diff --git a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java index ffa9701ad3..fb89e4b4f2 100644 --- a/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java +++ b/oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java @@ -38,6 +38,8 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.security.web.context.NullSecurityContextRepository; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; @@ -75,6 +77,8 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter private AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); + private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository(); + /** * Construct a {@code BearerTokenAuthenticationFilter} using the provided parameter(s) * @param authenticationManagerResolver @@ -131,6 +135,7 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter SecurityContext context = SecurityContextHolder.createEmptyContext(); context.setAuthentication(authenticationResult); SecurityContextHolder.setContext(context); + this.securityContextRepository.saveContext(context, request, response); if (this.logger.isDebugEnabled()) { this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authenticationResult)); } @@ -143,6 +148,18 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter } } + /** + * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on + * authentication success. The default action is not to save the + * {@link SecurityContext}. + * @param securityContextRepository the {@link SecurityContextRepository} to use. + * Cannot be null. + */ + public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) { + Assert.notNull(securityContextRepository, "securityContextRepository cannot be null"); + this.securityContextRepository = securityContextRepository; + } + /** * Set the {@link BearerTokenResolver} to use. Defaults to * {@link DefaultBearerTokenResolver}. diff --git a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java index 418ca85c0a..5f42377410 100644 --- a/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java +++ b/oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java @@ -36,18 +36,23 @@ import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManagerResolver; import org.springframework.security.authentication.AuthenticationServiceException; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken; import org.springframework.security.oauth2.server.resource.BearerTokenError; import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.context.SecurityContextRepository; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -102,6 +107,26 @@ public class BearerTokenAuthenticationFilterTests { assertThat(captor.getValue().getPrincipal()).isEqualTo("token"); } + @Test + public void doFilterWhenSecurityContextRepositoryThenSaves() throws ServletException, IOException { + SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class); + String token = "token"; + given(this.bearerTokenResolver.resolve(this.request)).willReturn(token); + TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("test", "password"); + given(this.authenticationManager.authenticate(any())).willReturn(expectedAuthentication); + BearerTokenAuthenticationFilter filter = addMocks( + new BearerTokenAuthenticationFilter(this.authenticationManager)); + filter.setSecurityContextRepository(securityContextRepository); + filter.doFilter(this.request, this.response, this.filterChain); + ArgumentCaptor captor = ArgumentCaptor + .forClass(BearerTokenAuthenticationToken.class); + verify(this.authenticationManager).authenticate(captor.capture()); + assertThat(captor.getValue().getPrincipal()).isEqualTo(token); + ArgumentCaptor contextArg = ArgumentCaptor.forClass(SecurityContext.class); + verify(securityContextRepository).saveContext(contextArg.capture(), eq(this.request), eq(this.response)); + assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo(expectedAuthentication.getName()); + } + @Test public void doFilterWhenUsingAuthenticationManagerResolverThenAuthenticates() throws Exception { BearerTokenAuthenticationFilter filter = addMocks(