From 8acdb82e6a1f6e67fdde6b0276054140e524d641 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 10 Feb 2020 06:57:31 -0500 Subject: [PATCH] OAuth2AuthorizationCodeGrantWebFilter matches on query parameters Fixes gh-7966 --- ...OAuth2AuthorizationCodeGrantWebFilter.java | 59 ++++-- ...ationCodeAuthenticationTokenConverter.java | 9 +- ...2AuthorizationCodeGrantWebFilterTests.java | 197 ++++++++++++++---- 3 files changed, 201 insertions(+), 64 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java index de6e31d458..7ef55667cb 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,13 +37,20 @@ import org.springframework.security.web.server.authentication.ServerAuthenticati import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.util.Assert; -import org.springframework.util.MultiValueMap; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; +import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; import reactor.core.publisher.Mono; +import java.net.URI; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + /** * A {@code Filter} for the OAuth 2.0 Authorization Code Grant, * which handles the processing of the OAuth 2.0 Authorization Response. @@ -165,10 +172,10 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { return this.requiresAuthenticationMatcher.matches(exchange) - .filter( matchResult -> matchResult.isMatch()) - .flatMap( matchResult -> this.authenticationConverter.convert(exchange)) + .filter(ServerWebExchangeMatcher.MatchResult::isMatch) + .flatMap(matchResult -> this.authenticationConverter.convert(exchange)) .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) - .flatMap( token -> authenticate(exchange, chain, token)); + .flatMap(token -> authenticate(exchange, chain, token)); } private Mono authenticate(ServerWebExchange exchange, @@ -198,20 +205,34 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { } private Mono matchesAuthorizationResponse(ServerWebExchange exchange) { - return this.authorizationRequestRepository.loadAuthorizationRequest(exchange) - .flatMap(authorizationRequest -> { - String requestUrl = UriComponentsBuilder.fromUri(exchange.getRequest().getURI()) - .query(null) - .build() - .toUriString(); - MultiValueMap queryParams = exchange.getRequest().getQueryParams(); - if (requestUrl.equals(authorizationRequest.getRedirectUri()) && - OAuth2AuthorizationResponseUtils.isAuthorizationResponse(queryParams)) { - return ServerWebExchangeMatcher.MatchResult.match(); - } - return ServerWebExchangeMatcher.MatchResult.notMatch(); - }) - .filter(ServerWebExchangeMatcher.MatchResult::isMatch) + return Mono.just(exchange) + .filter(exch -> OAuth2AuthorizationResponseUtils.isAuthorizationResponse(exch.getRequest().getQueryParams())) + .flatMap(exch -> this.authorizationRequestRepository.loadAuthorizationRequest(exchange) + .flatMap(authorizationRequest -> + matchesRedirectUri(exch.getRequest().getURI(), authorizationRequest.getRedirectUri()))) .switchIfEmpty(ServerWebExchangeMatcher.MatchResult.notMatch()); } + + private static Mono matchesRedirectUri( + URI authorizationResponseUri, String authorizationRequestRedirectUri) { + UriComponents requestUri = UriComponentsBuilder.fromUri(authorizationResponseUri).build(); + UriComponents redirectUri = UriComponentsBuilder.fromUriString(authorizationRequestRedirectUri).build(); + Set>> requestUriParameters = + new LinkedHashSet<>(requestUri.getQueryParams().entrySet()); + Set>> redirectUriParameters = + new LinkedHashSet<>(redirectUri.getQueryParams().entrySet()); + // Remove the additional request parameters (if any) from the authorization response (request) + // before doing an exact comparison with the authorizationRequest.getRedirectUri() parameters (if any) + requestUriParameters.retainAll(redirectUriParameters); + + if (Objects.equals(requestUri.getScheme(), redirectUri.getScheme()) && + Objects.equals(requestUri.getUserInfo(), redirectUri.getUserInfo()) && + Objects.equals(requestUri.getHost(), redirectUri.getHost()) && + Objects.equals(requestUri.getPort(), redirectUri.getPort()) && + Objects.equals(requestUri.getPath(), redirectUri.getPath()) && + Objects.equals(requestUriParameters.toString(), redirectUriParameters.toString())) { + return ServerWebExchangeMatcher.MatchResult.match(); + } + return ServerWebExchangeMatcher.MatchResult.notMatch(); + } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java index 38a453f009..2e9bd68c17 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,6 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; import org.springframework.util.Assert; -import org.springframework.util.MultiValueMap; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.util.UriComponentsBuilder; import reactor.core.publisher.Mono; @@ -103,14 +102,10 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverter } private static OAuth2AuthorizationResponse convertResponse(ServerWebExchange exchange) { - MultiValueMap queryParams = exchange.getRequest() - .getQueryParams(); String redirectUri = UriComponentsBuilder.fromUri(exchange.getRequest().getURI()) - .query(null) .build() .toUriString(); - return OAuth2AuthorizationResponseUtils - .convert(queryParams, redirectUri); + .convert(exchange.getRequest().getQueryParams(), redirectUri); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java index f86256cbe0..db8f1b80c1 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,25 +25,28 @@ import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.ReactiveAuthenticationManager; -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.authentication.TestOAuth2AuthorizationCodeAuthenticationTokens; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; -import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; -import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; +import org.springframework.util.CollectionUtils; import org.springframework.web.server.handler.DefaultWebFilterChain; import reactor.core.publisher.Mono; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; /** * @author Rob Winch @@ -102,7 +105,7 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { MockServerWebExchange exchange = MockServerWebExchange .from(MockServerHttpRequest.get("/")); DefaultWebFilterChain chain = new DefaultWebFilterChain( - e -> e.getResponse().setComplete()); + e -> e.getResponse().setComplete(), Collections.emptyList()); this.filter.filter(exchange, chain).block(); @@ -111,43 +114,161 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests { @Test public void filterWhenMatchThenAuthorizedClientSaved() { - OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() - .redirectUri("/authorize/registration-id") - .build(); - OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success() - .redirectUri("/authorize/registration-id") - .build(); - OAuth2AuthorizationExchange authorizationExchange = - new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse); - ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); - Mono authentication = Mono.just( - new OAuth2AuthorizationCodeAuthenticationToken(registration, authorizationExchange)); - OAuth2AuthorizationCodeAuthenticationToken authenticated = TestOAuth2AuthorizationCodeAuthenticationTokens - .authenticated(); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + when(this.clientRegistrationRepository.findByRegistrationId(any())) + .thenReturn(Mono.just(clientRegistration)); - when(this.authenticationManager.authenticate(any())).thenReturn( - Mono.just(authenticated)); + MockServerHttpRequest authorizationRequest = + createAuthorizationRequest("/authorization/callback"); + OAuth2AuthorizationRequest oauth2AuthorizationRequest = + createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) - .thenReturn(Mono.just(authorizationRequest)); + .thenReturn(Mono.just(oauth2AuthorizationRequest)); + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) + .thenReturn(Mono.just(oauth2AuthorizationRequest)); + when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())) .thenReturn(Mono.empty()); - ServerAuthenticationConverter converter = e -> authentication; + when(this.authenticationManager.authenticate(any())) + .thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated())); - this.filter = new OAuth2AuthorizationCodeGrantWebFilter( - this.authenticationManager, converter, this.authorizedClientRepository); - this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository); - - MockServerHttpRequest request = MockServerHttpRequest - .get("/authorize/registration-id") - .queryParam(OAuth2ParameterNames.CODE, "code") - .queryParam(OAuth2ParameterNames.STATE, "state") - .build(); - MockServerWebExchange exchange = MockServerWebExchange.from(request); + MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); + MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); DefaultWebFilterChain chain = new DefaultWebFilterChain( - e -> e.getResponse().setComplete()); + e -> e.getResponse().setComplete(), Collections.emptyList()); this.filter.filter(exchange, chain).block(); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(AnonymousAuthenticationToken.class), any()); } + + // gh-7966 + @Test + public void filterWhenAuthorizationRequestRedirectUriParametersMatchThenProcessed() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + when(this.clientRegistrationRepository.findByRegistrationId(any())) + .thenReturn(Mono.just(clientRegistration)); + when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())) + .thenReturn(Mono.empty()); + when(this.authenticationManager.authenticate(any())) + .thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated())); + + // 1) redirect_uri with query parameters + Map parameters = new LinkedHashMap<>(); + parameters.put("param1", "value1"); + parameters.put("param2", "value2"); + MockServerHttpRequest authorizationRequest = + createAuthorizationRequest("/authorization/callback", parameters); + OAuth2AuthorizationRequest oauth2AuthorizationRequest = + createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); + when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) + .thenReturn(Mono.just(oauth2AuthorizationRequest)); + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) + .thenReturn(Mono.just(oauth2AuthorizationRequest)); + + MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); + MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); + DefaultWebFilterChain chain = new DefaultWebFilterChain( + e -> e.getResponse().setComplete(), Collections.emptyList()); + + this.filter.filter(exchange, chain).block(); + verify(this.authenticationManager, times(1)).authenticate(any()); + + // 2) redirect_uri with query parameters AND authorization response additional parameters + Map additionalParameters = new LinkedHashMap<>(); + additionalParameters.put("auth-param1", "value1"); + additionalParameters.put("auth-param2", "value2"); + authorizationResponse = createAuthorizationResponse(authorizationRequest, additionalParameters); + exchange = MockServerWebExchange.from(authorizationResponse); + + this.filter.filter(exchange, chain).block(); + verify(this.authenticationManager, times(2)).authenticate(any()); + } + + // gh-7966 + @Test + public void filterWhenAuthorizationRequestRedirectUriParametersNotMatchThenNotProcessed() { + String requestUri = "/authorization/callback"; + Map parameters = new LinkedHashMap<>(); + parameters.put("param1", "value1"); + parameters.put("param2", "value2"); + MockServerHttpRequest authorizationRequest = + createAuthorizationRequest(requestUri, parameters); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AuthorizationRequest oauth2AuthorizationRequest = + createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); + when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) + .thenReturn(Mono.just(oauth2AuthorizationRequest)); + + // 1) Parameter value + Map parametersNotMatch = new LinkedHashMap<>(parameters); + parametersNotMatch.put("param2", "value8"); + MockServerHttpRequest authorizationResponse = createAuthorizationResponse( + createAuthorizationRequest(requestUri, parametersNotMatch)); + MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); + DefaultWebFilterChain chain = new DefaultWebFilterChain( + e -> e.getResponse().setComplete(), Collections.emptyList()); + + this.filter.filter(exchange, chain).block(); + verifyZeroInteractions(this.authenticationManager); + + // 2) Parameter order + parametersNotMatch = new LinkedHashMap<>(); + parametersNotMatch.put("param2", "value2"); + parametersNotMatch.put("param1", "value1"); + authorizationResponse = createAuthorizationResponse( + createAuthorizationRequest(requestUri, parametersNotMatch)); + exchange = MockServerWebExchange.from(authorizationResponse); + + this.filter.filter(exchange, chain).block(); + verifyZeroInteractions(this.authenticationManager); + + // 3) Parameter missing + parametersNotMatch = new LinkedHashMap<>(parameters); + parametersNotMatch.remove("param2"); + authorizationResponse = createAuthorizationResponse( + createAuthorizationRequest(requestUri, parametersNotMatch)); + exchange = MockServerWebExchange.from(authorizationResponse); + + this.filter.filter(exchange, chain).block(); + verifyZeroInteractions(this.authenticationManager); + } + + private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest( + MockServerHttpRequest authorizationRequest, ClientRegistration registration) { + Map attributes = new HashMap<>(); + attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId()); + return request() + .attributes(attributes) + .redirectUri(authorizationRequest.getURI().toString()) + .build(); + } + + private static MockServerHttpRequest createAuthorizationRequest(String requestUri) { + return createAuthorizationRequest(requestUri, new LinkedHashMap<>()); + } + + private static MockServerHttpRequest createAuthorizationRequest(String requestUri, Map parameters) { + MockServerHttpRequest.BaseBuilder builder = MockServerHttpRequest + .get(requestUri); + if (!CollectionUtils.isEmpty(parameters)) { + parameters.forEach(builder::queryParam); + } + return builder.build(); + } + + private static MockServerHttpRequest createAuthorizationResponse(MockServerHttpRequest authorizationRequest) { + return createAuthorizationResponse(authorizationRequest, new LinkedHashMap<>()); + } + + private static MockServerHttpRequest createAuthorizationResponse( + MockServerHttpRequest authorizationRequest, Map additionalParameters) { + MockServerHttpRequest.BaseBuilder builder = MockServerHttpRequest + .get(authorizationRequest.getURI().toString()); + builder.queryParam(OAuth2ParameterNames.CODE, "code"); + builder.queryParam(OAuth2ParameterNames.STATE, "state"); + additionalParameters.forEach(builder::queryParam); + builder.cookies(authorizationRequest.getCookies()); + return builder.build(); + } }