Provide RestOperations in DefaultOAuth2UserService

Fixes gh-5600
This commit is contained in:
Joe Grandja 2018-08-28 07:53:49 -04:00 committed by Rob Winch
parent 25d1f49d84
commit 4a8c95a3e8
7 changed files with 498 additions and 203 deletions

View File

@ -15,11 +15,15 @@
*/
package org.springframework.security.oauth2.client.http;
import com.nimbusds.oauth2.sdk.token.BearerTokenError;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.util.StringUtils;
import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.ResponseErrorHandler;
@ -44,10 +48,39 @@ public class OAuth2ErrorResponseErrorHandler implements ResponseErrorHandler {
@Override
public void handleError(ClientHttpResponse response) throws IOException {
if (HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) {
OAuth2Error oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
if (!HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) {
this.defaultErrorHandler.handleError(response);
}
this.defaultErrorHandler.handleError(response);
// A Bearer Token Error may be in the WWW-Authenticate response header
// See https://tools.ietf.org/html/rfc6750#section-3
OAuth2Error oauth2Error = this.readErrorFromWwwAuthenticate(response.getHeaders());
if (oauth2Error == null) {
oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response);
}
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
private OAuth2Error readErrorFromWwwAuthenticate(HttpHeaders headers) {
String wwwAuthenticateHeader = headers.getFirst(HttpHeaders.WWW_AUTHENTICATE);
if (!StringUtils.hasText(wwwAuthenticateHeader)) {
return null;
}
BearerTokenError bearerTokenError;
try {
bearerTokenError = BearerTokenError.parse(wwwAuthenticateHeader);
} catch (Exception ex) {
return null;
}
String errorCode = bearerTokenError.getCode() != null ?
bearerTokenError.getCode() : OAuth2ErrorCodes.SERVER_ERROR;
String errorDescription = bearerTokenError.getDescription();
String errorUri = bearerTokenError.getURI() != null ?
bearerTokenError.getURI().toString() : null;
return new OAuth2Error(errorCode, errorDescription, errorUri);
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,7 +16,12 @@
package org.springframework.security.oauth2.client.userinfo;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
@ -24,8 +29,12 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
import java.util.HashSet;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
@ -34,7 +43,7 @@ import java.util.Set;
* <p>
* For standard OAuth 2.0 Provider's, the attribute name used to access the user's name
* from the UserInfo response is required and therefore must be available via
* {@link org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}.
* {@link ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}.
* <p>
* <b>NOTE:</b> Attribute names are <b>not</b> standardized between providers and therefore will vary.
* Please consult the provider's API documentation for the set of supported user attribute names.
@ -48,8 +57,23 @@ import java.util.Set;
*/
public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserRequest, OAuth2User> {
private static final String MISSING_USER_INFO_URI_ERROR_CODE = "missing_user_info_uri";
private static final String MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE = "missing_user_name_attribute";
private NimbusUserInfoResponseClient userInfoResponseClient = new NimbusUserInfoResponseClient();
private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response";
private static final ParameterizedTypeReference<Map<String, Object>> PARAMETERIZED_RESPONSE_TYPE =
new ParameterizedTypeReference<Map<String, Object>>() {};
private Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = new OAuth2UserRequestEntityConverter();
private RestOperations restOperations;
public DefaultOAuth2UserService() {
RestTemplate restTemplate = new RestTemplate();
restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler());
this.restOperations = restTemplate;
}
@Override
public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException {
@ -64,7 +88,8 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName();
String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails()
.getUserInfoEndpoint().getUserNameAttributeName();
if (!StringUtils.hasText(userNameAttributeName)) {
OAuth2Error oauth2Error = new OAuth2Error(
MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE,
@ -75,13 +100,63 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
ParameterizedTypeReference<Map<String, Object>> typeReference =
new ParameterizedTypeReference<Map<String, Object>>() {};
Map<String, Object> userAttributes = this.userInfoResponseClient.getUserInfoResponse(userRequest, typeReference);
GrantedAuthority authority = new OAuth2UserAuthority(userAttributes);
Set<GrantedAuthority> authorities = new HashSet<>();
authorities.add(authority);
RequestEntity<?> request = this.requestEntityConverter.convert(userRequest);
ResponseEntity<Map<String, Object>> response;
try {
response = this.restOperations.exchange(request, PARAMETERIZED_RESPONSE_TYPE);
} catch (OAuth2AuthenticationException ex) {
OAuth2Error oauth2Error = ex.getError();
StringBuilder errorDetails = new StringBuilder();
errorDetails.append("Error details: [");
errorDetails.append("UserInfo Uri: ").append(
userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri());
errorDetails.append(", Error Code: ").append(oauth2Error.getErrorCode());
if (oauth2Error.getDescription() != null) {
errorDetails.append(", Error Description: ").append(oauth2Error.getDescription());
}
errorDetails.append("]");
oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the UserInfo Resource: " + errorDetails.toString(), null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
} catch (RestClientException ex) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
"An error occurred while attempting to retrieve the UserInfo Resource: " + ex.getMessage(), null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
}
Map<String, Object> userAttributes = response.getBody();
Set<GrantedAuthority> authorities = Collections.singleton(new OAuth2UserAuthority(userAttributes));
return new DefaultOAuth2User(authorities, userAttributes, userNameAttributeName);
}
/**
* Sets the {@link Converter} used for converting the {@link OAuth2UserRequest}
* to a {@link RequestEntity} representation of the UserInfo Request.
*
* @since 5.1
* @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the UserInfo Request
*/
public final void setRequestEntityConverter(Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter) {
Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null");
this.requestEntityConverter = requestEntityConverter;
}
/**
* Sets the {@link RestOperations} used when requesting the UserInfo resource.
*
* <p>
* <b>NOTE:</b> At a minimum, the supplied {@code restOperations} must be configured with the following:
* <ol>
* <li>{@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}</li>
* </ol>
*
* @since 5.1
* @param restOperations the {@link RestOperations} used when requesting the UserInfo resource
*/
public final void setRestOperations(RestOperations restOperations) {
Assert.notNull(restOperations, "restOperations cannot be null");
this.restOperations = restOperations;
}
}

View File

@ -0,0 +1,81 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.userinfo;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.Collections;
import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE;
/**
* A {@link Converter} that converts the provided {@link OAuth2UserRequest}
* to a {@link RequestEntity} representation of a request for the UserInfo Endpoint.
*
* @author Joe Grandja
* @since 5.1
* @see Converter
* @see OAuth2UserRequest
* @see RequestEntity
*/
public class OAuth2UserRequestEntityConverter implements Converter<OAuth2UserRequest, RequestEntity<?>> {
private static final MediaType DEFAULT_CONTENT_TYPE = MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8");
/**
* Returns the {@link RequestEntity} used for the UserInfo Request.
*
* @param userRequest the user request
* @return the {@link RequestEntity} used for the UserInfo Request
*/
@Override
public RequestEntity<?> convert(OAuth2UserRequest userRequest) {
ClientRegistration clientRegistration = userRequest.getClientRegistration();
HttpMethod httpMethod = HttpMethod.GET;
if (AuthenticationMethod.FORM.equals(clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod())) {
httpMethod = HttpMethod.POST;
}
HttpHeaders headers = new HttpHeaders();
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8));
URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri())
.build()
.toUri();
RequestEntity<?> request;
if (HttpMethod.POST.equals(httpMethod)) {
headers.setContentType(DEFAULT_CONTENT_TYPE);
MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
formParameters.add(OAuth2ParameterNames.ACCESS_TOKEN, userRequest.getAccessToken().getTokenValue());
request = new RequestEntity<>(formParameters, headers, httpMethod, uri);
} else {
headers.setBearerAuth(userRequest.getAccessToken().getTokenValue());
request = new RequestEntity<>(headers, httpMethod, uri);
}
return request;
}
}

