Single ClientRegistration redirects by default
Fixes: gh-5339
This commit is contained in:
parent
f29e4cf91f
commit
32e368d9b7
|
@ -183,6 +183,8 @@ public class ServerHttpSecurity {
|
|||
|
||||
private LogoutSpec logout = new LogoutSpec();
|
||||
|
||||
private LoginPageSpec loginPage = new LoginPageSpec();
|
||||
|
||||
private ReactiveAuthenticationManager authenticationManager;
|
||||
|
||||
private ServerSecurityContextRepository securityContextRepository = new WebSessionServerSecurityContextRepository();
|
||||
|
@ -387,6 +389,16 @@ public class ServerHttpSecurity {
|
|||
});
|
||||
authenticationFilter.setSecurityContextRepository(new WebSessionServerSecurityContextRepository());
|
||||
|
||||
MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher(
|
||||
MediaType.TEXT_HTML);
|
||||
htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL));
|
||||
Map<String, String> urlToText = http.oauth2Login.getLinks();
|
||||
if (urlToText.size() == 1) {
|
||||
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint(urlToText.keySet().iterator().next())));
|
||||
} else {
|
||||
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint("/login")));
|
||||
}
|
||||
|
||||
http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC);
|
||||
http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION);
|
||||
}
|
||||
|
@ -610,31 +622,17 @@ public class ServerHttpSecurity {
|
|||
this.httpBasic.authenticationManager(this.authenticationManager);
|
||||
this.httpBasic.configure(this);
|
||||
}
|
||||
LoginPageGeneratingWebFilter loginPageFilter = null;
|
||||
if(this.formLogin != null) {
|
||||
this.formLogin.authenticationManager(this.authenticationManager);
|
||||
if(this.securityContextRepository != null) {
|
||||
this.formLogin.securityContextRepository(this.securityContextRepository);
|
||||
}
|
||||
if (this.authenticationEntryPoint == null) {
|
||||
loginPageFilter = new LoginPageGeneratingWebFilter();
|
||||
loginPageFilter.setFormLoginEnabled(true);
|
||||
this.authenticationEntryPoint = this.formLogin.authenticationEntryPoint;
|
||||
}
|
||||
this.formLogin.configure(this);
|
||||
}
|
||||
if (this.oauth2Login != null) {
|
||||
if (this.authenticationEntryPoint == null) {
|
||||
loginPageFilter = new LoginPageGeneratingWebFilter();
|
||||
loginPageFilter.setOauth2AuthenticationUrlToClientName(this.oauth2Login.getLinks());
|
||||
}
|
||||
this.oauth2Login.configure(this);
|
||||
}
|
||||
if (loginPageFilter != null) {
|
||||
this.authenticationEntryPoint = new RedirectServerAuthenticationEntryPoint("/login");
|
||||
this.webFilters.add(new OrderedWebFilter(loginPageFilter, SecurityWebFiltersOrder.LOGIN_PAGE_GENERATING.getOrder()));
|
||||
this.webFilters.add(new OrderedWebFilter(new LogoutPageGeneratingWebFilter(), SecurityWebFiltersOrder.LOGOUT_PAGE_GENERATING.getOrder()));
|
||||
}
|
||||
this.loginPage.configure(this);
|
||||
if(this.logout != null) {
|
||||
this.logout.configure(this);
|
||||
}
|
||||
|
@ -1084,6 +1082,8 @@ public class ServerHttpSecurity {
|
|||
|
||||
private ServerAuthenticationEntryPoint authenticationEntryPoint;
|
||||
|
||||
private boolean isEntryPointExplicit;
|
||||
|
||||
private ServerWebExchangeMatcher requiresAuthenticationMatcher;
|
||||
|
||||
private ServerAuthenticationFailureHandler authenticationFailureHandler;
|
||||
|
@ -1206,7 +1206,10 @@ public class ServerHttpSecurity {
|
|||
|
||||
protected void configure(ServerHttpSecurity http) {
|
||||
if(this.authenticationEntryPoint == null) {
|
||||
this.isEntryPointExplicit = false;
|
||||
loginPage("/login");
|
||||
} else {
|
||||
this.isEntryPointExplicit = true;
|
||||
}
|
||||
if(http.requestCache != null) {
|
||||
ServerRequestCache requestCache = http.requestCache.requestCache;
|
||||
|
@ -1233,6 +1236,35 @@ public class ServerHttpSecurity {
|
|||
}
|
||||
}
|
||||
|
||||
private class LoginPageSpec {
|
||||
protected void configure(ServerHttpSecurity http) {
|
||||
if (http.authenticationEntryPoint != null) {
|
||||
return;
|
||||
}
|
||||
if (http.formLogin != null && http.formLogin.isEntryPointExplicit) {
|
||||
return;
|
||||
}
|
||||
LoginPageGeneratingWebFilter loginPage = null;
|
||||
if (http.formLogin != null && !http.formLogin.isEntryPointExplicit) {
|
||||
loginPage = new LoginPageGeneratingWebFilter();
|
||||
loginPage.setFormLoginEnabled(true);
|
||||
}
|
||||
if (http.oauth2Login != null) {
|
||||
Map<String, String> urlToText = http.oauth2Login.getLinks();
|
||||
if (loginPage == null) {
|
||||
loginPage = new LoginPageGeneratingWebFilter();
|
||||
}
|
||||
loginPage.setOauth2AuthenticationUrlToClientName(urlToText);
|
||||
}
|
||||
if (loginPage != null) {
|
||||
http.addFilterAt(loginPage, SecurityWebFiltersOrder.LOGIN_PAGE_GENERATING);
|
||||
http.addFilterAt(new LogoutPageGeneratingWebFilter(), SecurityWebFiltersOrder.LOGOUT_PAGE_GENERATING);
|
||||
}
|
||||
}
|
||||
|
||||
private LoginPageSpec() {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Configures HTTP Response Headers.
|
||||
*
|
||||
|
|
|
@ -21,29 +21,20 @@ import static org.assertj.core.api.Assertions.assertThat;
|
|||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.openqa.selenium.WebDriver;
|
||||
import org.openqa.selenium.WebElement;
|
||||
import org.openqa.selenium.support.FindBy;
|
||||
import org.openqa.selenium.support.PageFactory;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
|
||||
import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
|
||||
import org.springframework.security.config.oauth2.client.CommonOAuth2Provider;
|
||||
import org.springframework.security.config.test.SpringTestRule;
|
||||
import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder;
|
||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||
import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
|
||||
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
||||
import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
|
||||
import org.springframework.security.web.server.SecurityWebFilterChain;
|
||||
import org.springframework.security.web.server.WebFilterChainProxy;
|
||||
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler;
|
||||
import org.springframework.security.web.server.csrf.CsrfToken;
|
||||
import org.springframework.stereotype.Controller;
|
||||
import org.springframework.test.web.reactive.server.WebTestClient;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.ResponseBody;
|
||||
import org.springframework.web.server.ServerWebExchange;
|
||||
import org.springframework.web.server.WebFilter;
|
||||
import org.springframework.web.server.WebFilterChain;
|
||||
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
|
@ -59,7 +50,7 @@ public class OAuth2LoginTests {
|
|||
@Autowired
|
||||
private WebFilterChainProxy springSecurity;
|
||||
|
||||
private ClientRegistration github = CommonOAuth2Provider.GITHUB
|
||||
private static ClientRegistration github = CommonOAuth2Provider.GITHUB
|
||||
.getBuilder("github")
|
||||
.clientId("client")
|
||||
.clientSecret("secret")
|
||||
|
@ -90,11 +81,6 @@ public class OAuth2LoginTests {
|
|||
static class OAuth2LoginWithMulitpleClientRegistrations {
|
||||
@Bean
|
||||
InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() {
|
||||
ClientRegistration github = CommonOAuth2Provider.GITHUB
|
||||
.getBuilder("github")
|
||||
.clientId("client")
|
||||
.clientSecret("secret")
|
||||
.build();
|
||||
ClientRegistration google = CommonOAuth2Provider.GOOGLE
|
||||
.getBuilder("google")
|
||||
.clientId("client")
|
||||
|
@ -103,4 +89,40 @@ public class OAuth2LoginTests {
|
|||
return new InMemoryReactiveClientRegistrationRepository(github, google);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void defaultLoginPageWithSingleClientRegistrationThenRedirect() {
|
||||
this.spring.register(OAuth2LoginWithSingleClientRegistrations.class).autowire();
|
||||
|
||||
WebTestClient webTestClient = WebTestClientBuilder
|
||||
.bindToWebFilters(new GitHubWebFilter(), this.springSecurity)
|
||||
.build();
|
||||
|
||||
WebDriver driver = WebTestClientHtmlUnitDriverBuilder
|
||||
.webTestClientSetup(webTestClient)
|
||||
.build();
|
||||
|
||||
driver.get("http://localhost/");
|
||||
|
||||
assertThat(driver.getCurrentUrl()).startsWith("https://github.com/login/oauth/authorize");
|
||||
}
|
||||
|
||||
@EnableWebFluxSecurity
|
||||
static class OAuth2LoginWithSingleClientRegistrations {
|
||||
@Bean
|
||||
InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() {
|
||||
return new InMemoryReactiveClientRegistrationRepository(github);
|
||||
}
|
||||
}
|
||||
|
||||
static class GitHubWebFilter implements WebFilter {
|
||||
|
||||
@Override
|
||||
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
|
||||
if (exchange.getRequest().getURI().getHost().equals("github.com")) {
|
||||
return exchange.getResponse().setComplete();
|
||||
}
|
||||
return chain.filter(exchange);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -151,8 +151,14 @@ final class HtmlUnitWebTestClient {
|
|||
|
||||
private Mono<ClientResponse> redirectIfNecessary(ClientRequest request, ExchangeFunction next, ClientResponse response) {
|
||||
URI location = response.headers().asHttpHeaders().getLocation();
|
||||
String host = request.url().getHost();
|
||||
String scheme = request.url().getScheme();
|
||||
if(location != null) {
|
||||
ClientRequest redirect = ClientRequest.method(HttpMethod.GET, URI.create("http://localhost" + location.toASCIIString()))
|
||||
String redirectUrl = location.toASCIIString();
|
||||
if (location.getHost() == null) {
|
||||
redirectUrl = scheme+ "://" + host + location.toASCIIString();
|
||||
}
|
||||
ClientRequest redirect = ClientRequest.method(HttpMethod.GET, URI.create(redirectUrl))
|
||||
.headers(headers -> headers.addAll(request.headers()))
|
||||
.cookies(cookies -> cookies.addAll(request.cookies()))
|
||||
.attributes(attributes -> attributes.putAll(request.attributes()))
|
||||
|
|
Loading…
Reference in New Issue