diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index e9e3e09381..e2b19f0340 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -2023,6 +2023,8 @@ public class ServerHttpSecurity { private ServerAuthenticationEntryPoint entryPoint; + private ServerAuthenticationFailureHandler authenticationFailureHandler; + private HttpBasicSpec() { List entryPoints = new ArrayList<>(); entryPoints @@ -2071,6 +2073,13 @@ public class ServerHttpSecurity { return this; } + public HttpBasicSpec authenticationFailureHandler( + ServerAuthenticationFailureHandler authenticationFailureHandler) { + Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); + this.authenticationFailureHandler = authenticationFailureHandler; + return this; + } + /** * Allows method chaining to continue configuring the {@link ServerHttpSecurity} * @return the {@link ServerHttpSecurity} to continue configuring @@ -2102,13 +2111,19 @@ public class ServerHttpSecurity { Arrays.asList(this.xhrMatcher, restNotHtmlMatcher)); ServerHttpSecurity.this.defaultEntryPoints.add(new DelegateEntry(preferredMatcher, this.entryPoint)); AuthenticationWebFilter authenticationFilter = new AuthenticationWebFilter(this.authenticationManager); - authenticationFilter - .setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(this.entryPoint)); + authenticationFilter.setAuthenticationFailureHandler(authenticationFailureHandler()); authenticationFilter.setAuthenticationConverter(new ServerHttpBasicAuthenticationConverter()); authenticationFilter.setSecurityContextRepository(this.securityContextRepository); http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.HTTP_BASIC); } + private ServerAuthenticationFailureHandler authenticationFailureHandler() { + if (this.authenticationFailureHandler != null) { + return this.authenticationFailureHandler; + } + return new ServerAuthenticationEntryPointFailureHandler(this.entryPoint); + } + } /** @@ -3996,6 +4011,8 @@ public class ServerHttpSecurity { private ServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint(); + private ServerAuthenticationFailureHandler authenticationFailureHandler; + private ServerAccessDeniedHandler accessDeniedHandler = new BearerTokenServerAccessDeniedHandler(); private ServerAuthenticationConverter bearerTokenConverter = new ServerBearerTokenAuthenticationConverter(); @@ -4038,6 +4055,12 @@ public class ServerHttpSecurity { return this; } + public OAuth2ResourceServerSpec authenticationFailureHandler( + ServerAuthenticationFailureHandler authenticationFailureHandler) { + this.authenticationFailureHandler = authenticationFailureHandler; + return this; + } + /** * Configures the {@link ServerAuthenticationConverter} to use for requests * authenticating with @@ -4127,8 +4150,7 @@ public class ServerHttpSecurity { if (this.authenticationManagerResolver != null) { AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(this.authenticationManagerResolver); oauth2.setServerAuthenticationConverter(this.bearerTokenConverter); - oauth2.setAuthenticationFailureHandler( - new ServerAuthenticationEntryPointFailureHandler(this.entryPoint)); + oauth2.setAuthenticationFailureHandler(authenticationFailureHandler()); http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); } else if (this.jwt != null) { @@ -4181,6 +4203,13 @@ public class ServerHttpSecurity { } } + private ServerAuthenticationFailureHandler authenticationFailureHandler() { + if (this.authenticationFailureHandler != null) { + return this.authenticationFailureHandler; + } + return new ServerAuthenticationEntryPointFailureHandler(this.entryPoint); + } + public ServerHttpSecurity and() { return ServerHttpSecurity.this; } @@ -4262,8 +4291,7 @@ public class ServerHttpSecurity { ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager); oauth2.setServerAuthenticationConverter(OAuth2ResourceServerSpec.this.bearerTokenConverter); - oauth2.setAuthenticationFailureHandler( - new ServerAuthenticationEntryPointFailureHandler(OAuth2ResourceServerSpec.this.entryPoint)); + oauth2.setAuthenticationFailureHandler(authenticationFailureHandler()); http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); } @@ -4398,8 +4426,7 @@ public class ServerHttpSecurity { ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager); oauth2.setServerAuthenticationConverter(OAuth2ResourceServerSpec.this.bearerTokenConverter); - oauth2.setAuthenticationFailureHandler( - new ServerAuthenticationEntryPointFailureHandler(OAuth2ResourceServerSpec.this.entryPoint)); + oauth2.setAuthenticationFailureHandler(authenticationFailureHandler()); http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); } diff --git a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDsl.kt index 91b157c264..7aa73ff0ed 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDsl.kt @@ -21,6 +21,7 @@ import org.springframework.security.core.Authentication import org.springframework.security.core.context.SecurityContext import org.springframework.security.web.authentication.www.BasicAuthenticationFilter import org.springframework.security.web.server.ServerAuthenticationEntryPoint +import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler import org.springframework.security.web.server.context.ReactorContextWebFilter import org.springframework.security.web.server.context.ServerSecurityContextRepository @@ -42,6 +43,7 @@ import org.springframework.security.web.server.context.ServerSecurityContextRepo class ServerHttpBasicDsl { var authenticationManager: ReactiveAuthenticationManager? = null var securityContextRepository: ServerSecurityContextRepository? = null + var authenticationFailureHandler: ServerAuthenticationFailureHandler? = null var authenticationEntryPoint: ServerAuthenticationEntryPoint? = null private var disabled = false @@ -57,6 +59,7 @@ class ServerHttpBasicDsl { return { httpBasic -> authenticationManager?.also { httpBasic.authenticationManager(authenticationManager) } securityContextRepository?.also { httpBasic.securityContextRepository(securityContextRepository) } + authenticationFailureHandler?.also { httpBasic.authenticationFailureHandler(authenticationFailureHandler) } authenticationEntryPoint?.also { httpBasic.authenticationEntryPoint(authenticationEntryPoint) } if (disabled) { httpBasic.disable() diff --git a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDsl.kt index 52992780c3..3c76817d67 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDsl.kt @@ -19,6 +19,7 @@ package org.springframework.security.config.web.server import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver import org.springframework.security.web.server.ServerAuthenticationEntryPoint import org.springframework.security.web.server.authentication.ServerAuthenticationConverter +import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler import org.springframework.web.server.ServerWebExchange @@ -38,6 +39,7 @@ import org.springframework.web.server.ServerWebExchange @ServerSecurityMarker class ServerOAuth2ResourceServerDsl { var accessDeniedHandler: ServerAccessDeniedHandler? = null + var authenticationFailureHandler: ServerAuthenticationFailureHandler? = null var authenticationEntryPoint: ServerAuthenticationEntryPoint? = null var bearerTokenConverter: ServerAuthenticationConverter? = null var authenticationManagerResolver: ReactiveAuthenticationManagerResolver? = null @@ -109,6 +111,7 @@ class ServerOAuth2ResourceServerDsl { internal fun get(): (ServerHttpSecurity.OAuth2ResourceServerSpec) -> Unit { return { oauth2ResourceServer -> accessDeniedHandler?.also { oauth2ResourceServer.accessDeniedHandler(accessDeniedHandler) } + authenticationFailureHandler?.also { oauth2ResourceServer.authenticationFailureHandler(authenticationFailureHandler) } authenticationEntryPoint?.also { oauth2ResourceServer.authenticationEntryPoint(authenticationEntryPoint) } bearerTokenConverter?.also { oauth2ResourceServer.bearerTokenConverter(bearerTokenConverter) } authenticationManagerResolver?.also { oauth2ResourceServer.authenticationManagerResolver(authenticationManagerResolver!!) } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java index e5b96965f2..40793cdde0 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java @@ -51,6 +51,7 @@ import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -73,6 +74,7 @@ import org.springframework.security.oauth2.server.resource.introspection.Reactiv import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint; import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; +import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; @@ -347,6 +349,25 @@ public class OAuth2ResourceServerSpecTests { // @formatter:on } + @Test + public void getWhenUsingCustomAuthenticationFailureHandlerThenUsesIsAccordingly() { + this.spring.register(CustomAuthenticationFailureHandlerConfig.class).autowire(); + ServerAuthenticationFailureHandler handler = this.spring.getContext() + .getBean(ServerAuthenticationFailureHandler.class); + ReactiveAuthenticationManager authenticationManager = this.spring.getContext() + .getBean(ReactiveAuthenticationManager.class); + given(authenticationManager.authenticate(any())) + .willReturn(Mono.error(() -> new BadCredentialsException("bad"))); + given(handler.onAuthenticationFailure(any(), any())).willReturn(Mono.empty()); + // @formatter:off + this.client.get() + .headers((headers) -> headers.setBearerAuth(this.messageReadToken)) + .exchange() + .expectStatus().isOk(); + // @formatter:on + verify(handler).onAuthenticationFailure(any(), any()); + } + @Test public void postWhenSignedThenReturnsOk() { this.spring.register(PublicKeyConfig.class, RootController.class).autowire(); @@ -903,6 +924,35 @@ public class OAuth2ResourceServerSpecTests { } @Configuration + @EnableWebFlux + @EnableWebFluxSecurity + static class CustomAuthenticationFailureHandlerConfig { + + @Bean + SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { + // @formatter:off + http + .authorizeExchange((authorize) -> authorize.anyExchange().authenticated()) + .oauth2ResourceServer((oauth2) -> oauth2 + .authenticationFailureHandler(authenticationFailureHandler()) + .jwt((jwt) -> jwt.authenticationManager(authenticationManager())) + ); + // @formatter:on + return http.build(); + } + + @Bean + ReactiveAuthenticationManager authenticationManager() { + return mock(ReactiveAuthenticationManager.class); + } + + @Bean + ServerAuthenticationFailureHandler authenticationFailureHandler() { + return mock(ServerAuthenticationFailureHandler.class); + } + + } + @EnableWebFlux @EnableWebFluxSecurity static class CustomBearerTokenServerAuthenticationConverter { diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index 26d7a293be..5acbe6327e 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -35,6 +35,7 @@ import reactor.test.publisher.TestPublisher; import org.springframework.http.HttpStatus; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; @@ -57,6 +58,7 @@ import org.springframework.security.web.server.WebFilterChainProxy; import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests; import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint; +import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter; import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler; import org.springframework.security.web.server.authentication.logout.LogoutWebFilter; @@ -218,6 +220,27 @@ public class ServerHttpSecurityTests { verify(authenticationEntryPoint).commence(any(), any()); } + @Test + public void basicWhenCustomAuthenticationFailureHandlerThenUses() { + ReactiveAuthenticationManager authenticationManager = mock(ReactiveAuthenticationManager.class); + ServerAuthenticationFailureHandler authenticationFailureHandler = mock( + ServerAuthenticationFailureHandler.class); + this.http.httpBasic().authenticationFailureHandler(authenticationFailureHandler); + this.http.httpBasic().authenticationManager(authenticationManager); + this.http.authorizeExchange().anyExchange().authenticated(); + given(authenticationManager.authenticate(any())) + .willReturn(Mono.error(() -> new BadCredentialsException("bad"))); + given(authenticationFailureHandler.onAuthenticationFailure(any(), any())).willReturn(Mono.empty()); + WebTestClient client = buildClient(); + // @formatter:off + client.get().uri("/") + .headers((headers) -> headers.setBasicAuth("user", "password")) + .exchange() + .expectStatus().isOk(); + // @formatter:on + verify(authenticationFailureHandler).onAuthenticationFailure(any(), any()); + } + @Test public void buildWhenServerWebExchangeFromContextThenFound() { SecurityWebFilterChain filter = this.http.build(); diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDslTests.kt index 6b0721fce8..8094fd884c 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDslTests.kt @@ -19,7 +19,6 @@ package org.springframework.security.config.web.server import io.mockk.every import io.mockk.mockkObject import io.mockk.verify -import java.util.Base64 import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtendWith import org.springframework.beans.factory.annotation.Autowired @@ -38,6 +37,7 @@ import org.springframework.security.core.userdetails.User import org.springframework.security.web.server.SecurityWebFilterChain import org.springframework.security.web.server.ServerAuthenticationEntryPoint import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint +import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler import org.springframework.security.web.server.context.ServerSecurityContextRepository import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository import org.springframework.test.web.reactive.server.WebTestClient @@ -45,6 +45,7 @@ import org.springframework.web.bind.annotation.RequestMapping import org.springframework.web.bind.annotation.RestController import org.springframework.web.reactive.config.EnableWebFlux import reactor.core.publisher.Mono +import java.util.* /** * Tests for [ServerHttpBasicDsl] @@ -228,6 +229,43 @@ class ServerHttpBasicDslTests { } } + @Test + fun `http basic when custom authentication failure handler then failure handler used`() { + this.spring.register(CustomAuthenticationFailureHandlerConfig::class.java, UserDetailsConfig::class.java).autowire() + mockkObject(CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER) + every { + CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any()) + } returns Mono.empty() + + this.client.get() + .uri("/") + .header("Authorization", "Basic " + Base64.getEncoder().encodeToString("user:wrong".toByteArray())) + .exchange() + + verify(exactly = 1) { CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any()) } + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class CustomAuthenticationFailureHandlerConfig { + + companion object { + val FAILURE_HANDLER: ServerAuthenticationFailureHandler = ServerAuthenticationFailureHandler { _, _ -> Mono.empty() } + } + + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + authorizeExchange { + authorize(anyExchange, authenticated) + } + httpBasic { + authenticationFailureHandler = FAILURE_HANDLER + } + } + } + } + @Configuration open class UserDetailsConfig { @Bean diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDslTests.kt index b517ec2959..13ede18898 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDslTests.kt @@ -19,10 +19,6 @@ package org.springframework.security.config.web.server import io.mockk.every import io.mockk.mockkObject import io.mockk.verify -import java.math.BigInteger -import java.security.KeyFactory -import java.security.interfaces.RSAPublicKey -import java.security.spec.RSAPublicKeySpec import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtendWith import org.springframework.beans.factory.annotation.Autowired @@ -38,11 +34,16 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtIss import org.springframework.security.oauth2.server.resource.web.server.authentication.ServerBearerTokenAuthenticationConverter import org.springframework.security.web.server.SecurityWebFilterChain import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint +import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.web.reactive.config.EnableWebFlux import org.springframework.web.server.ServerWebExchange import reactor.core.publisher.Mono +import java.math.BigInteger +import java.security.KeyFactory +import java.security.interfaces.RSAPublicKey +import java.security.spec.RSAPublicKeySpec /** * Tests for [ServerOAuth2ResourceServerDsl] @@ -129,6 +130,47 @@ class ServerOAuth2ResourceServerDslTests { } } + @Test + fun `http basic when custom authentication failure handler then failure handler used`() { + this.spring.register(AuthenticationFailureHandlerConfig::class.java).autowire() + mockkObject(AuthenticationFailureHandlerConfig.FAILURE_HANDLER) + every { + AuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any()) + } returns Mono.empty() + + this.client.get() + .uri("/") + .header("Authorization", "Bearer token") + .exchange() + .expectStatus().isOk + + verify(exactly = 1) { AuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any()) } + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class AuthenticationFailureHandlerConfig { + + companion object { + val FAILURE_HANDLER: ServerAuthenticationFailureHandler = ServerAuthenticationFailureHandler { _, _ -> Mono.empty() } + } + + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + authorizeExchange { + authorize(anyExchange, authenticated) + } + oauth2ResourceServer { + authenticationFailureHandler = FAILURE_HANDLER + jwt { + publicKey = publicKey() + } + } + } + } + } + @Test fun `request when custom bearer token converter configured then custom converter used`() { this.spring.register(BearerTokenConverterConfig::class.java).autowire()