From cba2444e1abf7a29a3a1adca507bc88277d1b9ad Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Thu, 23 Aug 2018 12:17:33 -0600 Subject: [PATCH] ServerHttpSecurity ReactiveJwtDecoder discovery This makes so that WebFlux OAuth 2.0 Resource Server configuration will pick up a ReactiveJwtDecoder exposed as a bean. Fixes: gh-5720 --- .../config/web/server/ServerHttpSecurity.java | 17 ++- .../server/OAuth2ResourceServerSpecTests.java | 122 +++++++++++++++++- 2 files changed, 136 insertions(+), 3 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 71ace7dac2..91f0a24ba7 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -741,8 +741,9 @@ public class ServerHttpSecurity { protected void configure(ServerHttpSecurity http) { BearerTokenServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint(); + ReactiveJwtDecoder jwtDecoder = this.getJwtDecoder(); JwtReactiveAuthenticationManager authenticationManager = new JwtReactiveAuthenticationManager( - this.jwtDecoder); + jwtDecoder); AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager); oauth2.setServerAuthenticationConverter(new ServerBearerTokenAuthenticationConverter()); oauth2.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(entryPoint)); @@ -752,6 +753,13 @@ public class ServerHttpSecurity { .and() .addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); } + + protected ReactiveJwtDecoder getJwtDecoder() { + if (this.jwtDecoder == null) { + return getBean(ReactiveJwtDecoder.class); + } + return this.jwtDecoder; + } } public ServerHttpSecurity and() { @@ -2014,6 +2022,13 @@ public class ServerHttpSecurity { private LogoutSpec() {} } + private T getBean(Class beanClass) { + if (this.context == null) { + return null; + } + return this.context.getBean(beanClass); + } + private T getBeanOrNull(Class beanClass) { return getBeanOrNull(ResolvableType.forClass(beanClass)); } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java index d83c195a73..83b0bf75f8 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java @@ -34,6 +34,8 @@ import org.junit.Test; import org.junit.runner.RunWith; import reactor.core.publisher.Mono; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; @@ -41,13 +43,24 @@ import org.springframework.security.config.annotation.web.reactive.EnableWebFlux import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.support.GenericWebApplicationContext; +import org.springframework.web.reactive.DispatcherHandler; import org.springframework.web.reactive.config.EnableWebFlux; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.hamcrest.core.StringStartsWith.startsWith; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + /** * Tests for {@link org.springframework.security.config.web.server.ServerHttpSecurity.OAuth2ResourceServerSpec} */ @@ -105,7 +118,7 @@ public class OAuth2ResourceServerSpecTests { .headers(headers -> headers.setBearerAuth(this.expired)) .exchange() .expectStatus().isUnauthorized() - .expectHeader().exists(HttpHeaders.WWW_AUTHENTICATE); + .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"invalid_token\"")); } @Test @@ -116,7 +129,22 @@ public class OAuth2ResourceServerSpecTests { .headers(headers -> headers.setBearerAuth(this.unsignedToken)) .exchange() .expectStatus().isUnauthorized() - .expectHeader().exists(HttpHeaders.WWW_AUTHENTICATE); + .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"invalid_token\"")); + } + + @Test + public void getWhenCustomDecoderThenAuthenticatesAccordingly() { + this.spring.register(CustomDecoderConfig.class, RootController.class).autowire(); + + ReactiveJwtDecoder jwtDecoder = this.spring.getContext().getBean(ReactiveJwtDecoder.class); + when(jwtDecoder.decode(anyString())).thenReturn(Mono.just(this.jwt)); + + this.client.get() + .headers(headers -> headers.setBearerAuth("token")) + .exchange() + .expectStatus().isOk(); + + verify(jwtDecoder).decode(anyString()); } @Test @@ -132,6 +160,67 @@ public class OAuth2ResourceServerSpecTests { .expectStatus().isOk(); } + @Test + public void getJwtDecoderWhenBeanWiredAndDslWiredThenDslTakesPrecedence() { + GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext(); + ServerHttpSecurity http = new ServerHttpSecurity(); + http.setApplicationContext(context); + + ReactiveJwtDecoder beanWiredJwtDecoder = mock(ReactiveJwtDecoder.class); + ReactiveJwtDecoder dslWiredJwtDecoder = mock(ReactiveJwtDecoder.class); + context.registerBean(ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder); + + ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt(); + jwt.jwtDecoder(dslWiredJwtDecoder); + + assertThat(jwt.getJwtDecoder()).isEqualTo(dslWiredJwtDecoder); + } + + @Test + public void getJwtDecoderWhenTwoBeansWiredAndDslWiredThenDslTakesPrecedence() { + GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext(); + ServerHttpSecurity http = new ServerHttpSecurity(); + http.setApplicationContext(context); + + ReactiveJwtDecoder beanWiredJwtDecoder = mock(ReactiveJwtDecoder.class); + ReactiveJwtDecoder dslWiredJwtDecoder = mock(ReactiveJwtDecoder.class); + context.registerBean("firstJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder); + context.registerBean("secondJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder); + + ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt(); + jwt.jwtDecoder(dslWiredJwtDecoder); + + assertThat(jwt.getJwtDecoder()).isEqualTo(dslWiredJwtDecoder); + } + + @Test + public void getJwtDecoderWhenTwoBeansWiredThenThrowsWiringException() { + GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext(); + ServerHttpSecurity http = new ServerHttpSecurity(); + http.setApplicationContext(context); + + ReactiveJwtDecoder beanWiredJwtDecoder = mock(ReactiveJwtDecoder.class); + context.registerBean("firstJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder); + context.registerBean("secondJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder); + + ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt(); + + assertThatCode(() -> jwt.getJwtDecoder()) + .isInstanceOf(NoUniqueBeanDefinitionException.class); + } + + @Test + public void getJwtDecoderWhenNoBeansAndNoDslWiredThenWiringException() { + GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext(); + ServerHttpSecurity http = new ServerHttpSecurity(); + http.setApplicationContext(context); + + ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt(); + + assertThatCode(() -> jwt.getJwtDecoder()) + .isInstanceOf(NoSuchBeanDefinitionException.class); + } + @EnableWebFlux @EnableWebFluxSecurity static class PublicKeyConfig { @@ -187,6 +276,28 @@ public class OAuth2ResourceServerSpecTests { } } + @EnableWebFlux + @EnableWebFluxSecurity + static class CustomDecoderConfig { + ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class); + + @Bean + SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { + // @formatter:off + http + .oauth2ResourceServer() + .jwt(); + // @formatter:on + + return http.build(); + } + + @Bean + ReactiveJwtDecoder jwtDecoder() { + return this.jwtDecoder; + } + } + @RestController static class RootController { @GetMapping @@ -194,4 +305,11 @@ public class OAuth2ResourceServerSpecTests { return Mono.just("ok"); } } + + private GenericWebApplicationContext autowireWebServerGenericWebApplicationContext() { + GenericWebApplicationContext context = new GenericWebApplicationContext(); + context.registerBean("webHandler", DispatcherHandler.class); + this.spring.context(context).autowire(); + return (GenericWebApplicationContext) this.spring.getContext(); + } }