View File

@ -16,6 +16,7 @@
package org.springframework.security.oauth2.client.http;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
@ -31,7 +32,7 @@ public class OAuth2ErrorResponseErrorHandlerTests {
private OAuth2ErrorResponseErrorHandler errorHandler = new OAuth2ErrorResponseErrorHandler();
@Test
public void handleErrorWhenStatusBadRequestThenHandled() {
public void handleErrorWhenErrorResponseBodyThenHandled() {
String errorResponse = "{\n" +
" \"error\": \"unauthorized_client\",\n" +
" \"error_description\": \"The client is not authorized\"\n" +
@ -44,4 +45,17 @@ public class OAuth2ErrorResponseErrorHandlerTests {
.isInstanceOf(OAuth2AuthenticationException.class)
.hasMessage("[unauthorized_client] The client is not authorized");
}
@Test
public void handleErrorWhenErrorResponseWwwAuthenticateHeaderThenHandled() {
String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\"";
MockClientHttpResponse response = new MockClientHttpResponse(
new byte[0], HttpStatus.BAD_REQUEST);
response.getHeaders().add(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader);
assertThatThrownBy(() -> this.errorHandler.handleError(response))
.isInstanceOf(OAuth2AuthenticationException.class)
.hasMessage("[insufficient_scope] The access token expired");
}
}

View File

@ -18,6 +18,7 @@ package org.springframework.security.oauth2.client.oidc.userinfo;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@ -29,7 +30,6 @@ import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
import org.springframework.security.oauth2.core.AuthenticationMethod;
@ -71,12 +71,15 @@ public class OidcUserServiceTests {
private OAuth2AccessToken accessToken;
private OidcIdToken idToken;
private OidcUserService userService = new OidcUserService();
private MockWebServer server;
@Rule
public ExpectedException exception = ExpectedException.none();
@Before
public void setUp() throws Exception {
public void setup() throws Exception {
this.server = new MockWebServer();
this.server.start();
this.clientRegistration = mock(ClientRegistration.class);
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
@ -101,6 +104,11 @@ public class OidcUserServiceTests {
this.userService.setOauth2UserService(new DefaultOAuth2UserService());
}
@After
public void cleanup() throws Exception {
this.server.shutdown();
}
@Test
public void setOauth2UserServiceWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.userService.setOauth2UserService(null))
@ -135,9 +143,7 @@ public class OidcUserServiceTests {
}
@Test
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception {
MockWebServer server = new MockWebServer();
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
String userInfoResponse = "{\n" +
" \"sub\": \"subject1\",\n" +
" \"name\": \"first last\",\n" +
@ -146,13 +152,9 @@ public class OidcUserServiceTests {
" \"preferred_username\": \"user1\",\n" +
" \"email\": \"user1@example.com\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
this.server.enqueue(jsonResponse(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
@ -160,8 +162,6 @@ public class OidcUserServiceTests {
OidcUser user = this.userService.loadUser(
new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
server.shutdown();
assertThat(user.getIdToken()).isNotNull();
assertThat(user.getUserInfo()).isNotNull();
assertThat(user.getUserInfo().getClaims().size()).isEqualTo(6);
@ -184,69 +184,47 @@ public class OidcUserServiceTests {
// gh-5447
@Test
public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() throws Exception {
public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"email\": \"full_name@provider.com\",\n" +
" \"name\": \"full name\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
this.server.enqueue(jsonResponse(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
try {
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
} finally {
server.shutdown();
}
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
}
@Test
public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() throws Exception {
public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"sub\": \"other-subject\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
this.server.enqueue(jsonResponse(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
try {
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
} finally {
server.shutdown();
}
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
}
@Test
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
MockWebServer server = new MockWebServer();
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
String userInfoResponse = "{\n" +
" \"sub\": \"subject1\",\n" +
@ -256,48 +234,35 @@ public class OidcUserServiceTests {
" \"preferred_username\": \"user1\",\n" +
" \"email\": \"user1@example.com\"\n";
// "}\n"; // Make the JSON invalid/malformed
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
this.server.enqueue(jsonResponse(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
try {
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
} finally {
server.shutdown();
}
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
}
@Test
public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"));
MockWebServer server = new MockWebServer();
server.enqueue(new MockResponse().setResponseCode(500));
server.start();
this.server.enqueue(new MockResponse().setResponseCode(500));
String userInfoUri = server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
try {
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
} finally {
server.shutdown();
}
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
}
@Test
public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception {
this.exception.expect(AuthenticationServiceException.class);
public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
String userInfoUri = "http://invalid-provider.com/user";
@ -308,9 +273,7 @@ public class OidcUserServiceTests {
}
@Test
public void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserName() throws Exception {
MockWebServer server = new MockWebServer();
public void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserName() {
String userInfoResponse = "{\n" +
" \"sub\": \"subject1\",\n" +
" \"name\": \"first last\",\n" +
@ -319,13 +282,9 @@ public class OidcUserServiceTests {
" \"preferred_username\": \"user1\",\n" +
" \"email\": \"user1@example.com\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
this.server.enqueue(jsonResponse(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL);
@ -334,16 +293,12 @@ public class OidcUserServiceTests {
OidcUser user = this.userService.loadUser(
new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
server.shutdown();
assertThat(user.getName()).isEqualTo("user1@example.com");
}
// gh-5294
@Test
public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception {
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"sub\": \"subject1\",\n" +
" \"name\": \"first last\",\n" +
@ -352,28 +307,21 @@ public class OidcUserServiceTests {
" \"preferred_username\": \"user1\",\n" +
" \"email\": \"user1@example.com\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
this.server.enqueue(jsonResponse(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
server.shutdown();
assertThat(server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
.isEqualTo(MediaType.APPLICATION_JSON_VALUE);
assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
.isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
}
// gh-5500
@Test
public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception {
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"sub\": \"subject1\",\n" +
" \"name\": \"first last\",\n" +
@ -382,31 +330,24 @@ public class OidcUserServiceTests {
" \"preferred_username\": \"user1\",\n" +
" \"email\": \"user1@example.com\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
this.server.enqueue(jsonResponse(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
server.shutdown();
RecordedRequest request = server.takeRequest();
RecordedRequest request = this.server.takeRequest();
assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name());
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue());
}
// gh-5500
@Test
public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception {
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"sub\": \"subject1\",\n" +
" \"name\": \"first last\",\n" +
@ -415,24 +356,25 @@ public class OidcUserServiceTests {
" \"preferred_username\": \"user1\",\n" +
" \"email\": \"user1@example.com\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
this.server.enqueue(jsonResponse(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.FORM);
when(this.accessToken.getTokenValue()).thenReturn("access-token");
this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
server.shutdown();
RecordedRequest request = server.takeRequest();
RecordedRequest request = this.server.takeRequest();
assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name());
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE)).contains(MediaType.APPLICATION_FORM_URLENCODED_VALUE);
assertThat(request.getBody().readUtf8()).isEqualTo("access_token=" + this.accessToken.getTokenValue());
}
private MockResponse jsonResponse(String json) {
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(json);
}
}

View File

@ -18,7 +18,7 @@ package org.springframework.security.oauth2.client.userinfo;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@ -30,7 +30,6 @@ import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
@ -59,12 +58,15 @@ public class DefaultOAuth2UserServiceTests {
private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
private OAuth2AccessToken accessToken;
private DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
private MockWebServer server;
@Rule
public ExpectedException exception = ExpectedException.none();
@Before
public void setUp() throws Exception {
public void setup() throws Exception {
this.server = new MockWebServer();
this.server.start();
this.clientRegistration = mock(ClientRegistration.class);
this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
@ -73,6 +75,23 @@ public class DefaultOAuth2UserServiceTests {
this.accessToken = mock(OAuth2AccessToken.class);
}
@After
public void cleanup() throws Exception {
this.server.shutdown();
}
@Test
public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
this.userService.setRequestEntityConverter(null);
}
@Test
public void setRestOperationsWhenNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
this.userService.setRestOperations(null);
}
@Test
public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
this.exception.expect(IllegalArgumentException.class);
@ -99,9 +118,7 @@ public class DefaultOAuth2UserServiceTests {
}
@Test
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception {
MockWebServer server = new MockWebServer();
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
String userInfoResponse = "{\n" +
" \"user-name\": \"user1\",\n" +
" \"first-name\": \"first\",\n" +
@ -110,13 +127,9 @@ public class DefaultOAuth2UserServiceTests {
" \"address\": \"address\",\n" +
" \"email\": \"user1@example.com\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
this.server.enqueue(jsonResponse(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
@ -125,8 +138,6 @@ public class DefaultOAuth2UserServiceTests {
OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
server.shutdown();
assertThat(user.getName()).isEqualTo("user1");
assertThat(user.getAttributes().size()).isEqualTo(6);
assertThat(user.getAttributes().get("user-name")).isEqualTo("user1");
@ -144,11 +155,9 @@ public class DefaultOAuth2UserServiceTests {
}
@Test
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
MockWebServer server = new MockWebServer();
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
String userInfoResponse = "{\n" +
" \"user-name\": \"user1\",\n" +
@ -158,52 +167,83 @@ public class DefaultOAuth2UserServiceTests {
" \"address\": \"address\",\n" +
" \"email\": \"user1@example.com\"\n";
// "}\n"; // Make the JSON invalid/malformed
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
this.server.enqueue(jsonResponse(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
try {
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
} finally {
server.shutdown();
}
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
}
@Test
public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
public void loadUserWhenUserInfoErrorResponseWwwAuthenticateHeaderThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response"));
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
this.exception.expectMessage(containsString("Error Code: insufficient_scope, Error Description: The access token expired"));
MockWebServer server = new MockWebServer();
server.enqueue(new MockResponse().setResponseCode(500));
server.start();
String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\"";
String userInfoUri = server.url("/user").toString();
MockResponse response = new MockResponse();
response.setHeader(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader);
response.setResponseCode(400);
this.server.enqueue(response);
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
try {
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
} finally {
server.shutdown();
}
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
}
@Test
public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception {
this.exception.expect(AuthenticationServiceException.class);
public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
this.exception.expectMessage(containsString("Error Code: invalid_token"));
String userInfoErrorResponse = "{\n" +
" \"error\": \"invalid_token\"\n" +
"}\n";
this.server.enqueue(jsonResponse(userInfoErrorResponse).setResponseCode(400));
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
}
@Test
public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"));
this.server.enqueue(new MockResponse().setResponseCode(500));
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
when(this.accessToken.getTokenValue()).thenReturn("access-token");
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
}
@Test
public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
String userInfoUri = "http://invalid-provider.com/user";
@ -218,8 +258,6 @@ public class DefaultOAuth2UserServiceTests {
// gh-5294
@Test
public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception {
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"user-name\": \"user1\",\n" +
" \"first-name\": \"first\",\n" +
@ -228,13 +266,9 @@ public class DefaultOAuth2UserServiceTests {
" \"address\": \"address\",\n" +
" \"email\": \"user1@example.com\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
this.server.enqueue(jsonResponse(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
@ -242,16 +276,13 @@ public class DefaultOAuth2UserServiceTests {
when(this.accessToken.getTokenValue()).thenReturn("access-token");
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
server.shutdown();
assertThat(server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
.isEqualTo(MediaType.APPLICATION_JSON_VALUE);
assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
.isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
}
// gh-5500
@Test
public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception {
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"user-name\": \"user1\",\n" +
" \"first-name\": \"first\",\n" +
@ -260,13 +291,9 @@ public class DefaultOAuth2UserServiceTests {
" \"address\": \"address\",\n" +
" \"email\": \"user1@example.com\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
this.server.enqueue(jsonResponse(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
@ -274,18 +301,15 @@ public class DefaultOAuth2UserServiceTests {
when(this.accessToken.getTokenValue()).thenReturn("access-token");
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
server.shutdown();
RecordedRequest request = server.takeRequest();
RecordedRequest request = this.server.takeRequest();
assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name());
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue());
}
// gh-5500
@Test
public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception {
MockWebServer server = new MockWebServer();
String userInfoResponse = "{\n" +
" \"user-name\": \"user1\",\n" +
" \"first-name\": \"first\",\n" +
@ -294,13 +318,9 @@ public class DefaultOAuth2UserServiceTests {
" \"address\": \"address\",\n" +
" \"email\": \"user1@example.com\"\n" +
"}\n";
server.enqueue(new MockResponse()
.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.setBody(userInfoResponse));
this.server.enqueue(jsonResponse(userInfoResponse));
server.start();
String userInfoUri = server.url("/user").toString();
String userInfoUri = this.server.url("/user").toString();
when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.FORM);
@ -308,11 +328,16 @@ public class DefaultOAuth2UserServiceTests {
when(this.accessToken.getTokenValue()).thenReturn("access-token");
this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
server.shutdown();
RecordedRequest request = server.takeRequest();
RecordedRequest request = this.server.takeRequest();
assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name());
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE)).contains(MediaType.APPLICATION_FORM_URLENCODED_VALUE);
assertThat(request.getBody().readUtf8()).isEqualTo("access_token=" + this.accessToken.getTokenValue());
}
private MockResponse jsonResponse(String json) {
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(json);
}
}

View File

@ -0,0 +1,125 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.userinfo;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthenticationMethod;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.MultiValueMap;
import java.time.Instant;
import java.util.Arrays;
import java.util.LinkedHashSet;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE;
/**
* Tests for {@link OAuth2UserRequestEntityConverter}.
*
* @author Joe Grandja
*/
public class OAuth2UserRequestEntityConverterTests {
private OAuth2UserRequestEntityConverter converter = new OAuth2UserRequestEntityConverter();
private OAuth2UserRequest userRequest;
@Before
public void setup() {
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1")
.clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUriTemplate("https://client.com/callback/client-1")
.scope("read", "write")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/user")
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.userNameAttributeName("id")
.build();
OAuth2AccessToken accessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "access-token-1234", Instant.now(),
Instant.now().plusSeconds(3600), new LinkedHashSet<>(Arrays.asList("read", "write")));
this.userRequest = new OAuth2UserRequest(clientRegistration, accessToken);
}
@SuppressWarnings("unchecked")
@Test
public void convertWhenAuthenticationMethodHeaderThenGetRequest() {
RequestEntity<?> requestEntity = this.converter.convert(this.userRequest);
ClientRegistration clientRegistration = this.userRequest.getClientRegistration();
assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.GET);
assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo(
clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri());
HttpHeaders headers = requestEntity.getHeaders();
assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8);
assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo(
"Bearer " + this.userRequest.getAccessToken().getTokenValue());
}
@SuppressWarnings("unchecked")
@Test
public void convertWhenAuthenticationMethodFormThenPostRequest() {
ClientRegistration clientRegistration = this.from(this.userRequest.getClientRegistration())
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.build();
OAuth2UserRequest userRequest = new OAuth2UserRequest(
clientRegistration, this.userRequest.getAccessToken());
RequestEntity<?> requestEntity = this.converter.convert(userRequest);
assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST);
assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo(
clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri());
HttpHeaders headers = requestEntity.getHeaders();
assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8);
assertThat(headers.getContentType()).isEqualTo(
MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"));
MultiValueMap<String, String> formParameters = (MultiValueMap<String, String>) requestEntity.getBody();
assertThat(formParameters.getFirst(OAuth2ParameterNames.ACCESS_TOKEN)).isEqualTo(
this.userRequest.getAccessToken().getTokenValue());
}
private ClientRegistration.Builder from(ClientRegistration registration) {
return ClientRegistration.withRegistrationId(registration.getRegistrationId())
.clientId(registration.getClientId())
.clientSecret(registration.getClientSecret())
.clientAuthenticationMethod(registration.getClientAuthenticationMethod())
.authorizationGrantType(registration.getAuthorizationGrantType())
.redirectUriTemplate(registration.getRedirectUriTemplate())
.scope(registration.getScopes())
.authorizationUri(registration.getProviderDetails().getAuthorizationUri())
.tokenUri(registration.getProviderDetails().getTokenUri())
.userInfoUri(registration.getProviderDetails().getUserInfoEndpoint().getUri())
.userNameAttributeName(registration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName())
.clientName(registration.getClientName());
}
}