AuthenticationEntryPoint & AccessDeniedHandler use Mono<Void>

This commit is contained in:
Rob Winch 2017-08-30 16:32:18 -05:00
parent 475f18174d
commit 8f5069053e
8 changed files with 400 additions and 24 deletions

View File

@ -17,12 +17,11 @@
*/ */
package org.springframework.security.web.server; package org.springframework.security.web.server;
import reactor.core.publisher.Mono;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
/** /**
* *
* @author Rob Winch * @author Rob Winch
@ -30,5 +29,5 @@ import reactor.core.publisher.Mono;
*/ */
public interface AuthenticationEntryPoint { public interface AuthenticationEntryPoint {
<T> Mono<T> commence(ServerWebExchange exchange, AuthenticationException e); Mono<Void> commence(ServerWebExchange exchange, AuthenticationException e);
} }

View File

@ -21,6 +21,7 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.server.AuthenticationEntryPoint; import org.springframework.security.web.server.AuthenticationEntryPoint;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -31,12 +32,31 @@ import reactor.core.publisher.Mono;
* @since 5.0 * @since 5.0
*/ */
public class HttpBasicAuthenticationEntryPoint implements AuthenticationEntryPoint { public class HttpBasicAuthenticationEntryPoint implements AuthenticationEntryPoint {
private static final String WWW_AUTHENTICATE = "WWW-Authenticate";
private static final String DEFAULT_REALM = "Realm";
private static String WWW_AUTHENTICATE_FORMAT = "Basic realm=\"%s\"";
private String headerValue = createHeaderValue(DEFAULT_REALM);
@Override @Override
public <T> Mono<T> commence(ServerWebExchange exchange, AuthenticationException e) { public Mono<Void> commence(ServerWebExchange exchange, AuthenticationException e) {
return Mono.fromRunnable(() -> {
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
response.setStatusCode(HttpStatus.UNAUTHORIZED); response.setStatusCode(HttpStatus.UNAUTHORIZED);
response.getHeaders().set("WWW-Authenticate", "Basic realm=\"Realm\""); response.getHeaders().set(WWW_AUTHENTICATE, this.headerValue);
return Mono.empty(); });
}
/**
* Sets the realm to be used
* @param realm the realm. Default is "Realm"
*/
public void setRealm(String realm) {
this.headerValue = createHeaderValue(realm);
}
private static String createHeaderValue(String realm) {
Assert.notNull(realm, "realm cannot be null");
return String.format(WWW_AUTHENTICATE_FORMAT, realm);
} }
} }

View File

@ -29,5 +29,5 @@ import reactor.core.publisher.Mono;
*/ */
public interface AccessDeniedHandler { public interface AccessDeniedHandler {
<T> Mono<T> handle(ServerWebExchange exchange, AccessDeniedException denied); Mono<Void> handle(ServerWebExchange exchange, AccessDeniedException denied);
} }

View File

@ -17,16 +17,17 @@
*/ */
package org.springframework.security.web.server.authorization; package org.springframework.security.web.server.authorization;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
import org.springframework.security.web.server.AuthenticationEntryPoint; import org.springframework.security.web.server.AuthenticationEntryPoint;
import org.springframework.security.web.server.authentication.www.HttpBasicAuthenticationEntryPoint; import org.springframework.security.web.server.authentication.www.HttpBasicAuthenticationEntryPoint;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;
/** /**
* *
@ -34,18 +35,43 @@ import reactor.core.publisher.Mono;
* @since 5.0 * @since 5.0
*/ */
public class ExceptionTranslationWebFilter implements WebFilter { public class ExceptionTranslationWebFilter implements WebFilter {
private AuthenticationEntryPoint entryPoint = new HttpBasicAuthenticationEntryPoint(); private AuthenticationEntryPoint authenticationEntryPoint = new HttpBasicAuthenticationEntryPoint();
private AccessDeniedHandler accessDeniedHandler = new HttpStatusAccessDeniedHandler(HttpStatus.FORBIDDEN); private AccessDeniedHandler accessDeniedHandler = new HttpStatusAccessDeniedHandler(HttpStatus.FORBIDDEN);
@Override @Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) { public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
return chain.filter(exchange) return chain.filter(exchange)
.onErrorResume(AccessDeniedException.class, denied -> { .onErrorResume(AccessDeniedException.class, denied -> exchange.getPrincipal()
return exchange.getPrincipal() .switchIfEmpty( commenceAuthentication(exchange, denied))
.switchIfEmpty( Mono.defer( () -> entryPoint.commence(exchange, new AuthenticationCredentialsNotFoundException("Not Authenticated", denied)))) .flatMap( principal -> this.accessDeniedHandler.handle(exchange, denied))
.flatMap( principal -> accessDeniedHandler.handle(exchange, denied)); );
});
} }
/**
* Sets the access denied handler.
* @param accessDeniedHandler the access denied handler to use. Default is
* HttpStatusAccessDeniedHandler with HttpStatus.FORBIDDEN
*/
public void setAccessDeniedHandler(AccessDeniedHandler accessDeniedHandler) {
Assert.notNull(accessDeniedHandler, "accessDeniedHandler cannot be null");
this.accessDeniedHandler = accessDeniedHandler;
}
/**
* Sets the authentication entry point used when authentication is required
* @param authenticationEntryPoint the authentication entry point to use. Default is
* {@link HttpBasicAuthenticationEntryPoint}
*/
public void setAuthenticationEntryPoint(
AuthenticationEntryPoint authenticationEntryPoint) {
Assert.notNull(authenticationEntryPoint, "authenticationEntryPoint cannot be null");
this.authenticationEntryPoint = authenticationEntryPoint;
}
private <T> Mono<T> commenceAuthentication(ServerWebExchange exchange, AccessDeniedException denied) {
return this.authenticationEntryPoint.commence(exchange, new AuthenticationCredentialsNotFoundException("Not Authenticated", denied))
.then(Mono.empty());
}
} }

