Add authenticationFailureHandler

- To ServerHttpSecurity#httpBasic
- To ServerHttpSecurity#oauthResourceServer

Closes gh-12132
This commit is contained in:
Josh Cummings 2022-11-02 15:34:54 -06:00
parent 6622e0135a
commit 3192618220
No known key found for this signature in database
GPG Key ID: A306A51F43B8E5A5
7 changed files with 199 additions and 13 deletions

View File

@ -2023,6 +2023,8 @@ public class ServerHttpSecurity {
private ServerAuthenticationEntryPoint entryPoint;
private ServerAuthenticationFailureHandler authenticationFailureHandler;
private HttpBasicSpec() {
List<DelegateEntry> entryPoints = new ArrayList<>();
entryPoints
@ -2071,6 +2073,13 @@ public class ServerHttpSecurity {
return this;
}
public HttpBasicSpec authenticationFailureHandler(
ServerAuthenticationFailureHandler authenticationFailureHandler) {
Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
this.authenticationFailureHandler = authenticationFailureHandler;
return this;
}
/**
* Allows method chaining to continue configuring the {@link ServerHttpSecurity}
* @return the {@link ServerHttpSecurity} to continue configuring
@ -2102,13 +2111,19 @@ public class ServerHttpSecurity {
Arrays.asList(this.xhrMatcher, restNotHtmlMatcher));
ServerHttpSecurity.this.defaultEntryPoints.add(new DelegateEntry(preferredMatcher, this.entryPoint));
AuthenticationWebFilter authenticationFilter = new AuthenticationWebFilter(this.authenticationManager);
authenticationFilter
.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(this.entryPoint));
authenticationFilter.setAuthenticationFailureHandler(authenticationFailureHandler());
authenticationFilter.setAuthenticationConverter(new ServerHttpBasicAuthenticationConverter());
authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.HTTP_BASIC);
}
private ServerAuthenticationFailureHandler authenticationFailureHandler() {
if (this.authenticationFailureHandler != null) {
return this.authenticationFailureHandler;
}
return new ServerAuthenticationEntryPointFailureHandler(this.entryPoint);
}
}
/**
@ -3996,6 +4011,8 @@ public class ServerHttpSecurity {
private ServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint();
private ServerAuthenticationFailureHandler authenticationFailureHandler;
private ServerAccessDeniedHandler accessDeniedHandler = new BearerTokenServerAccessDeniedHandler();
private ServerAuthenticationConverter bearerTokenConverter = new ServerBearerTokenAuthenticationConverter();
@ -4038,6 +4055,12 @@ public class ServerHttpSecurity {
return this;
}
public OAuth2ResourceServerSpec authenticationFailureHandler(
ServerAuthenticationFailureHandler authenticationFailureHandler) {
this.authenticationFailureHandler = authenticationFailureHandler;
return this;
}
/**
* Configures the {@link ServerAuthenticationConverter} to use for requests
* authenticating with
@ -4127,8 +4150,7 @@ public class ServerHttpSecurity {
if (this.authenticationManagerResolver != null) {
AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(this.authenticationManagerResolver);
oauth2.setServerAuthenticationConverter(this.bearerTokenConverter);
oauth2.setAuthenticationFailureHandler(
new ServerAuthenticationEntryPointFailureHandler(this.entryPoint));
oauth2.setAuthenticationFailureHandler(authenticationFailureHandler());
http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION);
}
else if (this.jwt != null) {
@ -4181,6 +4203,13 @@ public class ServerHttpSecurity {
}
}
private ServerAuthenticationFailureHandler authenticationFailureHandler() {
if (this.authenticationFailureHandler != null) {
return this.authenticationFailureHandler;
}
return new ServerAuthenticationEntryPointFailureHandler(this.entryPoint);
}
public ServerHttpSecurity and() {
return ServerHttpSecurity.this;
}
@ -4262,8 +4291,7 @@ public class ServerHttpSecurity {
ReactiveAuthenticationManager authenticationManager = getAuthenticationManager();
AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager);
oauth2.setServerAuthenticationConverter(OAuth2ResourceServerSpec.this.bearerTokenConverter);
oauth2.setAuthenticationFailureHandler(
new ServerAuthenticationEntryPointFailureHandler(OAuth2ResourceServerSpec.this.entryPoint));
oauth2.setAuthenticationFailureHandler(authenticationFailureHandler());
http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION);
}
@ -4398,8 +4426,7 @@ public class ServerHttpSecurity {
ReactiveAuthenticationManager authenticationManager = getAuthenticationManager();
AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager);
oauth2.setServerAuthenticationConverter(OAuth2ResourceServerSpec.this.bearerTokenConverter);
oauth2.setAuthenticationFailureHandler(
new ServerAuthenticationEntryPointFailureHandler(OAuth2ResourceServerSpec.this.entryPoint));
oauth2.setAuthenticationFailureHandler(authenticationFailureHandler());
http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION);
}

View File

@ -21,6 +21,7 @@ import org.springframework.security.core.Authentication
import org.springframework.security.core.context.SecurityContext
import org.springframework.security.web.authentication.www.BasicAuthenticationFilter
import org.springframework.security.web.server.ServerAuthenticationEntryPoint
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler
import org.springframework.security.web.server.context.ReactorContextWebFilter
import org.springframework.security.web.server.context.ServerSecurityContextRepository
@ -42,6 +43,7 @@ import org.springframework.security.web.server.context.ServerSecurityContextRepo
class ServerHttpBasicDsl {
var authenticationManager: ReactiveAuthenticationManager? = null
var securityContextRepository: ServerSecurityContextRepository? = null
var authenticationFailureHandler: ServerAuthenticationFailureHandler? = null
var authenticationEntryPoint: ServerAuthenticationEntryPoint? = null
private var disabled = false
@ -57,6 +59,7 @@ class ServerHttpBasicDsl {
return { httpBasic ->
authenticationManager?.also { httpBasic.authenticationManager(authenticationManager) }
securityContextRepository?.also { httpBasic.securityContextRepository(securityContextRepository) }
authenticationFailureHandler?.also { httpBasic.authenticationFailureHandler(authenticationFailureHandler) }
authenticationEntryPoint?.also { httpBasic.authenticationEntryPoint(authenticationEntryPoint) }
if (disabled) {
httpBasic.disable()

View File

@ -19,6 +19,7 @@ package org.springframework.security.config.web.server
import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver
import org.springframework.security.web.server.ServerAuthenticationEntryPoint
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler
import org.springframework.web.server.ServerWebExchange
@ -38,6 +39,7 @@ import org.springframework.web.server.ServerWebExchange
@ServerSecurityMarker
class ServerOAuth2ResourceServerDsl {
var accessDeniedHandler: ServerAccessDeniedHandler? = null
var authenticationFailureHandler: ServerAuthenticationFailureHandler? = null
var authenticationEntryPoint: ServerAuthenticationEntryPoint? = null
var bearerTokenConverter: ServerAuthenticationConverter? = null
var authenticationManagerResolver: ReactiveAuthenticationManagerResolver<ServerWebExchange>? = null
@ -107,6 +109,7 @@ class ServerOAuth2ResourceServerDsl {
internal fun get(): (ServerHttpSecurity.OAuth2ResourceServerSpec) -> Unit {
return { oauth2ResourceServer ->
accessDeniedHandler?.also { oauth2ResourceServer.accessDeniedHandler(accessDeniedHandler) }
authenticationFailureHandler?.also { oauth2ResourceServer.authenticationFailureHandler(authenticationFailureHandler) }
authenticationEntryPoint?.also { oauth2ResourceServer.authenticationEntryPoint(authenticationEntryPoint) }
bearerTokenConverter?.also { oauth2ResourceServer.bearerTokenConverter(bearerTokenConverter) }
authenticationManagerResolver?.also { oauth2ResourceServer.authenticationManagerResolver(authenticationManagerResolver!!) }

View File

@ -51,6 +51,7 @@ import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver;
import org.springframework.security.authentication.TestingAuthenticationToken;
@ -73,6 +74,7 @@ import org.springframework.security.oauth2.server.resource.introspection.Reactiv
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint;
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler;
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.springframework.test.web.reactive.server.WebTestClient;
@ -348,6 +350,25 @@ public class OAuth2ResourceServerSpecTests {
// @formatter:on
}
@Test
public void getWhenUsingCustomAuthenticationFailureHandlerThenUsesIsAccordingly() {
this.spring.register(CustomAuthenticationFailureHandlerConfig.class).autowire();
ServerAuthenticationFailureHandler handler = this.spring.getContext()
.getBean(ServerAuthenticationFailureHandler.class);
ReactiveAuthenticationManager authenticationManager = this.spring.getContext()
.getBean(ReactiveAuthenticationManager.class);
given(authenticationManager.authenticate(any()))
.willReturn(Mono.error(() -> new BadCredentialsException("bad")));
given(handler.onAuthenticationFailure(any(), any())).willReturn(Mono.empty());
// @formatter:off
this.client.get()
.headers((headers) -> headers.setBearerAuth(this.messageReadToken))
.exchange()
.expectStatus().isOk();
// @formatter:on
verify(handler).onAuthenticationFailure(any(), any());
}
@Test
public void postWhenSignedThenReturnsOk() {
this.spring.register(PublicKeyConfig.class, RootController.class).autowire();
@ -893,6 +914,35 @@ public class OAuth2ResourceServerSpecTests {
}
@EnableWebFlux
@EnableWebFluxSecurity
static class CustomAuthenticationFailureHandlerConfig {
@Bean
SecurityWebFilterChain springSecurity(ServerHttpSecurity http) {
// @formatter:off
http
.authorizeExchange((authorize) -> authorize.anyExchange().authenticated())
.oauth2ResourceServer((oauth2) -> oauth2
.authenticationFailureHandler(authenticationFailureHandler())
.jwt((jwt) -> jwt.authenticationManager(authenticationManager()))
);
// @formatter:on
return http.build();
}
@Bean
ReactiveAuthenticationManager authenticationManager() {
return mock(ReactiveAuthenticationManager.class);
}
@Bean
ServerAuthenticationFailureHandler authenticationFailureHandler() {
return mock(ServerAuthenticationFailureHandler.class);
}
}
@EnableWebFlux
@EnableWebFluxSecurity
static class CustomBearerTokenServerAuthenticationConverter {

View File

@ -35,6 +35,7 @@ import reactor.test.publisher.TestPublisher;
import org.springframework.http.HttpStatus;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
@ -57,6 +58,7 @@ import org.springframework.security.web.server.WebFilterChainProxy;
import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests;
import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint;
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler;
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler;
import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
@ -218,6 +220,27 @@ public class ServerHttpSecurityTests {
verify(authenticationEntryPoint).commence(any(), any());
}
@Test
public void basicWhenCustomAuthenticationFailureHandlerThenUses() {
ReactiveAuthenticationManager authenticationManager = mock(ReactiveAuthenticationManager.class);
ServerAuthenticationFailureHandler authenticationFailureHandler = mock(
ServerAuthenticationFailureHandler.class);
this.http.httpBasic().authenticationFailureHandler(authenticationFailureHandler);
this.http.httpBasic().authenticationManager(authenticationManager);
this.http.authorizeExchange().anyExchange().authenticated();
given(authenticationManager.authenticate(any()))
.willReturn(Mono.error(() -> new BadCredentialsException("bad")));
given(authenticationFailureHandler.onAuthenticationFailure(any(), any())).willReturn(Mono.empty());
WebTestClient client = buildClient();
// @formatter:off
client.get().uri("/")
.headers((headers) -> headers.setBasicAuth("user", "password"))
.exchange()
.expectStatus().isOk();
// @formatter:on
verify(authenticationFailureHandler).onAuthenticationFailure(any(), any());
}
@Test
public void buildWhenServerWebExchangeFromContextThenFound() {
SecurityWebFilterChain filter = this.http.build();

View File

@ -19,7 +19,6 @@ package org.springframework.security.config.web.server
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import java.util.Base64
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import org.springframework.beans.factory.annotation.Autowired
@ -36,6 +35,7 @@ import org.springframework.security.core.userdetails.MapReactiveUserDetailsServi
import org.springframework.security.core.userdetails.User
import org.springframework.security.web.server.SecurityWebFilterChain
import org.springframework.security.web.server.ServerAuthenticationEntryPoint
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler
import org.springframework.security.web.server.context.ServerSecurityContextRepository
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository
import org.springframework.test.web.reactive.server.WebTestClient
@ -43,6 +43,7 @@ import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RestController
import org.springframework.web.reactive.config.EnableWebFlux
import reactor.core.publisher.Mono
import java.util.*
/**
* Tests for [ServerHttpBasicDsl]
@ -216,6 +217,43 @@ class ServerHttpBasicDslTests {
}
}
@Test
fun `http basic when custom authentication failure handler then failure handler used`() {
this.spring.register(CustomAuthenticationFailureHandlerConfig::class.java, UserDetailsConfig::class.java).autowire()
mockkObject(CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER)
every {
CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any())
} returns Mono.empty()
this.client.get()
.uri("/")
.header("Authorization", "Basic " + Base64.getEncoder().encodeToString("user:wrong".toByteArray()))
.exchange()
verify(exactly = 1) { CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any()) }
}
@EnableWebFluxSecurity
@EnableWebFlux
open class CustomAuthenticationFailureHandlerConfig {
companion object {
val FAILURE_HANDLER: ServerAuthenticationFailureHandler = ServerAuthenticationFailureHandler { _, _ -> Mono.empty() }
}
@Bean
open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
return http {
authorizeExchange {
authorize(anyExchange, authenticated)
}
httpBasic {
authenticationFailureHandler = FAILURE_HANDLER
}
}
}
}
@Configuration
open class UserDetailsConfig {
@Bean

View File

@ -19,10 +19,6 @@ package org.springframework.security.config.web.server
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import java.math.BigInteger
import java.security.KeyFactory
import java.security.interfaces.RSAPublicKey
import java.security.spec.RSAPublicKeySpec
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import org.springframework.beans.factory.annotation.Autowired
@ -36,11 +32,16 @@ import org.springframework.security.config.test.SpringTestContextExtension
import org.springframework.security.oauth2.server.resource.web.server.authentication.ServerBearerTokenAuthenticationConverter
import org.springframework.security.web.server.SecurityWebFilterChain
import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler
import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.web.reactive.config.EnableWebFlux
import org.springframework.web.server.ServerWebExchange
import reactor.core.publisher.Mono
import java.math.BigInteger
import java.security.KeyFactory
import java.security.interfaces.RSAPublicKey
import java.security.spec.RSAPublicKeySpec
/**
* Tests for [ServerOAuth2ResourceServerDsl]
@ -125,6 +126,47 @@ class ServerOAuth2ResourceServerDslTests {
}
}
@Test
fun `http basic when custom authentication failure handler then failure handler used`() {
this.spring.register(AuthenticationFailureHandlerConfig::class.java).autowire()
mockkObject(AuthenticationFailureHandlerConfig.FAILURE_HANDLER)
every {
AuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any())
} returns Mono.empty()
this.client.get()
.uri("/")
.header("Authorization", "Bearer token")
.exchange()
.expectStatus().isOk
verify(exactly = 1) { AuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any()) }
}
@EnableWebFluxSecurity
@EnableWebFlux
open class AuthenticationFailureHandlerConfig {
companion object {
val FAILURE_HANDLER: ServerAuthenticationFailureHandler = ServerAuthenticationFailureHandler { _, _ -> Mono.empty() }
}
@Bean
open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
return http {
authorizeExchange {
authorize(anyExchange, authenticated)
}
oauth2ResourceServer {
authenticationFailureHandler = FAILURE_HANDLER
jwt {
publicKey = publicKey()
}
}
}
}
}
@Test
fun `request when custom bearer token converter configured then custom converter used`() {
this.spring.register(BearerTokenConverterConfig::class.java).autowire()