diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 386af46663..37c3e78e46 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -621,20 +621,56 @@ public class ServerHttpSecurity { authenticationFilter.setAuthenticationFailureHandler(new RedirectServerAuthenticationFailureHandler("/login?error")); authenticationFilter.setSecurityContextRepository(new WebSessionServerSecurityContextRepository()); - MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher( - MediaType.TEXT_HTML); - htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL)); - Map urlToText = http.oauth2Login.getLinks(); - if (urlToText.size() == 1) { - http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint(urlToText.keySet().iterator().next()))); - } else { - http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint("/login"))); - } + setDefaultEntryPoints(http); http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION); } + private void setDefaultEntryPoints(ServerHttpSecurity http) { + String defaultLoginPage = "/login"; + Map urlToText = http.oauth2Login.getLinks(); + String providerLoginPage = null; + if (urlToText.size() == 1) { + providerLoginPage = urlToText.keySet().iterator().next(); + } + + MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher( + MediaType.APPLICATION_XHTML_XML, new MediaType("image", "*"), + MediaType.TEXT_HTML, MediaType.TEXT_PLAIN); + htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL)); + + ServerWebExchangeMatcher xhrMatcher = exchange -> { + if (exchange.getRequest().getHeaders().getOrDefault("X-Requested-With", Collections.emptyList()).contains("XMLHttpRequest")) { + return ServerWebExchangeMatcher.MatchResult.match(); + } + return ServerWebExchangeMatcher.MatchResult.notMatch(); + }; + ServerWebExchangeMatcher notXhrMatcher = new NegatedServerWebExchangeMatcher(xhrMatcher); + + ServerWebExchangeMatcher defaultEntryPointMatcher = new AndServerWebExchangeMatcher( + notXhrMatcher, htmlMatcher); + + if (providerLoginPage != null) { + ServerWebExchangeMatcher loginPageMatcher = new PathPatternParserServerWebExchangeMatcher(defaultLoginPage); + ServerWebExchangeMatcher faviconMatcher = new PathPatternParserServerWebExchangeMatcher("/favicon.ico"); + ServerWebExchangeMatcher defaultLoginPageMatcher = new AndServerWebExchangeMatcher( + new OrServerWebExchangeMatcher(loginPageMatcher, faviconMatcher), defaultEntryPointMatcher); + + ServerWebExchangeMatcher matcher = new AndServerWebExchangeMatcher( + notXhrMatcher, new NegatedServerWebExchangeMatcher(defaultLoginPageMatcher)); + RedirectServerAuthenticationEntryPoint entryPoint = + new RedirectServerAuthenticationEntryPoint(providerLoginPage); + entryPoint.setRequestCache(http.requestCache.requestCache); + http.defaultEntryPoints.add(new DelegateEntry(matcher, entryPoint)); + } + + RedirectServerAuthenticationEntryPoint defaultEntryPoint = + new RedirectServerAuthenticationEntryPoint(defaultLoginPage); + defaultEntryPoint.setRequestCache(http.requestCache.requestCache); + http.defaultEntryPoints.add(new DelegateEntry(defaultEntryPointMatcher, defaultEntryPoint)); + } + private ServerWebExchangeMatcher createAttemptAuthenticationRequestMatcher() { return new PathPatternParserServerWebExchangeMatcher("/login/oauth2/code/{registrationId}"); } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java index 25a0758474..090fa7ac2e 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,8 +26,10 @@ import org.junit.Rule; import org.junit.Test; import org.openqa.selenium.WebDriver; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpHeaders; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity; @@ -61,10 +63,12 @@ import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.WebFilterChainProxy; import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.reactive.config.EnableWebFlux; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; +import org.springframework.web.server.WebHandler; import reactor.core.publisher.Mono; import java.time.Duration; @@ -79,6 +83,8 @@ public class OAuth2LoginTests { @Rule public final SpringTestRule spring = new SpringTestRule(); + private WebTestClient client; + @Autowired private WebFilterChainProxy springSecurity; @@ -94,6 +100,14 @@ public class OAuth2LoginTests { .clientSecret("secret") .build(); + @Autowired + public void setApplicationContext(ApplicationContext context) { + if (context.getBeanNamesForType(WebHandler.class).length > 0) { + this.client = WebTestClient.bindToApplicationContext(context) + .build(); + } + } + @Test public void defaultLoginPageWithMultipleClientRegistrationsThenLinks() { this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class).autowire(); @@ -140,6 +154,22 @@ public class OAuth2LoginTests { assertThat(driver.getCurrentUrl()).startsWith("https://github.com/login/oauth/authorize"); } + // gh-8118 + @Test + public void defaultLoginPageWithSingleClientRegistrationAndXhrRequestThenDoesNotRedirectForAuthorization() { + this.spring.register(OAuth2LoginWithSingleClientRegistrations.class, WebFluxConfig.class).autowire(); + + this.client.get() + .uri("/") + .header("X-Requested-With", "XMLHttpRequest") + .exchange() + .expectStatus().is3xxRedirection() + .expectHeader().valueEquals(HttpHeaders.LOCATION, "/login"); + } + + @EnableWebFlux + static class WebFluxConfig { } + @EnableWebFluxSecurity static class OAuth2LoginWithSingleClientRegistrations { @Bean