View File

@ -18,13 +18,15 @@
package org.springframework.security.web.server.authorization; package org.springframework.security.web.server.authorization;
import org.springframework.http.HttpStatus;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.http.HttpStatus;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
/** /**
* Sets an HTTP Status that is provided when
* @author Rob Winch * @author Rob Winch
* @since 5.0 * @since 5.0
*/ */
@ -32,12 +34,12 @@ public class HttpStatusAccessDeniedHandler implements AccessDeniedHandler {
private final HttpStatus httpStatus; private final HttpStatus httpStatus;
public HttpStatusAccessDeniedHandler(HttpStatus httpStatus) { public HttpStatusAccessDeniedHandler(HttpStatus httpStatus) {
Assert.notNull(httpStatus, "httpStatus cannot be null");
this.httpStatus = httpStatus; this.httpStatus = httpStatus;
} }
@Override @Override
public <T> Mono<T> handle(ServerWebExchange exchange, AccessDeniedException e) { public Mono<Void> handle(ServerWebExchange exchange, AccessDeniedException e) {
exchange.getResponse().setStatusCode(HttpStatus.FORBIDDEN); return Mono.fromRunnable(() -> exchange.getResponse().setStatusCode(HttpStatus.FORBIDDEN));
return Mono.empty();
} }
} }

View File

@ -0,0 +1,84 @@
/*
*
* * Copyright 2002-2017 the original author or authors.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.springframework.security.web.server.authentication.www;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.http.HttpStatus;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.web.server.ServerWebExchange;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.verifyZeroInteractions;
/**
* @author Rob Winch
* @since 5.0
*/
@RunWith(MockitoJUnitRunner.class)
public class HttpBasicAuthenticationEntryPointTests {
@Mock
private ServerWebExchange exchange;
private HttpBasicAuthenticationEntryPoint entryPoint = new HttpBasicAuthenticationEntryPoint();
private AuthenticationException exception = new AuthenticationCredentialsNotFoundException("Authenticate");
@Test
public void commenceWhenNoSubscribersThenNoActions() {
this.entryPoint.commence(this.exchange,
this.exception);
verifyZeroInteractions(this.exchange);
}
@Test
public void commenceWhenSubscribeThenStatusAndHeaderSet() {
this.exchange = MockServerHttpRequest.get("/").toExchange();
this.entryPoint.commence(this.exchange, this.exception).block();
assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo(
HttpStatus.UNAUTHORIZED);
assertThat(this.exchange.getResponse().getHeaders().get("WWW-Authenticate")).containsOnly(
"Basic realm=\"Realm\"");
}
@Test
public void commenceWhenCustomRealmThenStatusAndHeaderSet() {
this.entryPoint.setRealm("Custom");
this.exchange = MockServerHttpRequest.get("/").toExchange();
this.entryPoint.commence(this.exchange, this.exception).block();
assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo(
HttpStatus.UNAUTHORIZED);
assertThat(this.exchange.getResponse().getHeaders().get("WWW-Authenticate")).containsOnly(
"Basic realm=\"Custom\"");
}
@Test(expected = IllegalArgumentException.class)
public void setRealmWhenNullThenException() {
this.entryPoint.setRealm(null);
}
}

View File

