diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusReactiveAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusReactiveAuthorizationCodeTokenResponseClient.java new file mode 100644 index 0000000000..cb937ea12d --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusReactiveAuthorizationCodeTokenResponseClient.java @@ -0,0 +1,171 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials; + +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.util.CollectionUtils; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.ExchangeFilterFunctions; +import org.springframework.web.reactive.function.client.WebClient; + +import com.nimbusds.oauth2.sdk.AccessTokenResponse; +import com.nimbusds.oauth2.sdk.ErrorObject; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.TokenErrorResponse; +import com.nimbusds.oauth2.sdk.TokenResponse; +import com.nimbusds.oauth2.sdk.token.AccessToken; + +import net.minidev.json.JSONObject; +import reactor.core.publisher.Mono; + +/** + * An implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} that "exchanges" + * an authorization code credential for an access token credential + * at the Authorization Server's Token Endpoint. + * + *

+ * NOTE: This implementation uses the Nimbus OAuth 2.0 SDK internally. + * + * @author Rob Winch + * @since 5.1 + * @see OAuth2AccessTokenResponseClient + * @see OAuth2AuthorizationCodeGrantRequest + * @see OAuth2AccessTokenResponse + * @see Nimbus OAuth 2.0 SDK + * @see Section 4.1.3 Access Token Request (Authorization Code Grant) + * @see Section 4.1.4 Access Token Response (Authorization Code Grant) + */ +public class NimbusReactiveAuthorizationCodeTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient { + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; + + private WebClient webClient = WebClient.builder() + .filter(ExchangeFilterFunctions.basicAuthentication()) + .build(); + + @Override + public Mono getTokenResponse(OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest) + throws OAuth2AuthenticationException { + + return Mono.defer(() -> { + ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration(); + + OAuth2AuthorizationExchange authorizationExchange = authorizationGrantRequest.getAuthorizationExchange(); + String tokenUri = clientRegistration.getProviderDetails().getTokenUri(); + BodyInserters.FormInserter body = body(authorizationExchange); + + return this.webClient.post() + .uri(tokenUri) + .accept(MediaType.APPLICATION_JSON) + .attributes(basicAuthenticationCredentials(clientRegistration.getClientId(), clientRegistration.getClientSecret())) + .body(body) + .retrieve() + .onStatus(s -> false, response -> { + throw new IllegalStateException("Disabled Status Handlers"); + }) + .bodyToMono(new ParameterizedTypeReference>() {}) + .map(json -> parse(json)) + .flatMap(tokenResponse -> accessTokenResponse(tokenResponse)) + .map(accessTokenResponse -> { + AccessToken accessToken = accessTokenResponse.getTokens().getAccessToken(); + OAuth2AccessToken.TokenType accessTokenType = null; + if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase( + accessToken.getType().getValue())) { + accessTokenType = OAuth2AccessToken.TokenType.BEARER; + } + long expiresIn = accessToken.getLifetime(); + + // As per spec, in section 5.1 Successful Access Token Response + // https://tools.ietf.org/html/rfc6749#section-5.1 + // If AccessTokenResponse.scope is empty, then default to the scope + // originally requested by the client in the Authorization Request + Set scopes; + if (CollectionUtils.isEmpty( + accessToken.getScope())) { + scopes = new LinkedHashSet<>( + authorizationExchange.getAuthorizationRequest().getScopes()); + } + else { + scopes = new LinkedHashSet<>( + accessToken.getScope().toStringList()); + } + + Map additionalParameters = new LinkedHashMap<>( + accessTokenResponse.getCustomParameters()); + + return OAuth2AccessTokenResponse.withToken(accessToken.getValue()) + .tokenType(accessTokenType) + .expiresIn(expiresIn) + .scopes(scopes) + .additionalParameters(additionalParameters) + .build(); + }); + }); + } + + private static BodyInserters.FormInserter body(OAuth2AuthorizationExchange authorizationExchange) { + OAuth2AuthorizationResponse authorizationResponse = authorizationExchange.getAuthorizationResponse(); + String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri(); + BodyInserters.FormInserter body = BodyInserters + .fromFormData("grant_type", AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .with("code", authorizationResponse.getCode()); + if (redirectUri != null) { + body.with("redirect_uri", redirectUri); + } + return body; + } + + private static Mono accessTokenResponse(TokenResponse tokenResponse) { + if (tokenResponse.indicatesSuccess()) { + return Mono.just(tokenResponse) + .cast(AccessTokenResponse.class); + } + TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) tokenResponse; + ErrorObject errorObject = tokenErrorResponse.getErrorObject(); + OAuth2Error oauth2Error = new OAuth2Error(errorObject.getCode(), + errorObject.getDescription(), (errorObject.getURI() != null ? + errorObject.getURI().toString() : + null)); + + return Mono.error(new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString())); + } + + private static TokenResponse parse(Map json) { + try { + return TokenResponse.parse(new JSONObject(json)); + } + catch (ParseException pe) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred parsing the Access Token response: " + pe.getMessage(), null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), pe); + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/ReactiveOAuth2AccessTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/ReactiveOAuth2AccessTokenResponseClient.java new file mode 100644 index 0000000000..76acf19956 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/ReactiveOAuth2AccessTokenResponseClient.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; + +import reactor.core.publisher.Mono; + +/** + * A reactive strategy for "exchanging" an authorization grant credential + * (e.g. an Authorization Code) for an access token credential + * at the Authorization Server's Token Endpoint. + * + * @author Rob Winch + * @since 5.1 + * @see AbstractOAuth2AuthorizationGrantRequest + * @see OAuth2AccessTokenResponse + * @see AuthorizationGrantType + * @see Section 1.3 Authorization Grant + * @see Section 4.1.3 Access Token Request (Authorization Code Grant) + * @see Section 4.1.4 Access Token Response (Authorization Code Grant) + */ +public interface ReactiveOAuth2AccessTokenResponseClient { + + /** + * Exchanges the authorization grant credential, provided in the authorization grant request, + * for an access token credential at the Authorization Server's Token Endpoint. + * + * @param authorizationGrantRequest the authorization grant request that contains the authorization grant credential + * @return an {@link OAuth2AccessTokenResponse} that contains the {@link OAuth2AccessTokenResponse#getAccessToken() access token} credential + * @throws OAuth2AuthenticationException if an error occurs while attempting to exchange for the access token credential + */ + Mono getTokenResponse(T authorizationGrantRequest) throws OAuth2AuthenticationException; + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusReactiveAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusReactiveAuthorizationCodeTokenResponseClientTests.java new file mode 100644 index 0000000000..6f3c93ad24 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusReactiveAuthorizationCodeTokenResponseClientTests.java @@ -0,0 +1,262 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.time.Instant; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; + +/** + * @author Rob Winch + * @since 5.1 + */ +public class NimbusReactiveAuthorizationCodeTokenResponseClientTests { + private ClientRegistration.Builder clientRegistration; + + private NimbusReactiveAuthorizationCodeTokenResponseClient tokenResponseClient = new NimbusReactiveAuthorizationCodeTokenResponseClient(); + + private MockWebServer server; + + @Before + public void setup() throws Exception { + this.server = new MockWebServer(); + this.server.start(); + + String tokenUri = this.server.url("/oauth2/token").toString(); + + this.clientRegistration = ClientRegistration.withRegistrationId("github") + .redirectUriTemplate("https://example.com/oauth2/code/github") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .scope("read:user") + .authorizationUri("https://github.com/login/oauth/authorize") + .tokenUri(tokenUri) + .userInfoUri("https://api.example.com/user") + .userNameAttributeName("user-name") + .clientName("GitHub") + .clientId("clientId") + .jwkSetUri("https://example.com/oauth2/jwk") + .clientSecret("clientSecret"); + } + + @After + public void cleanup() throws Exception { + this.server.shutdown(); + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"openid profile\",\n" + + " \"custom_parameter_1\": \"custom-value-1\",\n" + + " \"custom_parameter_2\": \"custom-value-2\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block(); + + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo( + OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile"); + assertThat(accessTokenResponse.getAdditionalParameters().size()).isEqualTo(2); + assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_1", "custom-value-1"); + assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_2", "custom-value-2"); + } + +// @Test +// public void getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() throws Exception { +// this.exception.expect(IllegalArgumentException.class); +// +// String redirectUri = "http:\\example.com"; +// when(this.clientRegistration.getRedirectUriTemplate()).thenReturn(redirectUri); +// +// this.tokenResponseClient.getTokenResponse( +// new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); +// } +// +// @Test +// public void getTokenResponseWhenTokenUriMalformedThenThrowIllegalArgumentException() throws Exception { +// this.exception.expect(IllegalArgumentException.class); +// +// String tokenUri = "http:\\provider.com\\oauth2\\token"; +// when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); +// +// this.tokenResponseClient.getTokenResponse( +// new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); +// } +// +// @Test +// public void getTokenResponseWhenSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception { +// this.exception.expect(OAuth2AuthenticationException.class); +// this.exception.expectMessage(containsString("invalid_token_response")); +// +// MockWebServer server = new MockWebServer(); +// +// String accessTokenSuccessResponse = "{\n" + +// " \"access_token\": \"access-token-1234\",\n" + +// " \"token_type\": \"bearer\",\n" + +// " \"expires_in\": \"3600\",\n" + +// " \"scope\": \"openid profile\",\n" + +// " \"custom_parameter_1\": \"custom-value-1\",\n" + +// " \"custom_parameter_2\": \"custom-value-2\"\n"; +// // "}\n"; // Make the JSON invalid/malformed +// +// server.enqueue(new MockResponse() +// .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) +// .setBody(accessTokenSuccessResponse)); +// server.start(); +// +// String tokenUri = server.url("/oauth2/token").toString(); +// when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); +// +// try { +// this.tokenResponseClient.getTokenResponse( +// new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); +// } finally { +// server.shutdown(); +// } +// } +// +// @Test +// public void getTokenResponseWhenTokenUriInvalidThenThrowAuthenticationServiceException() throws Exception { +// this.exception.expect(AuthenticationServiceException.class); +// +// String tokenUri = "http://invalid-provider.com/oauth2/token"; +// when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); +// +// this.tokenResponseClient.getTokenResponse( +// new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); +// } +// + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthenticationException() throws Exception { + String accessTokenErrorResponse = "{\n" + + " \"error\": \"unauthorized_client\"\n" + + "}\n"; + + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value())); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining("unauthorized_client"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthenticationException() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) + .isInstanceOf(OAuth2AuthenticationException.class) + .hasMessageContaining("invalid_token_response"); + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponseUsingResponseScope() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"openid profile\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + this.clientRegistration.scope("openid", "profile", "email", "address"); + + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block(); + + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile"); + } + + @Test + public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseUsingRequestedScope() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + + this.clientRegistration.scope("openid", "profile", "email", "address"); + + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block(); + + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile", "email", "address"); + } + + private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest() { + ClientRegistration registration = this.clientRegistration.build(); + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest + .authorizationCode() + .clientId(registration.getClientId()) + .state("state") + .authorizationUri(registration.getProviderDetails().getAuthorizationUri()) + .redirectUri(registration.getRedirectUriTemplate()) + .scopes(registration.getScopes()) + .build(); + OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponse + .success("code") + .state("state") + .redirectUri(registration.getRedirectUriTemplate()) + .build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, + authorizationResponse); + return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange); + } + + private MockResponse jsonResponse(String json) { + return new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(json); + } +}