diff --git a/config/spring-security-config.gradle b/config/spring-security-config.gradle index 23fa96f840..fc56abf254 100644 --- a/config/spring-security-config.gradle +++ b/config/spring-security-config.gradle @@ -35,6 +35,7 @@ dependencies { testCompile powerMock2Dependencies testCompile spockDependencies testCompile 'ch.qos.logback:logback-classic' + testCompile 'io.projectreactor.ipc:reactor-netty' testCompile 'javax.annotation:jsr250-api:1.0' testCompile 'javax.xml.bind:jaxb-api' testCompile 'ldapsdk:ldapsdk:4.1' diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfiguration.java index bed91d4f97..30b2b44b7f 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfiguration.java @@ -87,13 +87,14 @@ class WebFluxSecurityConfiguration { private SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) { http .authorizeExchange() - .anyExchange().authenticated() - .and() - .httpBasic().and() - .formLogin(); + .anyExchange().authenticated(); - if (isOAuth2Present) { + if (isOAuth2Present && OAuth2ClasspathGuard.shouldConfigure(this.context)) { OAuth2ClasspathGuard.configure(this.context, http); + } else { + http + .httpBasic().and() + .formLogin(); } SecurityWebFilterChain result = http.build(); @@ -102,11 +103,13 @@ class WebFluxSecurityConfiguration { private static class OAuth2ClasspathGuard { static void configure(ApplicationContext context, ServerHttpSecurity http) { + http.oauth2Login(); + } + + static boolean shouldConfigure(ApplicationContext context) { ClassLoader loader = context.getClassLoader(); Class reactiveClientRegistrationRepositoryClass = ClassUtils.resolveClassName(REACTIVE_CLIENT_REGISTRATION_REPOSITORY_CLASSNAME, loader); - if (context.getBeanNamesForType(reactiveClientRegistrationRepositoryClass).length == 1) { - http.oauth2Login(); - } + return context.getBeanNamesForType(reactiveClientRegistrationRepositoryClass).length == 1; } } } diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 6c8c40622c..d54aed8296 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -41,6 +41,7 @@ import org.springframework.security.authorization.AuthorityReactiveAuthorization import org.springframework.security.authorization.AuthorizationDecision; import org.springframework.security.authorization.ReactiveAuthorizationManager; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.authentication.OAuth2LoginReactiveAuthenticationManager; import org.springframework.security.oauth2.client.endpoint.NimbusReactiveAuthorizationCodeTokenResponseClient; @@ -361,11 +362,7 @@ public class ServerHttpSecurity { return this; } - protected void configure(LoginPageGeneratingWebFilter loginPageFilter, ServerHttpSecurity http) { - if (loginPageFilter != null) { - loginPageFilter.setOauth2AuthenticationUrlToClientName(getLinks()); - } - + protected void configure(ServerHttpSecurity http) { ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(); ReactiveOAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientService(); OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(clientRegistrationRepository); @@ -417,6 +414,9 @@ public class ServerHttpSecurity { if (this.authorizedClientService == null) { this.authorizedClientService = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class); } + if (this.authorizedClientService == null) { + this.authorizedClientService = new InMemoryReactiveOAuth2AuthorizedClientService(getClientRegistrationRepository()); + } return this.authorizedClientService; } @@ -616,15 +616,24 @@ public class ServerHttpSecurity { if(this.securityContextRepository != null) { this.formLogin.securityContextRepository(this.securityContextRepository); } - if(this.formLogin.authenticationEntryPoint == null) { + if (this.authenticationEntryPoint == null) { loginPageFilter = new LoginPageGeneratingWebFilter(); - this.webFilters.add(new OrderedWebFilter(loginPageFilter, SecurityWebFiltersOrder.LOGIN_PAGE_GENERATING.getOrder())); - this.webFilters.add(new OrderedWebFilter(new LogoutPageGeneratingWebFilter(), SecurityWebFiltersOrder.LOGOUT_PAGE_GENERATING.getOrder())); + loginPageFilter.setFormLoginEnabled(true); + this.authenticationEntryPoint = this.formLogin.authenticationEntryPoint; } this.formLogin.configure(this); } if (this.oauth2Login != null) { - this.oauth2Login.configure(loginPageFilter, this); + if (this.authenticationEntryPoint == null) { + loginPageFilter = new LoginPageGeneratingWebFilter(); + loginPageFilter.setOauth2AuthenticationUrlToClientName(this.oauth2Login.getLinks()); + } + this.oauth2Login.configure(this); + } + if (loginPageFilter != null) { + this.authenticationEntryPoint = new RedirectServerAuthenticationEntryPoint("/login"); + this.webFilters.add(new OrderedWebFilter(loginPageFilter, SecurityWebFiltersOrder.LOGIN_PAGE_GENERATING.getOrder())); + this.webFilters.add(new OrderedWebFilter(new LogoutPageGeneratingWebFilter(), SecurityWebFiltersOrder.LOGOUT_PAGE_GENERATING.getOrder())); } if(this.logout != null) { this.logout.configure(this); @@ -638,8 +647,8 @@ public class ServerHttpSecurity { exceptionTranslationWebFilter.setAuthenticationEntryPoint( authenticationEntryPoint); } - if(accessDeniedHandler != null) { - exceptionTranslationWebFilter.setAccessDeniedHandler(accessDeniedHandler); + if(this.accessDeniedHandler != null) { + exceptionTranslationWebFilter.setAccessDeniedHandler(this.accessDeniedHandler); } this.addFilterAt(exceptionTranslationWebFilter, SecurityWebFiltersOrder.EXCEPTION_TRANSLATION); this.authorizeExchange.configure(this); diff --git a/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java index 70b0fa60f8..501c9fa22b 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java @@ -17,6 +17,8 @@ package org.springframework.security.config.web.server; import org.junit.Test; +import org.openqa.selenium.By; +import org.openqa.selenium.NoSuchElementException; import org.openqa.selenium.WebDriver; import org.openqa.selenium.WebElement; import org.openqa.selenium.support.FindBy; @@ -36,6 +38,8 @@ import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * @author Rob Winch @@ -204,9 +208,10 @@ public class FormLoginTests { private LoginForm loginForm; + private OAuth2Login oauth2Login = new OAuth2Login(); + public DefaultLoginPage(WebDriver webDriver) { this.driver = webDriver; - this.loginForm = PageFactory.initElements(webDriver, LoginForm.class); } static DefaultLoginPage create(WebDriver driver) { @@ -228,10 +233,23 @@ public class FormLoginTests { return this; } + public DefaultLoginPage assertLoginFormNotPresent() { + assertThatThrownBy(() -> loginForm().username("")) + .isInstanceOf(NoSuchElementException.class); + return this; + } + public LoginForm loginForm() { + if (this.loginForm == null) { + this.loginForm = PageFactory.initElements(this.driver, LoginForm.class); + } return this.loginForm; } + public OAuth2Login oauth2Login() { + return this.oauth2Login; + } + static DefaultLoginPage to(WebDriver driver) { driver.get("http://localhost/login"); return PageFactory.initElements(driver, DefaultLoginPage.class); @@ -263,6 +281,22 @@ public class FormLoginTests { return PageFactory.initElements(this.driver, page); } } + + public class OAuth2Login { + public WebElement findClientRegistrationByName(String clientName) { + return DefaultLoginPage.this.driver.findElement(By.linkText(clientName)); + } + + public OAuth2Login assertClientRegistrationByName(String clientName) { + assertThatCode(() -> findClientRegistrationByName(clientName)) + .doesNotThrowAnyException(); + return this; + } + + public DefaultLoginPage and() { + return DefaultLoginPage.this; + } + } } public static class DefaultLogoutPage { diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java new file mode 100644 index 0000000000..2ff2724f6c --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.web.server; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.Rule; +import org.junit.Test; +import org.openqa.selenium.WebDriver; +import org.openqa.selenium.WebElement; +import org.openqa.selenium.support.FindBy; +import org.openqa.selenium.support.PageFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity; +import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; +import org.springframework.security.config.oauth2.client.CommonOAuth2Provider; +import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; +import org.springframework.security.web.server.SecurityWebFilterChain; +import org.springframework.security.web.server.WebFilterChainProxy; +import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler; +import org.springframework.security.web.server.csrf.CsrfToken; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.server.ServerWebExchange; + +import reactor.core.publisher.Mono; + +/** + * @author Rob Winch + * @since 5.1 + */ +public class OAuth2LoginTests { + + @Rule + public final SpringTestRule spring = new SpringTestRule(); + + @Autowired + private WebFilterChainProxy springSecurity; + + private ClientRegistration github = CommonOAuth2Provider.GITHUB + .getBuilder("github") + .clientId("client") + .clientSecret("secret") + .build(); + + @Test + public void defaultLoginPageWithMultipleClientRegistrationsThenLinks() { + this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class).autowire(); + + WebTestClient webTestClient = WebTestClientBuilder + .bindToWebFilters(this.springSecurity) + .build(); + + WebDriver driver = WebTestClientHtmlUnitDriverBuilder + .webTestClientSetup(webTestClient) + .build(); + + FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage + .to(driver, FormLoginTests.DefaultLoginPage.class) + .assertAt() + .assertLoginFormNotPresent() + .oauth2Login() + .assertClientRegistrationByName(this.github.getClientName()) + .and(); + } + + @EnableWebFluxSecurity + static class OAuth2LoginWithMulitpleClientRegistrations { + @Bean + InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() { + ClientRegistration github = CommonOAuth2Provider.GITHUB + .getBuilder("github") + .clientId("client") + .clientSecret("secret") + .build(); + ClientRegistration google = CommonOAuth2Provider.GOOGLE + .getBuilder("google") + .clientId("client") + .clientSecret("secret") + .build(); + return new InMemoryReactiveClientRegistrationRepository(github, google); + } + } +} diff --git a/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java b/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java index 373ee21e4b..74f63f806e 100644 --- a/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java @@ -50,6 +50,12 @@ public class LoginPageGeneratingWebFilter implements WebFilter { private Map oauth2AuthenticationUrlToClientName = new HashMap<>(); + private boolean formLoginEnabled; + + public void setFormLoginEnabled(boolean enabled) { + this.formLoginEnabled = enabled; + } + public void setOauth2AuthenticationUrlToClientName( Map oauth2AuthenticationUrlToClientName) { Assert.notNull(oauth2AuthenticationUrlToClientName, "oauth2AuthenticationUrlToClientName cannot be null"); @@ -87,45 +93,47 @@ public class LoginPageGeneratingWebFilter implements WebFilter { private byte[] createPage(ServerWebExchange exchange, String csrfTokenHtmlInput) { MultiValueMap queryParams = exchange.getRequest() .getQueryParams(); - boolean isError = queryParams.containsKey("error"); - boolean isLogoutSuccess = queryParams.containsKey("logout"); String contextPath = exchange.getRequest().getPath().contextPath().value(); - String page = "\n" - + "\n" - + " \n" - + " \n" - + " \n" - + " \n" - + " \n" - + " Please sign in\n" - + " \n" - + " \n" - + " \n" - + " \n" - + "
\n" - + "
\n" - + " \n" - + createError(isError) - + createLogoutSuccess(isLogoutSuccess) - + "

\n" - + " \n" - + " \n" - + "

\n" - + "

\n" - + " \n" - + " \n" - + "

\n" - + csrfTokenHtmlInput - + " \n" - + "
\n" - + oauth2LoginLinks(contextPath, this.oauth2AuthenticationUrlToClientName) - + "
\n" - + " \n" - + ""; + String page = "\n" + "\n" + " \n" + + " \n" + + " \n" + + " \n" + + " \n" + + " Please sign in\n" + + " \n" + + " \n" + + " \n" + + " \n" + + "
\n" + + formLogin(queryParams, csrfTokenHtmlInput) + + oauth2LoginLinks(contextPath, this.oauth2AuthenticationUrlToClientName) + + "
\n" + + " \n" + + ""; return page.getBytes(Charset.defaultCharset()); } + private String formLogin(MultiValueMap queryParams, String csrfTokenHtmlInput) { + if (!this.formLoginEnabled) { + return ""; + } + boolean isError = queryParams.containsKey("error"); + boolean isLogoutSuccess = queryParams.containsKey("logout"); + return "
\n" + + " \n" + + createError(isError) + createLogoutSuccess(isLogoutSuccess) + + "

\n" + + " \n" + + " \n" + + "

\n" + "

\n" + + " \n" + + " \n" + + "

\n" + csrfTokenHtmlInput + + " \n" + + "
\n"; + } + private static String oauth2LoginLinks(String contextPath, Map oauth2AuthenticationUrlToClientName) { if (oauth2AuthenticationUrlToClientName.isEmpty()) { return "";