ServerHttpSecurity ReactiveJwtDecoder discovery

This makes so that WebFlux OAuth 2.0 Resource Server configuration
will pick up a ReactiveJwtDecoder exposed as a bean.

Fixes: gh-5720
This commit is contained in:
Josh Cummings 2018-08-23 12:17:33 -06:00 committed by Rob Winch
parent 0fdc081ab5
commit cba2444e1a
2 changed files with 136 additions and 3 deletions

View File

@ -741,8 +741,9 @@ public class ServerHttpSecurity {
protected void configure(ServerHttpSecurity http) { protected void configure(ServerHttpSecurity http) {
BearerTokenServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint(); BearerTokenServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint();
ReactiveJwtDecoder jwtDecoder = this.getJwtDecoder();
JwtReactiveAuthenticationManager authenticationManager = new JwtReactiveAuthenticationManager( JwtReactiveAuthenticationManager authenticationManager = new JwtReactiveAuthenticationManager(
this.jwtDecoder); jwtDecoder);
AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager); AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager);
oauth2.setServerAuthenticationConverter(new ServerBearerTokenAuthenticationConverter()); oauth2.setServerAuthenticationConverter(new ServerBearerTokenAuthenticationConverter());
oauth2.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(entryPoint)); oauth2.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(entryPoint));
@ -752,6 +753,13 @@ public class ServerHttpSecurity {
.and() .and()
.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); .addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION);
} }
protected ReactiveJwtDecoder getJwtDecoder() {
if (this.jwtDecoder == null) {
return getBean(ReactiveJwtDecoder.class);
}
return this.jwtDecoder;
}
} }
public ServerHttpSecurity and() { public ServerHttpSecurity and() {
@ -2014,6 +2022,13 @@ public class ServerHttpSecurity {
private LogoutSpec() {} private LogoutSpec() {}
} }
private <T> T getBean(Class<T> beanClass) {
if (this.context == null) {
return null;
}
return this.context.getBean(beanClass);
}
private <T> T getBeanOrNull(Class<T> beanClass) { private <T> T getBeanOrNull(Class<T> beanClass) {
return getBeanOrNull(ResolvableType.forClass(beanClass)); return getBeanOrNull(ResolvableType.forClass(beanClass));
} }

View File

@ -34,6 +34,8 @@ import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
@ -41,13 +43,24 @@ import org.springframework.security.config.annotation.web.reactive.EnableWebFlux
import org.springframework.security.config.test.SpringTestRule; import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.context.support.GenericWebApplicationContext;
import org.springframework.web.reactive.DispatcherHandler;
import org.springframework.web.reactive.config.EnableWebFlux; import org.springframework.web.reactive.config.EnableWebFlux;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.hamcrest.core.StringStartsWith.startsWith;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/** /**
* Tests for {@link org.springframework.security.config.web.server.ServerHttpSecurity.OAuth2ResourceServerSpec} * Tests for {@link org.springframework.security.config.web.server.ServerHttpSecurity.OAuth2ResourceServerSpec}
*/ */
@ -105,7 +118,7 @@ public class OAuth2ResourceServerSpecTests {
.headers(headers -> headers.setBearerAuth(this.expired)) .headers(headers -> headers.setBearerAuth(this.expired))
.exchange() .exchange()
.expectStatus().isUnauthorized() .expectStatus().isUnauthorized()
.expectHeader().exists(HttpHeaders.WWW_AUTHENTICATE); .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"invalid_token\""));
} }
@Test @Test
@ -116,7 +129,22 @@ public class OAuth2ResourceServerSpecTests {
.headers(headers -> headers.setBearerAuth(this.unsignedToken)) .headers(headers -> headers.setBearerAuth(this.unsignedToken))
.exchange() .exchange()
.expectStatus().isUnauthorized() .expectStatus().isUnauthorized()
.expectHeader().exists(HttpHeaders.WWW_AUTHENTICATE); .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"invalid_token\""));
}
@Test
public void getWhenCustomDecoderThenAuthenticatesAccordingly() {
this.spring.register(CustomDecoderConfig.class, RootController.class).autowire();
ReactiveJwtDecoder jwtDecoder = this.spring.getContext().getBean(ReactiveJwtDecoder.class);
when(jwtDecoder.decode(anyString())).thenReturn(Mono.just(this.jwt));
this.client.get()
.headers(headers -> headers.setBearerAuth("token"))
.exchange()
.expectStatus().isOk();
verify(jwtDecoder).decode(anyString());
} }
@Test @Test
@ -132,6 +160,67 @@ public class OAuth2ResourceServerSpecTests {
.expectStatus().isOk(); .expectStatus().isOk();
} }
@Test
public void getJwtDecoderWhenBeanWiredAndDslWiredThenDslTakesPrecedence() {
GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext();
ServerHttpSecurity http = new ServerHttpSecurity();
http.setApplicationContext(context);
ReactiveJwtDecoder beanWiredJwtDecoder = mock(ReactiveJwtDecoder.class);
ReactiveJwtDecoder dslWiredJwtDecoder = mock(ReactiveJwtDecoder.class);
context.registerBean(ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder);
ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt();
jwt.jwtDecoder(dslWiredJwtDecoder);
assertThat(jwt.getJwtDecoder()).isEqualTo(dslWiredJwtDecoder);
}
@Test
public void getJwtDecoderWhenTwoBeansWiredAndDslWiredThenDslTakesPrecedence() {
GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext();
ServerHttpSecurity http = new ServerHttpSecurity();
http.setApplicationContext(context);
ReactiveJwtDecoder beanWiredJwtDecoder = mock(ReactiveJwtDecoder.class);
ReactiveJwtDecoder dslWiredJwtDecoder = mock(ReactiveJwtDecoder.class);
context.registerBean("firstJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder);
context.registerBean("secondJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder);
ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt();
jwt.jwtDecoder(dslWiredJwtDecoder);
assertThat(jwt.getJwtDecoder()).isEqualTo(dslWiredJwtDecoder);
}
@Test
public void getJwtDecoderWhenTwoBeansWiredThenThrowsWiringException() {
GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext();
ServerHttpSecurity http = new ServerHttpSecurity();
http.setApplicationContext(context);
ReactiveJwtDecoder beanWiredJwtDecoder = mock(ReactiveJwtDecoder.class);
context.registerBean("firstJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder);
context.registerBean("secondJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder);
ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt();
assertThatCode(() -> jwt.getJwtDecoder())
.isInstanceOf(NoUniqueBeanDefinitionException.class);
}
@Test
public void getJwtDecoderWhenNoBeansAndNoDslWiredThenWiringException() {
GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext();
ServerHttpSecurity http = new ServerHttpSecurity();
http.setApplicationContext(context);
ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt();
assertThatCode(() -> jwt.getJwtDecoder())
.isInstanceOf(NoSuchBeanDefinitionException.class);
}
@EnableWebFlux @EnableWebFlux
@EnableWebFluxSecurity @EnableWebFluxSecurity
static class PublicKeyConfig { static class PublicKeyConfig {
@ -187,6 +276,28 @@ public class OAuth2ResourceServerSpecTests {
} }
} }
@EnableWebFlux
@EnableWebFluxSecurity
static class CustomDecoderConfig {
ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
@Bean
SecurityWebFilterChain springSecurity(ServerHttpSecurity http) {
// @formatter:off
http
.oauth2ResourceServer()
.jwt();
// @formatter:on
return http.build();
}
@Bean
ReactiveJwtDecoder jwtDecoder() {
return this.jwtDecoder;
}
}
@RestController @RestController
static class RootController { static class RootController {
@GetMapping @GetMapping
@ -194,4 +305,11 @@ public class OAuth2ResourceServerSpecTests {
return Mono.just("ok"); return Mono.just("ok");
} }
} }
private GenericWebApplicationContext autowireWebServerGenericWebApplicationContext() {
GenericWebApplicationContext context = new GenericWebApplicationContext();
context.registerBean("webHandler", DispatcherHandler.class);
this.spring.context(context).autowire();
return (GenericWebApplicationContext) this.spring.getContext();
}
} }