diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index c37b664da1..848f6e53cd 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -881,11 +881,40 @@ public class ServerHttpSecurity { * Configures OAuth2 Resource Server Support */ public class OAuth2ResourceServerSpec { - private BearerTokenServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint(); - private BearerTokenServerAccessDeniedHandler accessDeniedHandler = new BearerTokenServerAccessDeniedHandler(); + private ServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint(); + private ServerAccessDeniedHandler accessDeniedHandler = new BearerTokenServerAccessDeniedHandler(); private JwtSpec jwt; + /** + * Configures the {@link ServerAccessDeniedHandler} to use for requests authenticating with + * Bearer Tokens. + * requests. + * + * @param accessDeniedHandler the {@link ServerAccessDeniedHandler} to use + * @return the {@link OAuth2ResourceServerSpec} for additional configuration + * @since 5.2 + */ + public OAuth2ResourceServerSpec accessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) { + Assert.notNull(accessDeniedHandler, "accessDeniedHandler cannot be null"); + this.accessDeniedHandler = accessDeniedHandler; + return this; + } + + /** + * Configures the {@link ServerAuthenticationEntryPoint} to use for requests authenticating with + * Bearer Tokens. + * + * @param entryPoint the {@link ServerAuthenticationEntryPoint} to use + * @return the {@link OAuth2ResourceServerSpec} for additional configuration + * @since 5.2 + */ + public OAuth2ResourceServerSpec authenticationEntryPoint(ServerAuthenticationEntryPoint entryPoint) { + Assert.notNull(entryPoint, "entryPoint cannot be null"); + this.entryPoint = entryPoint; + return this; + } + public JwtSpec jwt() { if (this.jwt == null) { this.jwt = new JwtSpec(); @@ -1024,7 +1053,7 @@ public class ServerHttpSecurity { http.defaultAccessDeniedHandlers.add( new ServerWebExchangeDelegatingServerAccessDeniedHandler.DelegateEntry( this.bearerTokenServerWebExchangeMatcher, - new BearerTokenServerAccessDeniedHandler() + OAuth2ResourceServerSpec.this.accessDeniedHandler ) ); } @@ -1035,7 +1064,7 @@ public class ServerHttpSecurity { http.defaultEntryPoints.add( new DelegateEntry( this.bearerTokenServerWebExchangeMatcher, - new BearerTokenServerAuthenticationEntryPoint() + OAuth2ResourceServerSpec.this.entryPoint ) ); } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java index 4aaa74d885..b92ba77064 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java @@ -43,6 +43,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpStatus; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity; @@ -58,6 +59,10 @@ import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationConverter; import org.springframework.security.oauth2.server.resource.authentication.ReactiveJwtAuthenticationConverterAdapter; import org.springframework.security.web.server.SecurityWebFilterChain; +import org.springframework.security.web.server.ServerAuthenticationEntryPoint; +import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint; +import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler; +import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler; import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; @@ -232,6 +237,27 @@ public class OAuth2ResourceServerSpecTests { .expectStatus().isOk(); } + @Test + public void getWhenCustomBearerTokenEntryPointThenResponds() { + this.spring.register(CustomErrorHandlingConfig.class).autowire(); + + this.client.get() + .uri("/authenticated") + .exchange() + .expectStatus().isEqualTo(HttpStatus.I_AM_A_TEAPOT); + } + + @Test + public void getWhenCustomBearerTokenDeniedHandlerThenResponds() { + this.spring.register(CustomErrorHandlingConfig.class).autowire(); + + this.client.get() + .uri("/unobtainable") + .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .exchange() + .expectStatus().isEqualTo(HttpStatus.BANDWIDTH_LIMIT_EXCEEDED); + } + @Test public void getJwtDecoderWhenBeanWiredAndDslWiredThenDslTakesPrecedence() { GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext(); @@ -438,7 +464,30 @@ public class OAuth2ResourceServerSpecTests { } } + @EnableWebFlux + @EnableWebFluxSecurity + static class CustomErrorHandlingConfig { + private ServerAccessDeniedHandler accessDeniedHandler = mock(ServerAccessDeniedHandler.class); + private ServerAuthenticationEntryPoint entryPoint = mock(ServerAuthenticationEntryPoint.class); + @Bean + SecurityWebFilterChain springSecurity(ServerHttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeExchange() + .pathMatchers("/authenticated").authenticated() + .pathMatchers("/unobtainable").hasAuthority("unobtainable") + .and() + .oauth2ResourceServer() + .accessDeniedHandler(new HttpStatusServerAccessDeniedHandler(HttpStatus.BANDWIDTH_LIMIT_EXCEEDED)) + .authenticationEntryPoint(new HttpStatusServerEntryPoint(HttpStatus.I_AM_A_TEAPOT)) + .jwt() + .publicKey(publicKey()); + // @formatter:on + + return http.build(); + } + } @RestController static class RootController {