diff --git a/config/src/main/java/org/springframework/security/config/web/server/HttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/HttpSecurity.java index 2e06168e28..8acec59f90 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/HttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/HttpSecurity.java @@ -31,7 +31,6 @@ import org.springframework.security.web.server.HttpBasicAuthenticationConverter; import org.springframework.security.web.server.MatcherSecurityWebFilterChain; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.authentication.AuthenticationWebFilter; -import org.springframework.security.web.server.authentication.DefaultAuthenticationSuccessHandler; import org.springframework.security.web.server.authentication.www.HttpBasicAuthenticationEntryPoint; import org.springframework.security.web.server.authorization.AuthorizationContext; import org.springframework.security.web.server.authorization.AuthorizationWebFilter; @@ -40,6 +39,7 @@ import org.springframework.security.web.server.context.AuthenticationReactorCont import org.springframework.security.web.server.context.SecurityContextRepositoryWebFilter; import org.springframework.security.web.server.authorization.ExceptionTranslationWebFilter; import org.springframework.security.web.server.context.SecurityContextRepository; +import org.springframework.security.web.server.context.ServerWebExchangeAttributeSecurityContextRepository; import org.springframework.security.web.server.header.CacheControlHttpHeadersWriter; import org.springframework.security.web.server.header.CompositeHttpHeadersWriter; import org.springframework.security.web.server.header.ContentTypeOptionsHttpHeadersWriter; @@ -232,7 +232,7 @@ public class HttpSecurity { public class HttpBasicBuilder { private ReactiveAuthenticationManager authenticationManager; - private SecurityContextRepository securityContextRepository; + private SecurityContextRepository securityContextRepository = new ServerWebExchangeAttributeSecurityContextRepository(); private AuthenticationEntryPoint entryPoint = new HttpBasicAuthenticationEntryPoint(); @@ -261,9 +261,7 @@ public class HttpSecurity { authenticationFilter.setEntryPoint(this.entryPoint); authenticationFilter.setAuthenticationConverter(new HttpBasicAuthenticationConverter()); if(this.securityContextRepository != null) { - DefaultAuthenticationSuccessHandler handler = new DefaultAuthenticationSuccessHandler(); - handler.setSecurityContextRepository(this.securityContextRepository); - authenticationFilter.setAuthenticationSuccessHandler(handler); + authenticationFilter.setSecurityContextRepository(this.securityContextRepository); } return authenticationFilter; } diff --git a/webflux/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java b/webflux/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java index 363278becd..b4733c3c4a 100644 --- a/webflux/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java +++ b/webflux/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java @@ -20,9 +20,12 @@ package org.springframework.security.web.server.authentication; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.web.server.AuthenticationEntryPoint; import org.springframework.security.web.server.HttpBasicAuthenticationConverter; import org.springframework.security.web.server.authentication.www.HttpBasicAuthenticationEntryPoint; +import org.springframework.security.web.server.context.SecurityContextRepository; +import org.springframework.security.web.server.context.ServerWebExchangeAttributeSecurityContextRepository; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; @@ -40,12 +43,14 @@ public class AuthenticationWebFilter implements WebFilter { private final ReactiveAuthenticationManager authenticationManager; - private AuthenticationSuccessHandler authenticationSuccessHandler = new DefaultAuthenticationSuccessHandler(); + private AuthenticationSuccessHandler authenticationSuccessHandler = new WebFilterChainAuthenticationSuccessHandler(); private Function> authenticationConverter = new HttpBasicAuthenticationConverter(); private AuthenticationEntryPoint entryPoint = new HttpBasicAuthenticationEntryPoint(); + private SecurityContextRepository securityContextRepository = new ServerWebExchangeAttributeSecurityContextRepository(); + public AuthenticationWebFilter(ReactiveAuthenticationManager authenticationManager) { Assert.notNull(authenticationManager, "authenticationManager cannot be null"); this.authenticationManager = authenticationManager; @@ -56,11 +61,24 @@ public class AuthenticationWebFilter implements WebFilter { return this.authenticationConverter.apply(exchange) .switchIfEmpty(Mono.defer(() -> chain.filter(exchange).cast(Authentication.class))) .flatMap( token -> this.authenticationManager.authenticate(token) - .flatMap(authentication -> this.authenticationSuccessHandler.success(authentication, exchange, chain)) + .flatMap(authentication -> onAuthenticationSuccess(authentication, exchange, chain)) .onErrorResume( AuthenticationException.class, t -> this.entryPoint.commence(exchange, t)) ); } + private Mono onAuthenticationSuccess(Authentication authentication, ServerWebExchange exchange, WebFilterChain chain) { + SecurityContextImpl securityContext = new SecurityContextImpl(); + securityContext.setAuthentication(authentication); + return this.securityContextRepository.save(exchange, securityContext) + .flatMap( wrappedExchange -> this.authenticationSuccessHandler.success(authentication, wrappedExchange, chain)); + } + + public void setSecurityContextRepository( + SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) { this.authenticationSuccessHandler = authenticationSuccessHandler; } diff --git a/webflux/src/main/java/org/springframework/security/web/server/authentication/DefaultAuthenticationSuccessHandler.java b/webflux/src/main/java/org/springframework/security/web/server/authentication/DefaultAuthenticationSuccessHandler.java deleted file mode 100644 index 097e9bce5f..0000000000 --- a/webflux/src/main/java/org/springframework/security/web/server/authentication/DefaultAuthenticationSuccessHandler.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * - * * Copyright 2002-2017 the original author or authors. - * * - * * Licensed under the Apache License, Version 2.0 (the "License"); - * * you may not use this file except in compliance with the License. - * * You may obtain a copy of the License at - * * - * * http://www.apache.org/licenses/LICENSE-2.0 - * * - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, - * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * * See the License for the specific language governing permissions and - * * limitations under the License. - * - */ - -package org.springframework.security.web.server.authentication; - -import org.springframework.security.core.Authentication; -import org.springframework.security.core.context.SecurityContextImpl; -import org.springframework.security.web.server.context.SecurityContextRepository; -import org.springframework.security.web.server.context.ServerWebExchangeAttributeSecurityContextRepository; -import org.springframework.util.Assert; -import org.springframework.web.server.ServerWebExchange; -import org.springframework.web.server.WebFilterChain; -import reactor.core.publisher.Mono; - -/** - * @author Rob Winch - * @since 5.0 - */ -public class DefaultAuthenticationSuccessHandler implements AuthenticationSuccessHandler { - private SecurityContextRepository securityContextRepository = new ServerWebExchangeAttributeSecurityContextRepository(); - - private AuthenticationSuccessHandler delegate = new WebFilterChainAuthenticationSuccessHandler(); - - @Override - public Mono success(Authentication authentication, ServerWebExchange exchange, WebFilterChain chain) { - SecurityContextImpl securityContext = new SecurityContextImpl(); - securityContext.setAuthentication(authentication); - return securityContextRepository.save(exchange, securityContext) - .flatMap( wrappedExchange -> delegate.success(authentication, wrappedExchange, chain)); - } - - public void setDelegate(AuthenticationSuccessHandler delegate) { - Assert.notNull(delegate, "delegate cannot be null"); - this.delegate = delegate; - } - - public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { - Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); - this.securityContextRepository = securityContextRepository; - } -} diff --git a/webflux/src/test/java/org/springframework/security/web/server/authentication/AuthenticationWebFilterTests.java b/webflux/src/test/java/org/springframework/security/web/server/authentication/AuthenticationWebFilterTests.java index afc4b98747..7093879543 100644 --- a/webflux/src/test/java/org/springframework/security/web/server/authentication/AuthenticationWebFilterTests.java +++ b/webflux/src/test/java/org/springframework/security/web/server/authentication/AuthenticationWebFilterTests.java @@ -18,27 +18,30 @@ package org.springframework.security.web.server.authentication; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; +import reactor.core.publisher.Mono; + import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import org.springframework.security.web.server.AuthenticationEntryPoint; +import org.springframework.security.web.server.context.SecurityContextRepository; import org.springframework.test.web.reactive.server.EntityExchangeResult; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; - -import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; @@ -59,6 +62,8 @@ public class AuthenticationWebFilterTests { private ReactiveAuthenticationManager authenticationManager; @Mock private AuthenticationEntryPoint entryPoint; + @Mock + private SecurityContextRepository securityContextRepository; private AuthenticationWebFilter filter; @@ -68,6 +73,7 @@ public class AuthenticationWebFilterTests { this.filter.setAuthenticationSuccessHandler(this.successHandler); this.filter.setAuthenticationConverter(this.authenticationConverter); this.filter.setEntryPoint(this.entryPoint); + this.filter.setSecurityContextRepository(this.securityContextRepository); } @Test @@ -151,6 +157,7 @@ public class AuthenticationWebFilterTests { .expectBody(String.class).consumeWith(b -> assertThat(b.getResponseBody()).isEqualTo("ok")) .returnResult(); + verify(this.securityContextRepository, never()).save(any(), any()); verifyZeroInteractions(this.authenticationManager, this.successHandler, this.entryPoint); } @@ -170,16 +177,18 @@ public class AuthenticationWebFilterTests { .expectStatus().is5xxServerError() .expectBody().isEmpty(); + verify(this.securityContextRepository, never()).save(any(), any()); verifyZeroInteractions(this.authenticationManager, this.successHandler, this.entryPoint); } @Test - public void filterWhenConvertAndAuthenticationSuccessThenSuccessHandler() { + public void filterWhenConvertAndAuthenticationSuccessThenSuccess() { Mono authentication = Mono.just(new TestingAuthenticationToken("test", "this", "ROLE_USER")); when(this.authenticationConverter.apply(any())).thenReturn(authentication); when(this.authenticationManager.authenticate(any())).thenReturn(authentication); when(this.successHandler.success(any(),any(),any())).thenReturn(Mono.empty()); + when(this.securityContextRepository.save(any(),any())).thenAnswer( a -> Mono.just(a.getArguments()[0])); WebTestClient client = WebTestClientBuilder .bindToWebFilters(this.filter) @@ -193,6 +202,7 @@ public class AuthenticationWebFilterTests { .expectBody().isEmpty(); verify(this.successHandler).success(eq(authentication.block()), any(), any()); + verify(this.securityContextRepository).save(any(), any()); verifyZeroInteractions(this.entryPoint); } @@ -215,6 +225,7 @@ public class AuthenticationWebFilterTests { .expectBody().isEmpty(); verify(this.entryPoint).commence(any(),any()); + verify(this.securityContextRepository, never()).save(any(), any()); verifyZeroInteractions(this.successHandler); } @@ -236,6 +247,7 @@ public class AuthenticationWebFilterTests { .expectStatus().is5xxServerError() .expectBody().isEmpty(); + verify(this.securityContextRepository, never()).save(any(), any()); verifyZeroInteractions(this.successHandler, this.entryPoint); } }