diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java index d4325704cb..81d94c3436 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java @@ -16,10 +16,18 @@ package org.springframework.security.config.web.server; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + import org.junit.Rule; import org.junit.Test; import org.openqa.selenium.WebDriver; +import reactor.core.publisher.Mono; + import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.authentication.ReactiveAuthenticationManager; @@ -27,15 +35,22 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity; import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeReactiveAuthenticationManager; +import org.springframework.security.oauth2.client.web.server.oidc.logout.OidcClientInitiatedServerLogoutSuccessHandler; import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.core.OAuth2AccessToken; @@ -60,21 +75,21 @@ import org.springframework.security.test.web.reactive.server.WebTestClientBuilde import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.WebFilterChainProxy; import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; +import org.springframework.security.web.server.context.ServerSecurityContextRepository; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.reactive.config.EnableWebFlux; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; -import reactor.core.publisher.Mono; - -import java.time.Instant; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; +import org.springframework.web.server.WebHandler; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -85,6 +100,8 @@ public class OAuth2LoginTests { @Rule public final SpringTestRule spring = new SpringTestRule(); + private WebTestClient client; + @Autowired private WebFilterChainProxy springSecurity; @@ -100,6 +117,14 @@ public class OAuth2LoginTests { .clientSecret("secret") .build(); + @Autowired + public void setApplicationContext(ApplicationContext context) { + if (context.getBeanNamesForType(WebHandler.class).length > 0) { + this.client = WebTestClient.bindToApplicationContext(context) + .build(); + } + } + @Test public void defaultLoginPageWithMultipleClientRegistrationsThenLinks() { this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class).autowire(); @@ -326,6 +351,60 @@ public class OAuth2LoginTests { } } + + @Test + public void logoutWhenUsingOidcLogoutHandlerThenRedirects() throws Exception { + this.spring.register(OAuth2LoginConfigWithOidcLogoutSuccessHandler.class).autowire(); + + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( + TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, + getBean(ClientRegistration.class).getRegistrationId()); + + ServerSecurityContextRepository repository = getBean(ServerSecurityContextRepository.class); + when(repository.load(any())).thenReturn(authentication(token)); + + this.client.post().uri("/logout") + .exchange() + .expectHeader().valueEquals("Location", "http://logout?id_token_hint=id-token"); + } + + @EnableWebFlux + @EnableWebFluxSecurity + static class OAuth2LoginConfigWithOidcLogoutSuccessHandler { + private final ServerSecurityContextRepository repository = + mock(ServerSecurityContextRepository.class); + private final ClientRegistration withLogout = + TestClientRegistrations.clientRegistration() + .providerConfigurationMetadata(Collections.singletonMap( + "end_session_endpoint", "http://logout")).build(); + + @Bean + public SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { + + http + .csrf().disable() + .logout() + .logoutSuccessHandler( + new OidcClientInitiatedServerLogoutSuccessHandler( + new InMemoryReactiveClientRegistrationRepository(this.withLogout))) + .and() + .securityContextRepository(this.repository); + + return http.build(); + } + + @Bean + ServerSecurityContextRepository securityContextRepository() { + return this.repository; + } + + @Bean + ClientRegistration clientRegistration() { + return this.withLogout; + } + } + static class GitHubWebFilter implements WebFilter { @Override @@ -336,4 +415,14 @@ public class OAuth2LoginTests { return chain.filter(exchange); } } + + Mono authentication(Authentication authentication) { + SecurityContext context = new SecurityContextImpl(); + context.setAuthentication(authentication); + return Mono.just(context); + } + + T getBean(Class beanClass) { + return this.spring.getContext().getBean(beanClass); + } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/oidc/logout/OidcClientInitiatedServerLogoutSuccessHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/oidc/logout/OidcClientInitiatedServerLogoutSuccessHandler.java new file mode 100644 index 0000000000..6f77c5f49f --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/oidc/logout/OidcClientInitiatedServerLogoutSuccessHandler.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.server.oidc.logout; + +import java.net.URI; +import java.nio.charset.StandardCharsets; + +import reactor.core.publisher.Mono; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.web.server.DefaultServerRedirectStrategy; +import org.springframework.security.web.server.ServerRedirectStrategy; +import org.springframework.security.web.server.WebFilterExchange; +import org.springframework.security.web.server.authentication.logout.RedirectServerLogoutSuccessHandler; +import org.springframework.security.web.server.authentication.logout.ServerLogoutSuccessHandler; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * A reactive logout success handler for initiating OIDC logout through the user agent. + * + * @author Josh Cummings + * @since 5.2 + * @see RP-Initiated Logout + * @see org.springframework.security.web.server.authentication.logout.ServerLogoutSuccessHandler + */ +public class OidcClientInitiatedServerLogoutSuccessHandler + implements ServerLogoutSuccessHandler { + + private final ServerRedirectStrategy redirectStrategy = new DefaultServerRedirectStrategy(); + private final RedirectServerLogoutSuccessHandler serverLogoutSuccessHandler + = new RedirectServerLogoutSuccessHandler(); + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + + private URI postLogoutRedirectUri; + + /** + * Constructs an {@link OidcClientInitiatedServerLogoutSuccessHandler} with the provided parameters + * + * @param clientRegistrationRepository The {@link ReactiveClientRegistrationRepository} to use to derive + * the end_session_endpoint value + */ + public OidcClientInitiatedServerLogoutSuccessHandler + (ReactiveClientRegistrationRepository clientRegistrationRepository) { + + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + } + + /** + * {@inheritDoc} + */ + @Override + public Mono onLogoutSuccess(WebFilterExchange exchange, Authentication authentication) { + return Mono.just(authentication) + .filter(OAuth2AuthenticationToken.class::isInstance) + .filter(token -> authentication.getPrincipal() instanceof OidcUser) + .map(OAuth2AuthenticationToken.class::cast) + .flatMap(this::endSessionEndpoint) + .map(endSessionEndpoint -> endpointUri(endSessionEndpoint, authentication)) + .switchIfEmpty(this.serverLogoutSuccessHandler + .onLogoutSuccess(exchange, authentication).then(Mono.empty())) + .flatMap(endpointUri -> this.redirectStrategy.sendRedirect(exchange.getExchange(), endpointUri)); + } + + private Mono endSessionEndpoint(OAuth2AuthenticationToken token) { + String registrationId = token.getAuthorizedClientRegistrationId(); + return this.clientRegistrationRepository.findByRegistrationId(registrationId) + .map(ClientRegistration::getProviderDetails) + .map(ClientRegistration.ProviderDetails::getConfigurationMetadata) + .flatMap(configurationMetadata -> Mono.justOrEmpty(configurationMetadata.get("end_session_endpoint"))) + .map(Object::toString) + .map(URI::create); + } + + private URI endpointUri(URI endSessionEndpoint, Authentication authentication) { + UriComponentsBuilder builder = UriComponentsBuilder.fromUri(endSessionEndpoint); + builder.queryParam("id_token_hint", idToken(authentication)); + if (this.postLogoutRedirectUri != null) { + builder.queryParam("post_logout_redirect_uri", this.postLogoutRedirectUri); + } + return builder.encode(StandardCharsets.UTF_8).build().toUri(); + } + + private String idToken(Authentication authentication) { + return ((OidcUser) authentication.getPrincipal()).getIdToken().getTokenValue(); + } + + /** + * Set the post logout redirect uri to use + * + * @param postLogoutRedirectUri - A valid URL to which the OP should redirect after logging out the user + */ + public void setPostLogoutRedirectUri(URI postLogoutRedirectUri) { + Assert.notNull(postLogoutRedirectUri, "postLogoutRedirectUri cannot be empty"); + this.postLogoutRedirectUri = postLogoutRedirectUri; + } + + /** + * The URL to redirect to after successfully logging out when not originally an OIDC login + * + * @param logoutSuccessUrl the url to redirect to. Default is "/login?logout". + */ + public void setLogoutSuccessUrl(URI logoutSuccessUrl) { + Assert.notNull(logoutSuccessUrl, "logoutSuccessUrl cannot be null"); + this.serverLogoutSuccessHandler.setLogoutSuccessUrl(logoutSuccessUrl); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/oidc/logout/OidcClientInitiatedServerLogoutSuccessHandlerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/oidc/logout/OidcClientInitiatedServerLogoutSuccessHandlerTests.java new file mode 100644 index 0000000000..c891a15893 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/oidc/logout/OidcClientInitiatedServerLogoutSuccessHandlerTests.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.server.oidc.logout; + +import java.net.URI; +import java.util.Collections; + +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.MockServerHttpResponse; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers; +import org.springframework.security.oauth2.core.user.TestOAuth2Users; +import org.springframework.security.web.server.WebFilterExchange; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilterChain; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class OidcClientInitiatedServerLogoutSuccessHandlerTests { + ClientRegistration registration = TestClientRegistrations + .clientRegistration() + .providerConfigurationMetadata( + Collections.singletonMap("end_session_endpoint", "http://endpoint")) + .build(); + ReactiveClientRegistrationRepository repository = new InMemoryReactiveClientRegistrationRepository(registration); + + ServerWebExchange exchange; + WebFilterChain chain; + + OidcClientInitiatedServerLogoutSuccessHandler handler; + + @Before + public void setup() { + this.exchange = mock(ServerWebExchange.class); + when(this.exchange.getResponse()).thenReturn(new MockServerHttpResponse()); + when(this.exchange.getRequest()).thenReturn(MockServerHttpRequest.get("/").build()); + this.chain = mock(WebFilterChain.class); + this.handler = new OidcClientInitiatedServerLogoutSuccessHandler(this.repository); + } + + @Test + public void logoutWhenOidcRedirectUrlConfiguredThenRedirects() { + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( + TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, + this.registration.getRegistrationId()); + + when(this.exchange.getPrincipal()).thenReturn(Mono.just(token)); + WebFilterExchange f = new WebFilterExchange(exchange, this.chain); + this.handler.onLogoutSuccess(f, token).block(); + + assertThat(redirectedUrl(this.exchange)).isEqualTo("http://endpoint?id_token_hint=id-token"); + } + + @Test + public void logoutWhenNotOAuth2AuthenticationThenDefaults() { + Authentication token = mock(Authentication.class); + + when(this.exchange.getPrincipal()).thenReturn(Mono.just(token)); + WebFilterExchange f = new WebFilterExchange(exchange, this.chain); + + this.handler.setLogoutSuccessUrl(URI.create("http://default")); + this.handler.onLogoutSuccess(f, token).block(); + + assertThat(redirectedUrl(this.exchange)).isEqualTo("http://default"); + } + + @Test + public void logoutWhenNotOidcUserThenDefaults() { + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( + TestOAuth2Users.create(), + AuthorityUtils.NO_AUTHORITIES, + this.registration.getRegistrationId()); + + when(this.exchange.getPrincipal()).thenReturn(Mono.just(token)); + WebFilterExchange f = new WebFilterExchange(exchange, this.chain); + + this.handler.setLogoutSuccessUrl(URI.create("http://default")); + this.handler.onLogoutSuccess(f, token).block(); + + assertThat(redirectedUrl(this.exchange)).isEqualTo("http://default"); + } + + @Test + public void logoutWhenClientRegistrationHasNoEndSessionEndpointThenDefaults() { + + ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); + ReactiveClientRegistrationRepository repository = + new InMemoryReactiveClientRegistrationRepository(registration); + OidcClientInitiatedServerLogoutSuccessHandler handler = + new OidcClientInitiatedServerLogoutSuccessHandler(repository); + + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( + TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, + registration.getRegistrationId()); + + when(this.exchange.getPrincipal()).thenReturn(Mono.just(token)); + WebFilterExchange f = new WebFilterExchange(exchange, this.chain); + + handler.setLogoutSuccessUrl(URI.create("http://default")); + handler.onLogoutSuccess(f, token).block(); + + assertThat(redirectedUrl(this.exchange)).isEqualTo("http://default"); + } + + @Test + public void logoutWhenUsingPostLogoutRedirectUriThenIncludesItInRedirect() { + + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken( + TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, + this.registration.getRegistrationId()); + + when(this.exchange.getPrincipal()).thenReturn(Mono.just(token)); + WebFilterExchange f = new WebFilterExchange(exchange, this.chain); + + this.handler.setPostLogoutRedirectUri(URI.create("http://postlogout?encodedparam=value")); + this.handler.onLogoutSuccess(f, token).block(); + + assertThat(redirectedUrl(this.exchange)) + .isEqualTo("http://endpoint?" + + "id_token_hint=id-token&" + + "post_logout_redirect_uri=http://postlogout?encodedparam%3Dvalue"); + } + + @Test + public void setPostLogoutRedirectUriWhenGivenNullThenThrowsException() { + assertThatThrownBy(() -> this.handler.setPostLogoutRedirectUri(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + private String redirectedUrl(ServerWebExchange exchange) { + return exchange.getResponse().getHeaders().getFirst("Location"); + } +}