@ -0,0 +1,178 @@
/*
*
* * Copyright 2002-2017 the original author or authors.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.springframework.security.web.server.authorization;
import java.security.Principal;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.http.HttpStatus;
import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.web.server.AuthenticationEntryPoint;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.when;
/**
* @author Rob Winch
* @since 5.0
*/
@RunWith(MockitoJUnitRunner.class)
public class ExceptionTranslationWebFilterTests {
@Mock
private Principal principal;
@Mock
private ServerWebExchange exchange;
@Mock
private WebFilterChain chain;
@Mock
private AccessDeniedHandler deniedHandler;
@Mock
private AuthenticationEntryPoint entryPoint;
private TestMono<Void> deniedMono = TestMono.create();
private TestMono<Void> entryPointMono = TestMono.create();
private ExceptionTranslationWebFilter filter = new ExceptionTranslationWebFilter();
@Before
public void setup() {
when(this.exchange.getResponse()).thenReturn(new MockServerHttpResponse());
when(this.deniedHandler.handle(any(), any())).thenReturn(this.deniedMono.mono());
when(this.entryPoint.commence(any(), any())).thenReturn(this.entryPointMono.mono());
this.filter.setAuthenticationEntryPoint(this.entryPoint);
this.filter.setAccessDeniedHandler(this.deniedHandler);
}
@Test
public void filterWhenNoExceptionThenNotHandled() {
when(this.chain.filter(this.exchange)).thenReturn(Mono.empty());
StepVerifier.create(this.filter.filter(this.exchange, this.chain))
.expectComplete()
.verify();
assertThat(this.deniedMono.isInvoked()).isFalse();
assertThat(this.entryPointMono.isInvoked()).isFalse();
}
@Test
public void filterWhenNotAccessDeniedExceptionThenNotHandled() {
when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new IllegalArgumentException("oops")));
StepVerifier.create(this.filter.filter(this.exchange, this.chain))
.expectError(IllegalArgumentException.class)
.verify();
assertThat(this.deniedMono.isInvoked()).isFalse();
assertThat(this.entryPointMono.isInvoked()).isFalse();
}
@Test
public void filterWhenAccessDeniedExceptionAndNotAuthenticatedThenHandled() {
when(this.exchange.getPrincipal()).thenReturn(Mono.empty());
when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new AccessDeniedException("Not Authorized")));
StepVerifier.create(this.filter.filter(this.exchange, this.chain))
.expectComplete()
.verify();
assertThat(this.deniedMono.isInvoked()).isFalse();
assertThat(this.entryPointMono.isInvoked()).isTrue();
}
@Test
public void filterWhenDefaultsAndAccessDeniedExceptionAndAuthenticatedThenForbidden() {
this.filter = new ExceptionTranslationWebFilter();
when(this.exchange.getPrincipal()).thenReturn(Mono.just(this.principal));
when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new AccessDeniedException("Not Authorized")));
StepVerifier.create(this.filter.filter(this.exchange, this.chain))
.expectComplete()
.verify();
assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo(
HttpStatus.FORBIDDEN);
}
@Test
public void filterWhenDefaultsAndAccessDeniedExceptionAndNotAuthenticatedThenUnauthorized() {
this.filter = new ExceptionTranslationWebFilter();
when(this.exchange.getPrincipal()).thenReturn(Mono.empty());
when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new AccessDeniedException("Not Authorized")));
StepVerifier.create(this.filter.filter(this.exchange, this.chain))
.expectComplete()
.verify();
assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo(
HttpStatus.UNAUTHORIZED);
}
@Test
public void filterWhenAccessDeniedExceptionAndAuthenticatedThenHandled() {
when(this.exchange.getPrincipal()).thenReturn(Mono.just(this.principal));
when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new AccessDeniedException("Not Authorized")));
StepVerifier.create(this.filter.filter(this.exchange, this.chain))
.expectComplete()
.verify();
assertThat(this.deniedMono.isInvoked()).isTrue();
assertThat(this.entryPointMono.isInvoked()).isFalse();
}
@Test(expected = IllegalArgumentException.class)
public void setAccessDeniedHandlerWhenNullThenException() {
this.filter.setAccessDeniedHandler(null);
}
@Test(expected = IllegalArgumentException.class)
public void setAuthenticationEntryPointWhenNullThenException() {
this.filter.setAuthenticationEntryPoint(null);
}
static class TestMono<T> {
private final AtomicBoolean invoked = new AtomicBoolean();
public Mono<T> mono() {
return Mono.<T>empty().doOnSubscribe(s -> this.invoked.set(true));
}
public boolean isInvoked() {
return this.invoked.get();
}
public static <T> TestMono<T> create() {
return new TestMono<T>();
}
}
}

View File

@ -0,0 +1,67 @@
/*
*
* * Copyright 2002-2017 the original author or authors.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.springframework.security.web.server.authorization;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.springframework.http.HttpStatus;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.web.server.ServerWebExchange;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.verifyZeroInteractions;
/**
* @author Rob Winch
* @since 5.0
*/
@RunWith(MockitoJUnitRunner.class)
public class HttpStatusAccessDeniedHandlerTests {
@Mock
private ServerWebExchange exchange;
private final HttpStatus httpStatus = HttpStatus.FORBIDDEN;
private HttpStatusAccessDeniedHandler handler = new HttpStatusAccessDeniedHandler(this.httpStatus);
private AccessDeniedException exception = new AccessDeniedException("Forbidden");
@Test(expected = IllegalArgumentException.class)
public void constructorHttpStatusWhenNullThenException() {
new HttpStatusAccessDeniedHandler((HttpStatus) null);
}
@Test
public void commenceWhenNoSubscribersThenNoActions() {
this.handler.handle(this.exchange, this.exception);
verifyZeroInteractions(this.exchange);
}
@Test
public void commenceWhenSubscribeThenStatusSet() {
this.exchange = MockServerHttpRequest.get("/").toExchange();
this.handler.handle(this.exchange, this.exception).block();
assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo(this.httpStatus);
}
}