diff --git a/webflux/src/main/java/org/springframework/security/web/server/AuthenticationEntryPoint.java b/webflux/src/main/java/org/springframework/security/web/server/AuthenticationEntryPoint.java index de4305aaa5..bfb360d4a0 100644 --- a/webflux/src/main/java/org/springframework/security/web/server/AuthenticationEntryPoint.java +++ b/webflux/src/main/java/org/springframework/security/web/server/AuthenticationEntryPoint.java @@ -17,12 +17,11 @@ */ package org.springframework.security.web.server; +import reactor.core.publisher.Mono; + import org.springframework.security.core.AuthenticationException; import org.springframework.web.server.ServerWebExchange; - -import reactor.core.publisher.Mono; - /** * * @author Rob Winch @@ -30,5 +29,5 @@ import reactor.core.publisher.Mono; */ public interface AuthenticationEntryPoint { - Mono commence(ServerWebExchange exchange, AuthenticationException e); + Mono commence(ServerWebExchange exchange, AuthenticationException e); } diff --git a/webflux/src/main/java/org/springframework/security/web/server/authentication/www/HttpBasicAuthenticationEntryPoint.java b/webflux/src/main/java/org/springframework/security/web/server/authentication/www/HttpBasicAuthenticationEntryPoint.java index 040833e8bd..e5ce01a4c8 100644 --- a/webflux/src/main/java/org/springframework/security/web/server/authentication/www/HttpBasicAuthenticationEntryPoint.java +++ b/webflux/src/main/java/org/springframework/security/web/server/authentication/www/HttpBasicAuthenticationEntryPoint.java @@ -21,6 +21,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.security.core.AuthenticationException; import org.springframework.security.web.server.AuthenticationEntryPoint; +import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; @@ -31,12 +32,31 @@ import reactor.core.publisher.Mono; * @since 5.0 */ public class HttpBasicAuthenticationEntryPoint implements AuthenticationEntryPoint { + private static final String WWW_AUTHENTICATE = "WWW-Authenticate"; + private static final String DEFAULT_REALM = "Realm"; + private static String WWW_AUTHENTICATE_FORMAT = "Basic realm=\"%s\""; + + private String headerValue = createHeaderValue(DEFAULT_REALM); @Override - public Mono commence(ServerWebExchange exchange, AuthenticationException e) { - ServerHttpResponse response = exchange.getResponse(); - response.setStatusCode(HttpStatus.UNAUTHORIZED); - response.getHeaders().set("WWW-Authenticate", "Basic realm=\"Realm\""); - return Mono.empty(); + public Mono commence(ServerWebExchange exchange, AuthenticationException e) { + return Mono.fromRunnable(() -> { + ServerHttpResponse response = exchange.getResponse(); + response.setStatusCode(HttpStatus.UNAUTHORIZED); + response.getHeaders().set(WWW_AUTHENTICATE, this.headerValue); + }); + } + + /** + * Sets the realm to be used + * @param realm the realm. Default is "Realm" + */ + public void setRealm(String realm) { + this.headerValue = createHeaderValue(realm); + } + + private static String createHeaderValue(String realm) { + Assert.notNull(realm, "realm cannot be null"); + return String.format(WWW_AUTHENTICATE_FORMAT, realm); } } diff --git a/webflux/src/main/java/org/springframework/security/web/server/authorization/AccessDeniedHandler.java b/webflux/src/main/java/org/springframework/security/web/server/authorization/AccessDeniedHandler.java index d06d63746f..5116f75aec 100644 --- a/webflux/src/main/java/org/springframework/security/web/server/authorization/AccessDeniedHandler.java +++ b/webflux/src/main/java/org/springframework/security/web/server/authorization/AccessDeniedHandler.java @@ -29,5 +29,5 @@ import reactor.core.publisher.Mono; */ public interface AccessDeniedHandler { - Mono handle(ServerWebExchange exchange, AccessDeniedException denied); + Mono handle(ServerWebExchange exchange, AccessDeniedException denied); } diff --git a/webflux/src/main/java/org/springframework/security/web/server/authorization/ExceptionTranslationWebFilter.java b/webflux/src/main/java/org/springframework/security/web/server/authorization/ExceptionTranslationWebFilter.java index cf400875b7..4d2de04220 100644 --- a/webflux/src/main/java/org/springframework/security/web/server/authorization/ExceptionTranslationWebFilter.java +++ b/webflux/src/main/java/org/springframework/security/web/server/authorization/ExceptionTranslationWebFilter.java @@ -17,16 +17,17 @@ */ package org.springframework.security.web.server.authorization; +import reactor.core.publisher.Mono; import org.springframework.http.HttpStatus; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; import org.springframework.security.web.server.AuthenticationEntryPoint; import org.springframework.security.web.server.authentication.www.HttpBasicAuthenticationEntryPoint; +import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; -import reactor.core.publisher.Mono; /** * @@ -34,18 +35,43 @@ import reactor.core.publisher.Mono; * @since 5.0 */ public class ExceptionTranslationWebFilter implements WebFilter { - private AuthenticationEntryPoint entryPoint = new HttpBasicAuthenticationEntryPoint(); + private AuthenticationEntryPoint authenticationEntryPoint = new HttpBasicAuthenticationEntryPoint(); private AccessDeniedHandler accessDeniedHandler = new HttpStatusAccessDeniedHandler(HttpStatus.FORBIDDEN); @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { return chain.filter(exchange) - .onErrorResume(AccessDeniedException.class, denied -> { - return exchange.getPrincipal() - .switchIfEmpty( Mono.defer( () -> entryPoint.commence(exchange, new AuthenticationCredentialsNotFoundException("Not Authenticated", denied)))) - .flatMap( principal -> accessDeniedHandler.handle(exchange, denied)); - }); + .onErrorResume(AccessDeniedException.class, denied -> exchange.getPrincipal() + .switchIfEmpty( commenceAuthentication(exchange, denied)) + .flatMap( principal -> this.accessDeniedHandler.handle(exchange, denied)) + ); } + /** + * Sets the access denied handler. + * @param accessDeniedHandler the access denied handler to use. Default is + * HttpStatusAccessDeniedHandler with HttpStatus.FORBIDDEN + */ + public void setAccessDeniedHandler(AccessDeniedHandler accessDeniedHandler) { + Assert.notNull(accessDeniedHandler, "accessDeniedHandler cannot be null"); + this.accessDeniedHandler = accessDeniedHandler; + } + + /** + * Sets the authentication entry point used when authentication is required + * @param authenticationEntryPoint the authentication entry point to use. Default is + * {@link HttpBasicAuthenticationEntryPoint} + */ + public void setAuthenticationEntryPoint( + AuthenticationEntryPoint authenticationEntryPoint) { + Assert.notNull(authenticationEntryPoint, "authenticationEntryPoint cannot be null"); + this.authenticationEntryPoint = authenticationEntryPoint; + } + + private Mono commenceAuthentication(ServerWebExchange exchange, AccessDeniedException denied) { + return this.authenticationEntryPoint.commence(exchange, new AuthenticationCredentialsNotFoundException("Not Authenticated", denied)) + .then(Mono.empty()); + } } + diff --git a/webflux/src/main/java/org/springframework/security/web/server/authorization/HttpStatusAccessDeniedHandler.java b/webflux/src/main/java/org/springframework/security/web/server/authorization/HttpStatusAccessDeniedHandler.java index 2996952eaa..0d6f802d3b 100644 --- a/webflux/src/main/java/org/springframework/security/web/server/authorization/HttpStatusAccessDeniedHandler.java +++ b/webflux/src/main/java/org/springframework/security/web/server/authorization/HttpStatusAccessDeniedHandler.java @@ -18,13 +18,15 @@ package org.springframework.security.web.server.authorization; -import org.springframework.http.HttpStatus; -import org.springframework.security.access.AccessDeniedException; -import org.springframework.web.server.ServerWebExchange; -import org.springframework.web.server.WebFilter; import reactor.core.publisher.Mono; +import org.springframework.http.HttpStatus; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; + /** + * Sets an HTTP Status that is provided when * @author Rob Winch * @since 5.0 */ @@ -32,12 +34,12 @@ public class HttpStatusAccessDeniedHandler implements AccessDeniedHandler { private final HttpStatus httpStatus; public HttpStatusAccessDeniedHandler(HttpStatus httpStatus) { + Assert.notNull(httpStatus, "httpStatus cannot be null"); this.httpStatus = httpStatus; } @Override - public Mono handle(ServerWebExchange exchange, AccessDeniedException e) { - exchange.getResponse().setStatusCode(HttpStatus.FORBIDDEN); - return Mono.empty(); + public Mono handle(ServerWebExchange exchange, AccessDeniedException e) { + return Mono.fromRunnable(() -> exchange.getResponse().setStatusCode(HttpStatus.FORBIDDEN)); } } diff --git a/webflux/src/test/java/org/springframework/security/web/server/authentication/www/HttpBasicAuthenticationEntryPointTests.java b/webflux/src/test/java/org/springframework/security/web/server/authentication/www/HttpBasicAuthenticationEntryPointTests.java new file mode 100644 index 0000000000..a15f30b9a9 --- /dev/null +++ b/webflux/src/test/java/org/springframework/security/web/server/authentication/www/HttpBasicAuthenticationEntryPointTests.java @@ -0,0 +1,84 @@ +/* + * + * * Copyright 2002-2017 the original author or authors. + * * + * * Licensed under the Apache License, Version 2.0 (the "License"); + * * you may not use this file except in compliance with the License. + * * You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, + * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * * See the License for the specific language governing permissions and + * * limitations under the License. + * + */ + +package org.springframework.security.web.server.authentication.www; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; + +import org.springframework.http.HttpStatus; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; +import org.springframework.security.core.AuthenticationException; +import org.springframework.web.server.ServerWebExchange; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.verifyZeroInteractions; + +/** + * @author Rob Winch + * @since 5.0 + */ +@RunWith(MockitoJUnitRunner.class) +public class HttpBasicAuthenticationEntryPointTests { + @Mock + private ServerWebExchange exchange; + private HttpBasicAuthenticationEntryPoint entryPoint = new HttpBasicAuthenticationEntryPoint(); + + private AuthenticationException exception = new AuthenticationCredentialsNotFoundException("Authenticate"); + + @Test + public void commenceWhenNoSubscribersThenNoActions() { + this.entryPoint.commence(this.exchange, + this.exception); + + verifyZeroInteractions(this.exchange); + } + + @Test + public void commenceWhenSubscribeThenStatusAndHeaderSet() { + this.exchange = MockServerHttpRequest.get("/").toExchange(); + + this.entryPoint.commence(this.exchange, this.exception).block(); + + assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo( + HttpStatus.UNAUTHORIZED); + assertThat(this.exchange.getResponse().getHeaders().get("WWW-Authenticate")).containsOnly( + "Basic realm=\"Realm\""); + } + + @Test + public void commenceWhenCustomRealmThenStatusAndHeaderSet() { + this.entryPoint.setRealm("Custom"); + this.exchange = MockServerHttpRequest.get("/").toExchange(); + + this.entryPoint.commence(this.exchange, this.exception).block(); + + assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo( + HttpStatus.UNAUTHORIZED); + assertThat(this.exchange.getResponse().getHeaders().get("WWW-Authenticate")).containsOnly( + "Basic realm=\"Custom\""); + } + + @Test(expected = IllegalArgumentException.class) + public void setRealmWhenNullThenException() { + this.entryPoint.setRealm(null); + } +} diff --git a/webflux/src/test/java/org/springframework/security/web/server/authorization/ExceptionTranslationWebFilterTests.java b/webflux/src/test/java/org/springframework/security/web/server/authorization/ExceptionTranslationWebFilterTests.java new file mode 100644 index 0000000000..3f8cc3baf2 --- /dev/null +++ b/webflux/src/test/java/org/springframework/security/web/server/authorization/ExceptionTranslationWebFilterTests.java @@ -0,0 +1,178 @@ +/* + * + * * Copyright 2002-2017 the original author or authors. + * * + * * Licensed under the Apache License, Version 2.0 (the "License"); + * * you may not use this file except in compliance with the License. + * * You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, + * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * * See the License for the specific language governing permissions and + * * limitations under the License. + * + */ + +package org.springframework.security.web.server.authorization; + +import java.security.Principal; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.http.HttpStatus; +import org.springframework.mock.http.server.reactive.MockServerHttpResponse; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.web.server.AuthenticationEntryPoint; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilterChain; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.when; + +/** + * @author Rob Winch + * @since 5.0 + */ +@RunWith(MockitoJUnitRunner.class) +public class ExceptionTranslationWebFilterTests { + @Mock + private Principal principal; + @Mock + private ServerWebExchange exchange; + @Mock + private WebFilterChain chain; + @Mock + private AccessDeniedHandler deniedHandler; + @Mock + private AuthenticationEntryPoint entryPoint; + + private TestMono deniedMono = TestMono.create(); + private TestMono entryPointMono = TestMono.create(); + + private ExceptionTranslationWebFilter filter = new ExceptionTranslationWebFilter(); + + @Before + public void setup() { + when(this.exchange.getResponse()).thenReturn(new MockServerHttpResponse()); + when(this.deniedHandler.handle(any(), any())).thenReturn(this.deniedMono.mono()); + when(this.entryPoint.commence(any(), any())).thenReturn(this.entryPointMono.mono()); + + this.filter.setAuthenticationEntryPoint(this.entryPoint); + this.filter.setAccessDeniedHandler(this.deniedHandler); + } + + @Test + public void filterWhenNoExceptionThenNotHandled() { + when(this.chain.filter(this.exchange)).thenReturn(Mono.empty()); + + StepVerifier.create(this.filter.filter(this.exchange, this.chain)) + .expectComplete() + .verify(); + + assertThat(this.deniedMono.isInvoked()).isFalse(); + assertThat(this.entryPointMono.isInvoked()).isFalse(); + } + + @Test + public void filterWhenNotAccessDeniedExceptionThenNotHandled() { + when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new IllegalArgumentException("oops"))); + + StepVerifier.create(this.filter.filter(this.exchange, this.chain)) + .expectError(IllegalArgumentException.class) + .verify(); + + assertThat(this.deniedMono.isInvoked()).isFalse(); + assertThat(this.entryPointMono.isInvoked()).isFalse(); + } + + @Test + public void filterWhenAccessDeniedExceptionAndNotAuthenticatedThenHandled() { + when(this.exchange.getPrincipal()).thenReturn(Mono.empty()); + when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new AccessDeniedException("Not Authorized"))); + + StepVerifier.create(this.filter.filter(this.exchange, this.chain)) + .expectComplete() + .verify(); + + assertThat(this.deniedMono.isInvoked()).isFalse(); + assertThat(this.entryPointMono.isInvoked()).isTrue(); + } + + @Test + public void filterWhenDefaultsAndAccessDeniedExceptionAndAuthenticatedThenForbidden() { + this.filter = new ExceptionTranslationWebFilter(); + when(this.exchange.getPrincipal()).thenReturn(Mono.just(this.principal)); + when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new AccessDeniedException("Not Authorized"))); + + StepVerifier.create(this.filter.filter(this.exchange, this.chain)) + .expectComplete() + .verify(); + + assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo( + HttpStatus.FORBIDDEN); + } + + @Test + public void filterWhenDefaultsAndAccessDeniedExceptionAndNotAuthenticatedThenUnauthorized() { + this.filter = new ExceptionTranslationWebFilter(); + when(this.exchange.getPrincipal()).thenReturn(Mono.empty()); + when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new AccessDeniedException("Not Authorized"))); + + StepVerifier.create(this.filter.filter(this.exchange, this.chain)) + .expectComplete() + .verify(); + + assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo( + HttpStatus.UNAUTHORIZED); + } + + @Test + public void filterWhenAccessDeniedExceptionAndAuthenticatedThenHandled() { + when(this.exchange.getPrincipal()).thenReturn(Mono.just(this.principal)); + when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new AccessDeniedException("Not Authorized"))); + + StepVerifier.create(this.filter.filter(this.exchange, this.chain)) + .expectComplete() + .verify(); + + assertThat(this.deniedMono.isInvoked()).isTrue(); + assertThat(this.entryPointMono.isInvoked()).isFalse(); + } + + @Test(expected = IllegalArgumentException.class) + public void setAccessDeniedHandlerWhenNullThenException() { + this.filter.setAccessDeniedHandler(null); + } + + @Test(expected = IllegalArgumentException.class) + public void setAuthenticationEntryPointWhenNullThenException() { + this.filter.setAuthenticationEntryPoint(null); + } + + static class TestMono { + private final AtomicBoolean invoked = new AtomicBoolean(); + + public Mono mono() { + return Mono.empty().doOnSubscribe(s -> this.invoked.set(true)); + } + + public boolean isInvoked() { + return this.invoked.get(); + } + + public static TestMono create() { + return new TestMono(); + } + } +} diff --git a/webflux/src/test/java/org/springframework/security/web/server/authorization/HttpStatusAccessDeniedHandlerTests.java b/webflux/src/test/java/org/springframework/security/web/server/authorization/HttpStatusAccessDeniedHandlerTests.java new file mode 100644 index 0000000000..3edf1d1e0e --- /dev/null +++ b/webflux/src/test/java/org/springframework/security/web/server/authorization/HttpStatusAccessDeniedHandlerTests.java @@ -0,0 +1,67 @@ +/* + * + * * Copyright 2002-2017 the original author or authors. + * * + * * Licensed under the Apache License, Version 2.0 (the "License"); + * * you may not use this file except in compliance with the License. + * * You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, + * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * * See the License for the specific language governing permissions and + * * limitations under the License. + * + */ + +package org.springframework.security.web.server.authorization; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; + +import org.springframework.http.HttpStatus; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.web.server.ServerWebExchange; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.verifyZeroInteractions; + +/** + * @author Rob Winch + * @since 5.0 + */ +@RunWith(MockitoJUnitRunner.class) +public class HttpStatusAccessDeniedHandlerTests { + @Mock + private ServerWebExchange exchange; + private final HttpStatus httpStatus = HttpStatus.FORBIDDEN; + private HttpStatusAccessDeniedHandler handler = new HttpStatusAccessDeniedHandler(this.httpStatus); + + private AccessDeniedException exception = new AccessDeniedException("Forbidden"); + + @Test(expected = IllegalArgumentException.class) + public void constructorHttpStatusWhenNullThenException() { + new HttpStatusAccessDeniedHandler((HttpStatus) null); + } + + @Test + public void commenceWhenNoSubscribersThenNoActions() { + this.handler.handle(this.exchange, this.exception); + + verifyZeroInteractions(this.exchange); + } + + @Test + public void commenceWhenSubscribeThenStatusSet() { + this.exchange = MockServerHttpRequest.get("/").toExchange(); + + this.handler.handle(this.exchange, this.exception).block(); + + assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo(this.httpStatus); + } +}