diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 7a5ec0c262..d2edb13414 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -43,6 +43,7 @@ import org.springframework.security.authorization.AuthenticatedReactiveAuthoriza import org.springframework.security.authorization.AuthorityReactiveAuthorizationManager; import org.springframework.security.authorization.AuthorizationDecision; import org.springframework.security.authorization.ReactiveAuthorizationManager; +import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; @@ -114,7 +115,9 @@ import org.springframework.security.web.server.savedrequest.ServerRequestCacheWe import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache; import org.springframework.security.web.server.ui.LoginPageGeneratingWebFilter; import org.springframework.security.web.server.ui.LogoutPageGeneratingWebFilter; +import org.springframework.security.web.server.util.matcher.AndServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.MediaTypeServerWebExchangeMatcher; +import org.springframework.security.web.server.util.matcher.NegatedServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcherEntry; @@ -130,6 +133,8 @@ import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; import static org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry; +import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.match; +import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult.notMatch; /** * A {@link ServerHttpSecurity} is similar to Spring Security's {@code HttpSecurity} but for WebFlux. @@ -703,6 +708,9 @@ public class ServerHttpSecurity { public class JwtSpec { private ReactiveJwtDecoder jwtDecoder; + private BearerTokenServerWebExchangeMatcher bearerTokenServerWebExchangeMatcher = + new BearerTokenServerWebExchangeMatcher(); + /** * Configures the {@link ReactiveJwtDecoder} to use * @param jwtDecoder the decoder to use @@ -740,13 +748,20 @@ public class ServerHttpSecurity { } protected void configure(ServerHttpSecurity http) { + ServerBearerTokenAuthenticationConverter bearerTokenConverter = + new ServerBearerTokenAuthenticationConverter(); + this.bearerTokenServerWebExchangeMatcher.setBearerTokenConverter(bearerTokenConverter); + + registerDefaultCsrfOverride(http); + BearerTokenServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint(); ReactiveJwtDecoder jwtDecoder = getJwtDecoder(); JwtReactiveAuthenticationManager authenticationManager = new JwtReactiveAuthenticationManager( jwtDecoder); AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager); - oauth2.setServerAuthenticationConverter(new ServerBearerTokenAuthenticationConverter()); + oauth2.setServerAuthenticationConverter(bearerTokenConverter); oauth2.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(entryPoint)); + http .exceptionHandling() .authenticationEntryPoint(entryPoint) @@ -760,6 +775,38 @@ public class ServerHttpSecurity { } return this.jwtDecoder; } + + private void registerDefaultCsrfOverride(ServerHttpSecurity http) { + if ( http.csrf != null && !http.csrf.specifiedRequireCsrfProtectionMatcher ) { + http + .csrf() + .requireCsrfProtectionMatcher( + new AndServerWebExchangeMatcher( + CsrfWebFilter.DEFAULT_CSRF_MATCHER, + new NegatedServerWebExchangeMatcher( + this.bearerTokenServerWebExchangeMatcher))); + } + } + + private class BearerTokenServerWebExchangeMatcher implements ServerWebExchangeMatcher { + ServerBearerTokenAuthenticationConverter bearerTokenConverter; + + @Override + public Mono matches(ServerWebExchange exchange) { + return this.bearerTokenConverter.convert(exchange) + .flatMap(this::nullAuthentication) + .onErrorResume(e -> notMatch()); + } + + public void setBearerTokenConverter(ServerBearerTokenAuthenticationConverter bearerTokenConverter) { + Assert.notNull(bearerTokenConverter, "bearerTokenConverter cannot be null"); + this.bearerTokenConverter = bearerTokenConverter; + } + + private Mono nullAuthentication(Authentication authentication) { + return authentication == null ? notMatch() : match(); + } + } } public ServerHttpSecurity and() { @@ -1173,6 +1220,8 @@ public class ServerHttpSecurity { public class CsrfSpec { private CsrfWebFilter filter = new CsrfWebFilter(); + private boolean specifiedRequireCsrfProtectionMatcher; + /** * Configures the {@link ServerAccessDeniedHandler} used when a CSRF token is invalid. Default is * to send an {@link org.springframework.http.HttpStatus#FORBIDDEN}. @@ -1209,6 +1258,7 @@ public class ServerHttpSecurity { public CsrfSpec requireCsrfProtectionMatcher( ServerWebExchangeMatcher requireCsrfProtectionMatcher) { this.filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher); + this.specifiedRequireCsrfProtectionMatcher = true; return this; } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java index 83b0bf75f8..db35656753 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java @@ -48,6 +48,7 @@ import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.support.GenericWebApplicationContext; import org.springframework.web.reactive.DispatcherHandler; @@ -160,6 +161,25 @@ public class OAuth2ResourceServerSpecTests { .expectStatus().isOk(); } + @Test + public void postWhenSignedThenReturnsOk() { + this.spring.register(PublicKeyConfig.class, RootController.class).autowire(); + + this.client.post() + .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .exchange() + .expectStatus().isOk(); + } + + @Test + public void postWhenMissingTokenThenReturnsForbidden() { + this.spring.register(PublicKeyConfig.class, RootController.class).autowire(); + + this.client.post() + .exchange() + .expectStatus().isForbidden(); + } + @Test public void getJwtDecoderWhenBeanWiredAndDslWiredThenDslTakesPrecedence() { GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext(); @@ -301,7 +321,12 @@ public class OAuth2ResourceServerSpecTests { @RestController static class RootController { @GetMapping - Mono root() { + Mono get() { + return Mono.just("ok"); + } + + @PostMapping + Mono post() { return Mono.just("ok"); } }