From 23f4b9d3d1f9de9e38b22af0fad45d81bd8c2c50 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Tue, 1 May 2018 16:44:11 -0500 Subject: [PATCH] Add OAuth2AuthorizationRequestRedirectWebFilter Issue: gh-4807 --- ...AuthorizationRequestRedirectWebFilter.java | 213 ++++++++++++++++++ ...rizationRequestRedirectWebFilterTests.java | 136 +++++++++++ 2 files changed, 349 insertions(+) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilter.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilterTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilter.java new file mode 100644 index 0000000000..f49eb4d6df --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilter.java @@ -0,0 +1,213 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import java.net.URI; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; + +import org.springframework.http.HttpStatus; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpRequestDecorator; +import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; +import org.springframework.security.crypto.keygen.StringKeyGenerator; +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.web.server.DefaultServerRedirectStrategy; +import org.springframework.security.web.server.ServerRedirectStrategy; +import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher; +import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.util.UriComponentsBuilder; + +import reactor.core.publisher.Mono; + +/** + * This {@code WebFilter} initiates the authorization code grant or implicit grant flow + * by redirecting the End-User's user-agent to the Authorization Server's Authorization Endpoint. + * + *

+ * It builds the OAuth 2.0 Authorization Request, + * which is used as the redirect {@code URI} to the Authorization Endpoint. + * The redirect {@code URI} will include the client identifier, requested scope(s), state, + * response type, and a redirection URI which the authorization server will send the user-agent back to + * once access is granted (or denied) by the End-User (Resource Owner). + * + *

+ * By default, this {@code Filter} responds to authorization requests + * at the {@code URI} {@code /oauth2/authorization/{registrationId}}. + * The {@code URI} template variable {@code {registrationId}} represents the + * {@link ClientRegistration#getRegistrationId() registration identifier} of the client + * that is used for initiating the OAuth 2.0 Authorization Request. + * + *

+ * NOTE: The default base {@code URI} {@code /oauth2/authorization} may be overridden + * via it's constructor {@link #OAuth2AuthorizationRequestRedirectWebFilter(ReactiveClientRegistrationRepository, String)}. + + * @author Rob Winch + * @since 5.1 + * @see OAuth2AuthorizationRequest + * @see AuthorizationRequestRepository + * @see ClientRegistration + * @see ClientRegistrationRepository + * @see Section 4.1 Authorization Code Grant + * @see Section 4.1.1 Authorization Request (Authorization Code) + * @see Section 4.2 Implicit Grant + * @see Section 4.2.1 Authorization Request (Implicit) + */ +public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter { + /** + * The default base {@code URI} used for authorization requests. + */ + public static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = "/oauth2/authorization"; + private static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId"; + private static final String AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME = + ClientAuthorizationRequiredException.class.getName() + ".AUTHORIZATION_REQUIRED_EXCEPTION"; + private final ServerWebExchangeMatcher authorizationRequestMatcher; + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizationRequestUriBuilder authorizationRequestUriBuilder = new OAuth2AuthorizationRequestUriBuilder(); + private final ServerRedirectStrategy authorizationRedirectStrategy = new DefaultServerRedirectStrategy(); + private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); + private ReactiveAuthorizationRequestRepository authorizationRequestRepository = + new WebSessionOAuth2ReactiveAuthorizationRequestRepository(); + + /** + * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + */ + public OAuth2AuthorizationRequestRedirectWebFilter(ReactiveClientRegistrationRepository clientRegistrationRepository) { + this(clientRegistrationRepository, DEFAULT_AUTHORIZATION_REQUEST_BASE_URI); + } + + /** + * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizationRequestBaseUri the base {@code URI} used for authorization requests + */ + public OAuth2AuthorizationRequestRedirectWebFilter( + ReactiveClientRegistrationRepository clientRegistrationRepository, String authorizationRequestBaseUri) { + + Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty"); + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + this.authorizationRequestMatcher = new PathPatternParserServerWebExchangeMatcher( + authorizationRequestBaseUri + "/{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}"); + this.clientRegistrationRepository = clientRegistrationRepository; + } + + /** + * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. + * + * @param authorizationRequestRepository the repository used for storing {@link OAuth2AuthorizationRequest}'s + */ + public final void setAuthorizationRequestRepository(ReactiveAuthorizationRequestRepository authorizationRequestRepository) { + Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); + this.authorizationRequestRepository = authorizationRequestRepository; + } + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + return this.authorizationRequestMatcher.matches(exchange) + .filter(matchResult -> matchResult.isMatch()) + .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) + .map(ServerWebExchangeMatcher.MatchResult::getVariables) + .map(variables -> variables.get(REGISTRATION_ID_URI_VARIABLE_NAME)) + .cast(String.class) + .flatMap(clientRegistrationId -> this.findByRegistrationId(exchange, clientRegistrationId)) + .flatMap(clientRegistration -> sendRedirectForAuthorization(exchange, clientRegistration)); + } + + private Mono findByRegistrationId(ServerWebExchange exchange, String clientRegistration) { + return this.clientRegistrationRepository.findByRegistrationId(clientRegistration) + .switchIfEmpty(Mono.defer(() -> { + exchange.getResponse().setStatusCode(HttpStatus.BAD_REQUEST); + return exchange.getResponse().setComplete().then(Mono.empty()); + })); + } + + private Mono sendRedirectForAuthorization(ServerWebExchange exchange, + ClientRegistration clientRegistration) { + return Mono.defer(() -> { + String redirectUriStr = this + .expandRedirectUri(exchange.getRequest(), clientRegistration); + + Map additionalParameters = new HashMap<>(); + additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, + clientRegistration.getRegistrationId()); + + OAuth2AuthorizationRequest.Builder builder; + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { + builder = OAuth2AuthorizationRequest.authorizationCode(); + } + else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) { + builder = OAuth2AuthorizationRequest.implicit(); + } + else { + throw new IllegalArgumentException( + "Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue() + + ") for Client Registration with Id: " + clientRegistration.getRegistrationId()); + } + OAuth2AuthorizationRequest authorizationRequest = builder + .clientId(clientRegistration.getClientId()) + .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) + .redirectUri(redirectUriStr).scopes(clientRegistration.getScopes()) + .state(this.stateGenerator.generateKey()) + .additionalParameters(additionalParameters).build(); + + Mono saveAuthorizationRequest = Mono.empty(); + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationRequest.getGrantType())) { + saveAuthorizationRequest = this.authorizationRequestRepository + .saveAuthorizationRequest(authorizationRequest, exchange); + } + + URI redirectUri = this.authorizationRequestUriBuilder.build(authorizationRequest); + return saveAuthorizationRequest + .then(this.authorizationRedirectStrategy.sendRedirect(exchange, redirectUri)); + }); + } + + private String expandRedirectUri(ServerHttpRequest request, ClientRegistration clientRegistration) { + // Supported URI variables -> baseUrl, action, registrationId + // Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}" + Map uriVariables = new HashMap<>(); + uriVariables.put("registrationId", clientRegistration.getRegistrationId()); + + String baseUrl = UriComponentsBuilder.fromHttpRequest(new ServerHttpRequestDecorator(request)) + .replacePath(request.getPath().contextPath().value()) + .build() + .toUriString(); + uriVariables.put("baseUrl", baseUrl); + + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { + String loginAction = "login"; + uriVariables.put("action", loginAction); + } + + return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate()) + .buildAndExpand(uriVariables) + .toUriString(); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilterTests.java new file mode 100644 index 0000000000..82fc51de2a --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilterTests.java @@ -0,0 +1,136 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.test.web.reactive.server.FluxExchangeResult; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.server.handler.FilteringWebHandler; +import reactor.core.publisher.Mono; + +import java.net.URI; +import java.util.Arrays; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +/** + * @author Rob Winch + * @since 5.1 + */ +@RunWith(MockitoJUnitRunner.class) +public class OAuth2AuthorizationRequestRedirectWebFilterTests { + @Mock + private ReactiveClientRegistrationRepository clientRepository; + + @Mock + private ReactiveAuthorizationRequestRepository authzRequestRepository; + + private ClientRegistration github = ClientRegistration.withRegistrationId("github") + .redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .scope("read:user") + .authorizationUri("https://github.com/login/oauth/authorize") + .tokenUri("https://github.com/login/oauth/access_token") + .userInfoUri("https://api.github.com/user") + .userNameAttributeName("id") + .clientName("GitHub") + .clientId("clientId") + .clientSecret("clientSecret") + .build(); + + private OAuth2AuthorizationRequestRedirectWebFilter filter; + + private WebTestClient client; + + @Before + public void setup() { + this.filter = new OAuth2AuthorizationRequestRedirectWebFilter(this.clientRepository); + this.filter.setAuthorizationRequestRepository(this.authzRequestRepository); + FilteringWebHandler webHandler = new FilteringWebHandler(e -> e.getResponse().setComplete(), Arrays.asList(this.filter)); + + this.client = WebTestClient.bindToWebHandler(webHandler).build(); + when(this.clientRepository.findByRegistrationId(this.github.getRegistrationId())).thenReturn( + Mono.just(this.github)); + when(this.authzRequestRepository.saveAuthorizationRequest(any(), any())).thenReturn( + Mono.empty()); + } + + @Test + public void constructorWhenClientRegistrationRepositoryNullThenIllegalArgumentException() { + this.clientRepository = null; + assertThatThrownBy(() -> new OAuth2AuthorizationRequestRedirectWebFilter(this.clientRepository)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorWhenAuthorizationRequestBaseUriNullThenIllegalArgumentException() { + String authorizationRequestBaseUri = null; + assertThatThrownBy(() -> new OAuth2AuthorizationRequestRedirectWebFilter(this.clientRepository, authorizationRequestBaseUri)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void constructorWhenAuthorizationRequestBaseUriEmptyThenIllegalArgumentException() { + String authorizationRequestBaseUri = ""; + assertThatThrownBy(() -> new OAuth2AuthorizationRequestRedirectWebFilter(this.clientRepository, authorizationRequestBaseUri)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void filterWhenDoesNotMatchThenClientRegistrationRepositoryNotSubscribed() { + this.client.get() + .exchange() + .expectStatus().isOk(); + + verifyZeroInteractions(this.clientRepository, this.authzRequestRepository); + } + + @Test + public void filterWhenDoesMatchThenClientRegistrationRepositoryNotSubscribed() { + FluxExchangeResult result = this.client.get() + .uri("https://example.com/oauth2/authorization/github").exchange() + .expectStatus().is3xxRedirection().returnResult(String.class); + result.assertWithDiagnostics(() -> { + URI location = result.getResponseHeaders().getLocation(); + assertThat(location) + .hasScheme("https") + .hasHost("github.com") + .hasPath("/login/oauth/authorize") + .hasParameter("response_type", "code") + .hasParameter("client_id", "clientId") + .hasParameter("scope", "read:user") + .hasParameter("state") + .hasParameter("redirect_uri", "https://example.com/login/oauth2/code/github"); + }); + verify(this.authzRequestRepository).saveAuthorizationRequest(any(), any()); + } +}