Provide RestOperations in DefaultOAuth2UserService
Fixes gh-5600
This commit is contained in:
parent
25d1f49d84
commit
4a8c95a3e8
|
@ -15,11 +15,15 @@
|
|||
*/
|
||||
package org.springframework.security.oauth2.client.http;
|
||||
|
||||
import com.nimbusds.oauth2.sdk.token.BearerTokenError;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.client.ClientHttpResponse;
|
||||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
||||
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
||||
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.client.DefaultResponseErrorHandler;
|
||||
import org.springframework.web.client.ResponseErrorHandler;
|
||||
|
||||
|
@ -44,10 +48,39 @@ public class OAuth2ErrorResponseErrorHandler implements ResponseErrorHandler {
|
|||
|
||||
@Override
|
||||
public void handleError(ClientHttpResponse response) throws IOException {
|
||||
if (HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) {
|
||||
OAuth2Error oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response);
|
||||
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
|
||||
if (!HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) {
|
||||
this.defaultErrorHandler.handleError(response);
|
||||
}
|
||||
this.defaultErrorHandler.handleError(response);
|
||||
|
||||
// A Bearer Token Error may be in the WWW-Authenticate response header
|
||||
// See https://tools.ietf.org/html/rfc6750#section-3
|
||||
OAuth2Error oauth2Error = this.readErrorFromWwwAuthenticate(response.getHeaders());
|
||||
if (oauth2Error == null) {
|
||||
oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response);
|
||||
}
|
||||
|
||||
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
|
||||
}
|
||||
|
||||
private OAuth2Error readErrorFromWwwAuthenticate(HttpHeaders headers) {
|
||||
String wwwAuthenticateHeader = headers.getFirst(HttpHeaders.WWW_AUTHENTICATE);
|
||||
if (!StringUtils.hasText(wwwAuthenticateHeader)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
BearerTokenError bearerTokenError;
|
||||
try {
|
||||
bearerTokenError = BearerTokenError.parse(wwwAuthenticateHeader);
|
||||
} catch (Exception ex) {
|
||||
return null;
|
||||
}
|
||||
|
||||
String errorCode = bearerTokenError.getCode() != null ?
|
||||
bearerTokenError.getCode() : OAuth2ErrorCodes.SERVER_ERROR;
|
||||
String errorDescription = bearerTokenError.getDescription();
|
||||
String errorUri = bearerTokenError.getURI() != null ?
|
||||
bearerTokenError.getURI().toString() : null;
|
||||
|
||||
return new OAuth2Error(errorCode, errorDescription, errorUri);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2017 the original author or authors.
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,7 +16,12 @@
|
|||
package org.springframework.security.oauth2.client.userinfo;
|
||||
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.core.convert.converter.Converter;
|
||||
import org.springframework.http.RequestEntity;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.security.core.GrantedAuthority;
|
||||
import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
||||
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
|
||||
|
@ -24,8 +29,12 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
|
|||
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.client.ResponseErrorHandler;
|
||||
import org.springframework.web.client.RestClientException;
|
||||
import org.springframework.web.client.RestOperations;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
|
@ -34,7 +43,7 @@ import java.util.Set;
|
|||
* <p>
|
||||
* For standard OAuth 2.0 Provider's, the attribute name used to access the user's name
|
||||
* from the UserInfo response is required and therefore must be available via
|
||||
* {@link org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}.
|
||||
* {@link ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}.
|
||||
* <p>
|
||||
* <b>NOTE:</b> Attribute names are <b>not</b> standardized between providers and therefore will vary.
|
||||
* Please consult the provider's API documentation for the set of supported user attribute names.
|
||||
|
@ -48,8 +57,23 @@ import java.util.Set;
|
|||
*/
|
||||
public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserRequest, OAuth2User> {
|
||||
private static final String MISSING_USER_INFO_URI_ERROR_CODE = "missing_user_info_uri";
|
||||
|
||||
private static final String MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE = "missing_user_name_attribute";
|
||||
private NimbusUserInfoResponseClient userInfoResponseClient = new NimbusUserInfoResponseClient();
|
||||
|
||||
private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response";
|
||||
|
||||
private static final ParameterizedTypeReference<Map<String, Object>> PARAMETERIZED_RESPONSE_TYPE =
|
||||
new ParameterizedTypeReference<Map<String, Object>>() {};
|
||||
|
||||
private Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = new OAuth2UserRequestEntityConverter();
|
||||
|
||||
private RestOperations restOperations;
|
||||
|
||||
public DefaultOAuth2UserService() {
|
||||
RestTemplate restTemplate = new RestTemplate();
|
||||
restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler());
|
||||
this.restOperations = restTemplate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException {
|
||||
|
@ -64,7 +88,8 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
|
|||
);
|
||||
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
|
||||
}
|
||||
String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName();
|
||||
String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails()
|
||||
.getUserInfoEndpoint().getUserNameAttributeName();
|
||||
if (!StringUtils.hasText(userNameAttributeName)) {
|
||||
OAuth2Error oauth2Error = new OAuth2Error(
|
||||
MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE,
|
||||
|
@ -75,13 +100,63 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
|
|||
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
|
||||
}
|
||||
|
||||
ParameterizedTypeReference<Map<String, Object>> typeReference =
|
||||
new ParameterizedTypeReference<Map<String, Object>>() {};
|
||||
Map<String, Object> userAttributes = this.userInfoResponseClient.getUserInfoResponse(userRequest, typeReference);
|
||||
GrantedAuthority authority = new OAuth2UserAuthority(userAttributes);
|
||||
Set<GrantedAuthority> authorities = new HashSet<>();
|
||||
authorities.add(authority);
|
||||
RequestEntity<?> request = this.requestEntityConverter.convert(userRequest);
|
||||
|
||||
ResponseEntity<Map<String, Object>> response;
|
||||
try {
|
||||
response = this.restOperations.exchange(request, PARAMETERIZED_RESPONSE_TYPE);
|
||||
} catch (OAuth2AuthenticationException ex) {
|
||||
OAuth2Error oauth2Error = ex.getError();
|
||||
StringBuilder errorDetails = new StringBuilder();
|
||||
errorDetails.append("Error details: [");
|
||||
errorDetails.append("UserInfo Uri: ").append(
|
||||
userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri());
|
||||
errorDetails.append(", Error Code: ").append(oauth2Error.getErrorCode());
|
||||
if (oauth2Error.getDescription() != null) {
|
||||
errorDetails.append(", Error Description: ").append(oauth2Error.getDescription());
|
||||
}
|
||||
errorDetails.append("]");
|
||||
oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
|
||||
"An error occurred while attempting to retrieve the UserInfo Resource: " + errorDetails.toString(), null);
|
||||
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
|
||||
} catch (RestClientException ex) {
|
||||
OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
|
||||
"An error occurred while attempting to retrieve the UserInfo Resource: " + ex.getMessage(), null);
|
||||
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
|
||||
}
|
||||
|
||||
Map<String, Object> userAttributes = response.getBody();
|
||||
Set<GrantedAuthority> authorities = Collections.singleton(new OAuth2UserAuthority(userAttributes));
|
||||
|
||||
return new DefaultOAuth2User(authorities, userAttributes, userNameAttributeName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link Converter} used for converting the {@link OAuth2UserRequest}
|
||||
* to a {@link RequestEntity} representation of the UserInfo Request.
|
||||
*
|
||||
* @since 5.1
|
||||
* @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the UserInfo Request
|
||||
*/
|
||||
public final void setRequestEntityConverter(Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter) {
|
||||
Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null");
|
||||
this.requestEntityConverter = requestEntityConverter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link RestOperations} used when requesting the UserInfo resource.
|
||||
*
|
||||
* <p>
|
||||
* <b>NOTE:</b> At a minimum, the supplied {@code restOperations} must be configured with the following:
|
||||
* <ol>
|
||||
* <li>{@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}</li>
|
||||
* </ol>
|
||||
*
|
||||
* @since 5.1
|
||||
* @param restOperations the {@link RestOperations} used when requesting the UserInfo resource
|
||||
*/
|
||||
public final void setRestOperations(RestOperations restOperations) {
|
||||
Assert.notNull(restOperations, "restOperations cannot be null");
|
||||
this.restOperations = restOperations;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.security.oauth2.client.userinfo;
|
||||
|
||||
import org.springframework.core.convert.converter.Converter;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.RequestEntity;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.core.AuthenticationMethod;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE;
|
||||
|
||||
/**
|
||||
* A {@link Converter} that converts the provided {@link OAuth2UserRequest}
|
||||
* to a {@link RequestEntity} representation of a request for the UserInfo Endpoint.
|
||||
*
|
||||
* @author Joe Grandja
|
||||
* @since 5.1
|
||||
* @see Converter
|
||||
* @see OAuth2UserRequest
|
||||
* @see RequestEntity
|
||||
*/
|
||||
public class OAuth2UserRequestEntityConverter implements Converter<OAuth2UserRequest, RequestEntity<?>> {
|
||||
private static final MediaType DEFAULT_CONTENT_TYPE = MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8");
|
||||
|
||||
/**
|
||||
* Returns the {@link RequestEntity} used for the UserInfo Request.
|
||||
*
|
||||
* @param userRequest the user request
|
||||
* @return the {@link RequestEntity} used for the UserInfo Request
|
||||
*/
|
||||
@Override
|
||||
public RequestEntity<?> convert(OAuth2UserRequest userRequest) {
|
||||
ClientRegistration clientRegistration = userRequest.getClientRegistration();
|
||||
|
||||
HttpMethod httpMethod = HttpMethod.GET;
|
||||
if (AuthenticationMethod.FORM.equals(clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod())) {
|
||||
httpMethod = HttpMethod.POST;
|
||||
}
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8));
|
||||
URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri())
|
||||
.build()
|
||||
.toUri();
|
||||
|
||||
RequestEntity<?> request;
|
||||
if (HttpMethod.POST.equals(httpMethod)) {
|
||||
headers.setContentType(DEFAULT_CONTENT_TYPE);
|
||||
MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
|
||||
formParameters.add(OAuth2ParameterNames.ACCESS_TOKEN, userRequest.getAccessToken().getTokenValue());
|
||||
request = new RequestEntity<>(formParameters, headers, httpMethod, uri);
|
||||
} else {
|
||||
headers.setBearerAuth(userRequest.getAccessToken().getTokenValue());
|
||||
request = new RequestEntity<>(headers, httpMethod, uri);
|
||||
}
|
||||
|
||||
return request;
|
||||
}
|
||||
}
|
|
@ -16,6 +16,7 @@
|
|||
package org.springframework.security.oauth2.client.http;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.mock.http.client.MockClientHttpResponse;
|
||||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
||||
|
@ -31,7 +32,7 @@ public class OAuth2ErrorResponseErrorHandlerTests {
|
|||
private OAuth2ErrorResponseErrorHandler errorHandler = new OAuth2ErrorResponseErrorHandler();
|
||||
|
||||
@Test
|
||||
public void handleErrorWhenStatusBadRequestThenHandled() {
|
||||
public void handleErrorWhenErrorResponseBodyThenHandled() {
|
||||
String errorResponse = "{\n" +
|
||||
" \"error\": \"unauthorized_client\",\n" +
|
||||
" \"error_description\": \"The client is not authorized\"\n" +
|
||||
|
@ -44,4 +45,17 @@ public class OAuth2ErrorResponseErrorHandlerTests {
|
|||
.isInstanceOf(OAuth2AuthenticationException.class)
|
||||
.hasMessage("[unauthorized_client] The client is not authorized");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handleErrorWhenErrorResponseWwwAuthenticateHeaderThenHandled() {
|
||||
String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\"";
|
||||
|
||||
MockClientHttpResponse response = new MockClientHttpResponse(
|
||||
new byte[0], HttpStatus.BAD_REQUEST);
|
||||
response.getHeaders().add(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader);
|
||||
|
||||
assertThatThrownBy(() -> this.errorHandler.handleError(response))
|
||||
.isInstanceOf(OAuth2AuthenticationException.class)
|
||||
.hasMessage("[insufficient_scope] The access token expired");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.springframework.security.oauth2.client.oidc.userinfo;
|
|||
import okhttp3.mockwebserver.MockResponse;
|
||||
import okhttp3.mockwebserver.MockWebServer;
|
||||
import okhttp3.mockwebserver.RecordedRequest;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
|
@ -29,7 +30,6 @@ import org.powermock.modules.junit4.PowerMockRunner;
|
|||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.security.authentication.AuthenticationServiceException;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
|
||||
import org.springframework.security.oauth2.core.AuthenticationMethod;
|
||||
|
@ -71,12 +71,15 @@ public class OidcUserServiceTests {
|
|||
private OAuth2AccessToken accessToken;
|
||||
private OidcIdToken idToken;
|
||||
private OidcUserService userService = new OidcUserService();
|
||||
private MockWebServer server;
|
||||
|
||||
@Rule
|
||||
public ExpectedException exception = ExpectedException.none();
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
public void setup() throws Exception {
|
||||
this.server = new MockWebServer();
|
||||
this.server.start();
|
||||
this.clientRegistration = mock(ClientRegistration.class);
|
||||
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
|
||||
this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
|
||||
|
@ -101,6 +104,11 @@ public class OidcUserServiceTests {
|
|||
this.userService.setOauth2UserService(new DefaultOAuth2UserService());
|
||||
}
|
||||
|
||||
@After
|
||||
public void cleanup() throws Exception {
|
||||
this.server.shutdown();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setOauth2UserServiceWhenNullThenThrowIllegalArgumentException() {
|
||||
assertThatThrownBy(() -> this.userService.setOauth2UserService(null))
|
||||
|
@ -135,9 +143,7 @@ public class OidcUserServiceTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception {
|
||||
MockWebServer server = new MockWebServer();
|
||||
|
||||
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
|
||||
String userInfoResponse = "{\n" +
|
||||
" \"sub\": \"subject1\",\n" +
|
||||
" \"name\": \"first last\",\n" +
|
||||
|
@ -146,13 +152,9 @@ public class OidcUserServiceTests {
|
|||
" \"preferred_username\": \"user1\",\n" +
|
||||
" \"email\": \"user1@example.com\"\n" +
|
||||
"}\n";
|
||||
server.enqueue(new MockResponse()
|
||||
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(userInfoResponse));
|
||||
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||
|
||||
server.start();
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
@ -160,8 +162,6 @@ public class OidcUserServiceTests {
|
|||
OidcUser user = this.userService.loadUser(
|
||||
new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
||||
|
||||
server.shutdown();
|
||||
|
||||
assertThat(user.getIdToken()).isNotNull();
|
||||
assertThat(user.getUserInfo()).isNotNull();
|
||||
assertThat(user.getUserInfo().getClaims().size()).isEqualTo(6);
|
||||
|
@ -184,69 +184,47 @@ public class OidcUserServiceTests {
|
|||
|
||||
// gh-5447
|
||||
@Test
|
||||
public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() throws Exception {
|
||||
public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() {
|
||||
this.exception.expect(OAuth2AuthenticationException.class);
|
||||
this.exception.expectMessage(containsString("invalid_user_info_response"));
|
||||
|
||||
MockWebServer server = new MockWebServer();
|
||||
|
||||
String userInfoResponse = "{\n" +
|
||||
" \"email\": \"full_name@provider.com\",\n" +
|
||||
" \"name\": \"full name\"\n" +
|
||||
"}\n";
|
||||
server.enqueue(new MockResponse()
|
||||
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(userInfoResponse));
|
||||
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||
|
||||
server.start();
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL);
|
||||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
try {
|
||||
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
||||
} finally {
|
||||
server.shutdown();
|
||||
}
|
||||
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() throws Exception {
|
||||
public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() {
|
||||
this.exception.expect(OAuth2AuthenticationException.class);
|
||||
this.exception.expectMessage(containsString("invalid_user_info_response"));
|
||||
|
||||
MockWebServer server = new MockWebServer();
|
||||
|
||||
String userInfoResponse = "{\n" +
|
||||
" \"sub\": \"other-subject\"\n" +
|
||||
"}\n";
|
||||
server.enqueue(new MockResponse()
|
||||
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(userInfoResponse));
|
||||
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||
|
||||
server.start();
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
try {
|
||||
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
||||
} finally {
|
||||
server.shutdown();
|
||||
}
|
||||
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
|
||||
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
|
||||
this.exception.expect(OAuth2AuthenticationException.class);
|
||||
this.exception.expectMessage(containsString("invalid_user_info_response"));
|
||||
|
||||
MockWebServer server = new MockWebServer();
|
||||
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
|
||||
|
||||
String userInfoResponse = "{\n" +
|
||||
" \"sub\": \"subject1\",\n" +
|
||||
|
@ -256,48 +234,35 @@ public class OidcUserServiceTests {
|
|||
" \"preferred_username\": \"user1\",\n" +
|
||||
" \"email\": \"user1@example.com\"\n";
|
||||
// "}\n"; // Make the JSON invalid/malformed
|
||||
server.enqueue(new MockResponse()
|
||||
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(userInfoResponse));
|
||||
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||
|
||||
server.start();
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
try {
|
||||
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
||||
} finally {
|
||||
server.shutdown();
|
||||
}
|
||||
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
|
||||
public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() {
|
||||
this.exception.expect(OAuth2AuthenticationException.class);
|
||||
this.exception.expectMessage(containsString("invalid_user_info_response"));
|
||||
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"));
|
||||
|
||||
MockWebServer server = new MockWebServer();
|
||||
server.enqueue(new MockResponse().setResponseCode(500));
|
||||
server.start();
|
||||
this.server.enqueue(new MockResponse().setResponseCode(500));
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
try {
|
||||
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
||||
} finally {
|
||||
server.shutdown();
|
||||
}
|
||||
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception {
|
||||
this.exception.expect(AuthenticationServiceException.class);
|
||||
public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() {
|
||||
this.exception.expect(OAuth2AuthenticationException.class);
|
||||
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
|
||||
|
||||
String userInfoUri = "http://invalid-provider.com/user";
|
||||
|
||||
|
@ -308,9 +273,7 @@ public class OidcUserServiceTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserName() throws Exception {
|
||||
MockWebServer server = new MockWebServer();
|
||||
|
||||
public void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserName() {
|
||||
String userInfoResponse = "{\n" +
|
||||
" \"sub\": \"subject1\",\n" +
|
||||
" \"name\": \"first last\",\n" +
|
||||
|
@ -319,13 +282,9 @@ public class OidcUserServiceTests {
|
|||
" \"preferred_username\": \"user1\",\n" +
|
||||
" \"email\": \"user1@example.com\"\n" +
|
||||
"}\n";
|
||||
server.enqueue(new MockResponse()
|
||||
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(userInfoResponse));
|
||||
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||
|
||||
server.start();
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL);
|
||||
|
@ -334,16 +293,12 @@ public class OidcUserServiceTests {
|
|||
OidcUser user = this.userService.loadUser(
|
||||
new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
||||
|
||||
server.shutdown();
|
||||
|
||||
assertThat(user.getName()).isEqualTo("user1@example.com");
|
||||
}
|
||||
|
||||
// gh-5294
|
||||
@Test
|
||||
public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception {
|
||||
MockWebServer server = new MockWebServer();
|
||||
|
||||
String userInfoResponse = "{\n" +
|
||||
" \"sub\": \"subject1\",\n" +
|
||||
" \"name\": \"first last\",\n" +
|
||||
|
@ -352,28 +307,21 @@ public class OidcUserServiceTests {
|
|||
" \"preferred_username\": \"user1\",\n" +
|
||||
" \"email\": \"user1@example.com\"\n" +
|
||||
"}\n";
|
||||
server.enqueue(new MockResponse()
|
||||
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(userInfoResponse));
|
||||
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||
|
||||
server.start();
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
||||
server.shutdown();
|
||||
assertThat(server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
|
||||
.isEqualTo(MediaType.APPLICATION_JSON_VALUE);
|
||||
assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
|
||||
.isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
|
||||
}
|
||||
|
||||
// gh-5500
|
||||
@Test
|
||||
public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception {
|
||||
MockWebServer server = new MockWebServer();
|
||||
|
||||
String userInfoResponse = "{\n" +
|
||||
" \"sub\": \"subject1\",\n" +
|
||||
" \"name\": \"first last\",\n" +
|
||||
|
@ -382,31 +330,24 @@ public class OidcUserServiceTests {
|
|||
" \"preferred_username\": \"user1\",\n" +
|
||||
" \"email\": \"user1@example.com\"\n" +
|
||||
"}\n";
|
||||
server.enqueue(new MockResponse()
|
||||
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(userInfoResponse));
|
||||
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||
|
||||
server.start();
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
|
||||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
||||
server.shutdown();
|
||||
RecordedRequest request = server.takeRequest();
|
||||
RecordedRequest request = this.server.takeRequest();
|
||||
assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name());
|
||||
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
|
||||
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
|
||||
assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue());
|
||||
}
|
||||
|
||||
// gh-5500
|
||||
@Test
|
||||
public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception {
|
||||
MockWebServer server = new MockWebServer();
|
||||
|
||||
String userInfoResponse = "{\n" +
|
||||
" \"sub\": \"subject1\",\n" +
|
||||
" \"name\": \"first last\",\n" +
|
||||
|
@ -415,24 +356,25 @@ public class OidcUserServiceTests {
|
|||
" \"preferred_username\": \"user1\",\n" +
|
||||
" \"email\": \"user1@example.com\"\n" +
|
||||
"}\n";
|
||||
server.enqueue(new MockResponse()
|
||||
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(userInfoResponse));
|
||||
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||
|
||||
server.start();
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.FORM);
|
||||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
||||
server.shutdown();
|
||||
RecordedRequest request = server.takeRequest();
|
||||
RecordedRequest request = this.server.takeRequest();
|
||||
assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name());
|
||||
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
|
||||
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
|
||||
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE)).contains(MediaType.APPLICATION_FORM_URLENCODED_VALUE);
|
||||
assertThat(request.getBody().readUtf8()).isEqualTo("access_token=" + this.accessToken.getTokenValue());
|
||||
}
|
||||
|
||||
private MockResponse jsonResponse(String json) {
|
||||
return new MockResponse()
|
||||
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(json);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.springframework.security.oauth2.client.userinfo;
|
|||
import okhttp3.mockwebserver.MockResponse;
|
||||
import okhttp3.mockwebserver.MockWebServer;
|
||||
import okhttp3.mockwebserver.RecordedRequest;
|
||||
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
|
@ -30,7 +30,6 @@ import org.powermock.modules.junit4.PowerMockRunner;
|
|||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.security.authentication.AuthenticationServiceException;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.core.AuthenticationMethod;
|
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
|
@ -59,12 +58,15 @@ public class DefaultOAuth2UserServiceTests {
|
|||
private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
|
||||
private OAuth2AccessToken accessToken;
|
||||
private DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
|
||||
private MockWebServer server;
|
||||
|
||||
@Rule
|
||||
public ExpectedException exception = ExpectedException.none();
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
public void setup() throws Exception {
|
||||
this.server = new MockWebServer();
|
||||
this.server.start();
|
||||
this.clientRegistration = mock(ClientRegistration.class);
|
||||
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
|
||||
this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
|
||||
|
@ -73,6 +75,23 @@ public class DefaultOAuth2UserServiceTests {
|
|||
this.accessToken = mock(OAuth2AccessToken.class);
|
||||
}
|
||||
|
||||
@After
|
||||
public void cleanup() throws Exception {
|
||||
this.server.shutdown();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() {
|
||||
this.exception.expect(IllegalArgumentException.class);
|
||||
this.userService.setRequestEntityConverter(null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setRestOperationsWhenNullThenThrowIllegalArgumentException() {
|
||||
this.exception.expect(IllegalArgumentException.class);
|
||||
this.userService.setRestOperations(null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
|
||||
this.exception.expect(IllegalArgumentException.class);
|
||||
|
@ -99,9 +118,7 @@ public class DefaultOAuth2UserServiceTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception {
|
||||
MockWebServer server = new MockWebServer();
|
||||
|
||||
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
|
||||
String userInfoResponse = "{\n" +
|
||||
" \"user-name\": \"user1\",\n" +
|
||||
" \"first-name\": \"first\",\n" +
|
||||
|
@ -110,13 +127,9 @@ public class DefaultOAuth2UserServiceTests {
|
|||
" \"address\": \"address\",\n" +
|
||||
" \"email\": \"user1@example.com\"\n" +
|
||||
"}\n";
|
||||
server.enqueue(new MockResponse()
|
||||
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(userInfoResponse));
|
||||
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||
|
||||
server.start();
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
|
||||
|
@ -125,8 +138,6 @@ public class DefaultOAuth2UserServiceTests {
|
|||
|
||||
OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
|
||||
|
||||
server.shutdown();
|
||||
|
||||
assertThat(user.getName()).isEqualTo("user1");
|
||||
assertThat(user.getAttributes().size()).isEqualTo(6);
|
||||
assertThat(user.getAttributes().get("user-name")).isEqualTo("user1");
|
||||
|
@ -144,11 +155,9 @@ public class DefaultOAuth2UserServiceTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
|
||||
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
|
||||
this.exception.expect(OAuth2AuthenticationException.class);
|
||||
this.exception.expectMessage(containsString("invalid_user_info_response"));
|
||||
|
||||
MockWebServer server = new MockWebServer();
|
||||
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
|
||||
|
||||
String userInfoResponse = "{\n" +
|
||||
" \"user-name\": \"user1\",\n" +
|
||||
|
@ -158,52 +167,83 @@ public class DefaultOAuth2UserServiceTests {
|
|||
" \"address\": \"address\",\n" +
|
||||
" \"email\": \"user1@example.com\"\n";
|
||||
// "}\n"; // Make the JSON invalid/malformed
|
||||
server.enqueue(new MockResponse()
|
||||
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(userInfoResponse));
|
||||
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||
|
||||
server.start();
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
|
||||
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
|
||||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
try {
|
||||
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
|
||||
} finally {
|
||||
server.shutdown();
|
||||
}
|
||||
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
|
||||
public void loadUserWhenUserInfoErrorResponseWwwAuthenticateHeaderThenThrowOAuth2AuthenticationException() {
|
||||
this.exception.expect(OAuth2AuthenticationException.class);
|
||||
this.exception.expectMessage(containsString("invalid_user_info_response"));
|
||||
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
|
||||
this.exception.expectMessage(containsString("Error Code: insufficient_scope, Error Description: The access token expired"));
|
||||
|
||||
MockWebServer server = new MockWebServer();
|
||||
server.enqueue(new MockResponse().setResponseCode(500));
|
||||
server.start();
|
||||
String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\"";
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
MockResponse response = new MockResponse();
|
||||
response.setHeader(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader);
|
||||
response.setResponseCode(400);
|
||||
this.server.enqueue(response);
|
||||
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
|
||||
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
|
||||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
try {
|
||||
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
|
||||
} finally {
|
||||
server.shutdown();
|
||||
}
|
||||
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception {
|
||||
this.exception.expect(AuthenticationServiceException.class);
|
||||
public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() {
|
||||
this.exception.expect(OAuth2AuthenticationException.class);
|
||||
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
|
||||
this.exception.expectMessage(containsString("Error Code: invalid_token"));
|
||||
|
||||
String userInfoErrorResponse = "{\n" +
|
||||
" \"error\": \"invalid_token\"\n" +
|
||||
"}\n";
|
||||
this.server.enqueue(jsonResponse(userInfoErrorResponse).setResponseCode(400));
|
||||
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
|
||||
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
|
||||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() {
|
||||
this.exception.expect(OAuth2AuthenticationException.class);
|
||||
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"));
|
||||
|
||||
this.server.enqueue(new MockResponse().setResponseCode(500));
|
||||
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
|
||||
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
|
||||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() {
|
||||
this.exception.expect(OAuth2AuthenticationException.class);
|
||||
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
|
||||
|
||||
String userInfoUri = "http://invalid-provider.com/user";
|
||||
|
||||
|
@ -218,8 +258,6 @@ public class DefaultOAuth2UserServiceTests {
|
|||
// gh-5294
|
||||
@Test
|
||||
public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception {
|
||||
MockWebServer server = new MockWebServer();
|
||||
|
||||
String userInfoResponse = "{\n" +
|
||||
" \"user-name\": \"user1\",\n" +
|
||||
" \"first-name\": \"first\",\n" +
|
||||
|
@ -228,13 +266,9 @@ public class DefaultOAuth2UserServiceTests {
|
|||
" \"address\": \"address\",\n" +
|
||||
" \"email\": \"user1@example.com\"\n" +
|
||||
"}\n";
|
||||
server.enqueue(new MockResponse()
|
||||
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(userInfoResponse));
|
||||
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||
|
||||
server.start();
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
|
||||
|
@ -242,16 +276,13 @@ public class DefaultOAuth2UserServiceTests {
|
|||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
|
||||
server.shutdown();
|
||||
assertThat(server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
|
||||
.isEqualTo(MediaType.APPLICATION_JSON_VALUE);
|
||||
assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
|
||||
.isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
|
||||
}
|
||||
|
||||
// gh-5500
|
||||
@Test
|
||||
public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception {
|
||||
MockWebServer server = new MockWebServer();
|
||||
|
||||
String userInfoResponse = "{\n" +
|
||||
" \"user-name\": \"user1\",\n" +
|
||||
" \"first-name\": \"first\",\n" +
|
||||
|
@ -260,13 +291,9 @@ public class DefaultOAuth2UserServiceTests {
|
|||
" \"address\": \"address\",\n" +
|
||||
" \"email\": \"user1@example.com\"\n" +
|
||||
"}\n";
|
||||
server.enqueue(new MockResponse()
|
||||
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(userInfoResponse));
|
||||
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||
|
||||
server.start();
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
|
||||
|
@ -274,18 +301,15 @@ public class DefaultOAuth2UserServiceTests {
|
|||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
|
||||
server.shutdown();
|
||||
RecordedRequest request = server.takeRequest();
|
||||
RecordedRequest request = this.server.takeRequest();
|
||||
assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name());
|
||||
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
|
||||
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
|
||||
assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue());
|
||||
}
|
||||
|
||||
// gh-5500
|
||||
@Test
|
||||
public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception {
|
||||
MockWebServer server = new MockWebServer();
|
||||
|
||||
String userInfoResponse = "{\n" +
|
||||
" \"user-name\": \"user1\",\n" +
|
||||
" \"first-name\": \"first\",\n" +
|
||||
|
@ -294,13 +318,9 @@ public class DefaultOAuth2UserServiceTests {
|
|||
" \"address\": \"address\",\n" +
|
||||
" \"email\": \"user1@example.com\"\n" +
|
||||
"}\n";
|
||||
server.enqueue(new MockResponse()
|
||||
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(userInfoResponse));
|
||||
this.server.enqueue(jsonResponse(userInfoResponse));
|
||||
|
||||
server.start();
|
||||
|
||||
String userInfoUri = server.url("/user").toString();
|
||||
String userInfoUri = this.server.url("/user").toString();
|
||||
|
||||
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
||||
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.FORM);
|
||||
|
@ -308,11 +328,16 @@ public class DefaultOAuth2UserServiceTests {
|
|||
when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
||||
|
||||
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
|
||||
server.shutdown();
|
||||
RecordedRequest request = server.takeRequest();
|
||||
RecordedRequest request = this.server.takeRequest();
|
||||
assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name());
|
||||
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
|
||||
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
|
||||
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE)).contains(MediaType.APPLICATION_FORM_URLENCODED_VALUE);
|
||||
assertThat(request.getBody().readUtf8()).isEqualTo("access_token=" + this.accessToken.getTokenValue());
|
||||
}
|
||||
|
||||
private MockResponse jsonResponse(String json) {
|
||||
return new MockResponse()
|
||||
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
|
||||
.setBody(json);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.security.oauth2.client.userinfo;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.RequestEntity;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.core.AuthenticationMethod;
|
||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.LinkedHashSet;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE;
|
||||
|
||||
/**
|
||||
* Tests for {@link OAuth2UserRequestEntityConverter}.
|
||||
*
|
||||
* @author Joe Grandja
|
||||
*/
|
||||
public class OAuth2UserRequestEntityConverterTests {
|
||||
private OAuth2UserRequestEntityConverter converter = new OAuth2UserRequestEntityConverter();
|
||||
private OAuth2UserRequest userRequest;
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1")
|
||||
.clientId("client-1")
|
||||
.clientSecret("secret")
|
||||
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
|
||||
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
|
||||
.redirectUriTemplate("https://client.com/callback/client-1")
|
||||
.scope("read", "write")
|
||||
.authorizationUri("https://provider.com/oauth2/authorize")
|
||||
.tokenUri("https://provider.com/oauth2/token")
|
||||
.userInfoUri("https://provider.com/user")
|
||||
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
|
||||
.userNameAttributeName("id")
|
||||
.build();
|
||||
OAuth2AccessToken accessToken = new OAuth2AccessToken(
|
||||
OAuth2AccessToken.TokenType.BEARER, "access-token-1234", Instant.now(),
|
||||
Instant.now().plusSeconds(3600), new LinkedHashSet<>(Arrays.asList("read", "write")));
|
||||
this.userRequest = new OAuth2UserRequest(clientRegistration, accessToken);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Test
|
||||
public void convertWhenAuthenticationMethodHeaderThenGetRequest() {
|
||||
RequestEntity<?> requestEntity = this.converter.convert(this.userRequest);
|
||||
|
||||
ClientRegistration clientRegistration = this.userRequest.getClientRegistration();
|
||||
|
||||
assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.GET);
|
||||
assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo(
|
||||
clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri());
|
||||
|
||||
HttpHeaders headers = requestEntity.getHeaders();
|
||||
assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8);
|
||||
assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo(
|
||||
"Bearer " + this.userRequest.getAccessToken().getTokenValue());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Test
|
||||
public void convertWhenAuthenticationMethodFormThenPostRequest() {
|
||||
ClientRegistration clientRegistration = this.from(this.userRequest.getClientRegistration())
|
||||
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
|
||||
.build();
|
||||
OAuth2UserRequest userRequest = new OAuth2UserRequest(
|
||||
clientRegistration, this.userRequest.getAccessToken());
|
||||
|
||||
RequestEntity<?> requestEntity = this.converter.convert(userRequest);
|
||||
|
||||
assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST);
|
||||
assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo(
|
||||
clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri());
|
||||
|
||||
HttpHeaders headers = requestEntity.getHeaders();
|
||||
assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8);
|
||||
assertThat(headers.getContentType()).isEqualTo(
|
||||
MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"));
|
||||
|
||||
MultiValueMap<String, String> formParameters = (MultiValueMap<String, String>) requestEntity.getBody();
|
||||
assertThat(formParameters.getFirst(OAuth2ParameterNames.ACCESS_TOKEN)).isEqualTo(
|
||||
this.userRequest.getAccessToken().getTokenValue());
|
||||
}
|
||||
|
||||
private ClientRegistration.Builder from(ClientRegistration registration) {
|
||||
return ClientRegistration.withRegistrationId(registration.getRegistrationId())
|
||||
.clientId(registration.getClientId())
|
||||
.clientSecret(registration.getClientSecret())
|
||||
.clientAuthenticationMethod(registration.getClientAuthenticationMethod())
|
||||
.authorizationGrantType(registration.getAuthorizationGrantType())
|
||||
.redirectUriTemplate(registration.getRedirectUriTemplate())
|
||||
.scope(registration.getScopes())
|
||||
.authorizationUri(registration.getProviderDetails().getAuthorizationUri())
|
||||
.tokenUri(registration.getProviderDetails().getTokenUri())
|
||||
.userInfoUri(registration.getProviderDetails().getUserInfoEndpoint().getUri())
|
||||
.userNameAttributeName(registration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName())
|
||||
.clientName(registration.getClientName());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue