diff --git a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationConverter.java b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationConverter.java new file mode 100644 index 0000000000..c425aa8878 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationConverter.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.authentication; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; + +/** + * A strategy used for converting from a {@link HttpServletRequest} to an + * {@link Authentication} of particular type. Used to authenticate with + * appropriate {@link AuthenticationManager}. If the result is null, then it + * signals that no authentication attempt should be made. It is also possible to + * throw {@link AuthenticationException} within the + * {@link #convert(HttpServletRequest)} if there was invalid Authentication + * scheme value. + * + * @author Sergey Bespalov + * @since 5.2.0 + */ +public interface AuthenticationConverter { + + Authentication convert(HttpServletRequest request); + +} diff --git a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationEntryPointFailureHandler.java b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationEntryPointFailureHandler.java new file mode 100644 index 0000000000..226f96ad8b --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationEntryPointFailureHandler.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.authentication; + +import java.io.IOException; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.web.AuthenticationEntryPoint; +import org.springframework.util.Assert; + +/** + * Adapts a {@link AuthenticationEntryPoint} into a {@link AuthenticationFailureHandler} + * + * @author sbespalov + * @since 5.2.0 + */ +public class AuthenticationEntryPointFailureHandler implements AuthenticationFailureHandler { + + private final AuthenticationEntryPoint authenticationEntryPoint; + + public AuthenticationEntryPointFailureHandler(AuthenticationEntryPoint authenticationEntryPoint) { + Assert.notNull(authenticationEntryPoint, "authenticationEntryPoint cannot be null"); + this.authenticationEntryPoint = authenticationEntryPoint; + } + + @Override + public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, + AuthenticationException exception) throws IOException, ServletException { + this.authenticationEntryPoint.commence(request, response, exception); + } +} diff --git a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java new file mode 100644 index 0000000000..35aea92af9 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java @@ -0,0 +1,186 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.authentication; + +import java.io.IOException; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.springframework.http.HttpStatus; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.AuthenticationManagerResolver; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.web.util.matcher.AnyRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; +import org.springframework.web.filter.OncePerRequestFilter; + +/** + * A {@link Filter} that performs authentication of a particular request. An + * outline of the logic: + * + * + * + * @author Sergey Bespalov + * @since 5.2.0 + */ +public class AuthenticationFilter extends OncePerRequestFilter { + + private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE; + private AuthenticationConverter authenticationConverter; + private AuthenticationSuccessHandler successHandler = new SavedRequestAwareAuthenticationSuccessHandler(); + private AuthenticationFailureHandler failureHandler = new AuthenticationEntryPointFailureHandler( + new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED)); + private AuthenticationManagerResolver authenticationManagerResolver; + + public AuthenticationFilter(AuthenticationManager authenticationManager, + AuthenticationConverter authenticationConverter) { + this((AuthenticationManagerResolver) r -> authenticationManager, authenticationConverter); + } + + public AuthenticationFilter(AuthenticationManagerResolver authenticationManagerResolver, + AuthenticationConverter authenticationConverter) { + Assert.notNull(authenticationManagerResolver, "authenticationResolverManager cannot be null"); + Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); + + this.authenticationManagerResolver = authenticationManagerResolver; + this.authenticationConverter = authenticationConverter; + } + + public RequestMatcher getRequestMatcher() { + return requestMatcher; + } + + public void setRequestMatcher(RequestMatcher requestMatcher) { + Assert.notNull(requestMatcher, "requestMatcher cannot be null"); + this.requestMatcher = requestMatcher; + } + + public AuthenticationConverter getAuthenticationConverter() { + return authenticationConverter; + } + + public void setAuthenticationConverter(AuthenticationConverter authenticationConverter) { + Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); + this.authenticationConverter = authenticationConverter; + } + + public AuthenticationSuccessHandler getSuccessHandler() { + return successHandler; + } + + public void setSuccessHandler(AuthenticationSuccessHandler successHandler) { + Assert.notNull(successHandler, "successHandler cannot be null"); + this.successHandler = successHandler; + } + + public AuthenticationFailureHandler getFailureHandler() { + return failureHandler; + } + + public void setFailureHandler(AuthenticationFailureHandler failureHandler) { + Assert.notNull(failureHandler, "failureHandler cannot be null"); + this.failureHandler = failureHandler; + } + + public AuthenticationManagerResolver getAuthenticationManagerResolver() { + return authenticationManagerResolver; + } + + public void setAuthenticationManagerResolver( + AuthenticationManagerResolver authenticationManagerResolver) { + Assert.notNull(authenticationManagerResolver, "authenticationManagerResolver cannot be null"); + this.authenticationManagerResolver = authenticationManagerResolver; + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + if (!requestMatcher.matches(request)) { + filterChain.doFilter(request, response); + return; + } + + try { + Authentication authenticationResult = attemptAuthentication(request, response); + if (authenticationResult == null) { + filterChain.doFilter(request, response); + return; + } + + successfulAuthentication(request, response, filterChain, authenticationResult); + } catch (AuthenticationException e) { + unsuccessfulAuthentication(request, response, e); + } + } + + private void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, + AuthenticationException failed) throws IOException, ServletException { + SecurityContextHolder.clearContext(); + failureHandler.onAuthenticationFailure(request, response, failed); + } + + protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, + Authentication authentication) throws IOException, ServletException { + SecurityContext context = SecurityContextHolder.createEmptyContext(); + context.setAuthentication(authentication); + SecurityContextHolder.setContext(context); + + successHandler.onAuthenticationSuccess(request, response, chain, authentication); + } + + public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) + throws AuthenticationException, IOException, ServletException { + Authentication authentication = authenticationConverter.convert(request); + if (authentication == null) { + return null; + } + + AuthenticationManager authenticationManager = authenticationManagerResolver.resolve(request); + Authentication authenticationResult = authenticationManager.authenticate(authentication); + if (authenticationResult == null) { + throw new ServletException("AuthenticationManager should not return null Authentication object."); + } + + return authenticationResult; + } + +} diff --git a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationSuccessHandler.java b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationSuccessHandler.java index 2aab8b8bc3..0909d8c191 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/AuthenticationSuccessHandler.java +++ b/web/src/main/java/org/springframework/security/web/authentication/AuthenticationSuccessHandler.java @@ -17,6 +17,7 @@ package org.springframework.security.web.authentication; import java.io.IOException; +import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -38,6 +39,23 @@ import org.springframework.security.core.Authentication; */ public interface AuthenticationSuccessHandler { + /** + * Called when a user has been successfully authenticated. + * + * @param request the request which caused the successful authentication + * @param response the response + * @param chain the {@link FilterChain} which can be used to proceed other filters in the chain + * @param authentication the Authentication object which was created during + * the authentication process. + * @since 5.2.0 + */ + default void onAuthenticationSuccess(HttpServletRequest request, + HttpServletResponse response, FilterChain chain, Authentication authentication) + throws IOException, ServletException{ + onAuthenticationSuccess(request, response, authentication); + chain.doFilter(request, response); + } + /** * Called when a user has been successfully authenticated. * diff --git a/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationConverter.java b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationConverter.java new file mode 100644 index 0000000000..4da6fb5adf --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationConverter.java @@ -0,0 +1,121 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.authentication.www; + +import static org.springframework.http.HttpHeaders.AUTHORIZATION; + +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.security.authentication.AuthenticationDetailsSource; +import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.security.authentication.InternalAuthenticationServiceException; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.util.Assert; + +/** + * Converts from a HttpServletRequest to + * {@link UsernamePasswordAuthenticationToken} that can be authenticated. Null + * authentication possible if there was no Authorization header with Basic + * authentication scheme. + * + * @author Sergey Bespalov + * @since 5.2.0 + */ +public class BasicAuthenticationConverter implements AuthenticationConverter { + + public static final String AUTHENTICATION_SCHEME_BASIC = "Basic"; + + private AuthenticationDetailsSource authenticationDetailsSource; + + private String credentialsCharset = StandardCharsets.UTF_8.name(); + + public BasicAuthenticationConverter() { + this(new WebAuthenticationDetailsSource()); + } + + public BasicAuthenticationConverter( + AuthenticationDetailsSource authenticationDetailsSource) { + this.authenticationDetailsSource = authenticationDetailsSource; + } + + public String getCredentialsCharset() { + return credentialsCharset; + } + + public void setCredentialsCharset(String credentialsCharset) { + this.credentialsCharset = credentialsCharset; + } + + public AuthenticationDetailsSource getAuthenticationDetailsSource() { + return authenticationDetailsSource; + } + + public void setAuthenticationDetailsSource( + AuthenticationDetailsSource authenticationDetailsSource) { + Assert.notNull(authenticationDetailsSource, "AuthenticationDetailsSource required"); + this.authenticationDetailsSource = authenticationDetailsSource; + } + + @Override + public UsernamePasswordAuthenticationToken convert(HttpServletRequest request) { + String header = request.getHeader(AUTHORIZATION); + if (header == null) { + return null; + } + + header = header.trim(); + if (!header.startsWith(AUTHENTICATION_SCHEME_BASIC) && !header.startsWith(AUTHENTICATION_SCHEME_BASIC.toLowerCase())) { + return null; + } + + byte[] base64Token = header.substring(6).getBytes(); + byte[] decoded; + try { + decoded = Base64.getDecoder().decode(base64Token); + } catch (IllegalArgumentException e) { + throw new BadCredentialsException("Failed to decode basic authentication token"); + } + + String token; + try { + token = new String(decoded, getCredentialsCharset(request)); + } catch (UnsupportedEncodingException e) { + throw new InternalAuthenticationServiceException(e.getMessage(), e); + } + + String[] tokens = token.split(":"); + if (tokens.length != 2) { + throw new BadCredentialsException("Invalid basic authentication token"); + } + + UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken(tokens[0], + tokens[1]); + authentication.setDetails(authenticationDetailsSource.buildDetails(request)); + + return authentication; + } + + protected String getCredentialsCharset(HttpServletRequest request) { + return getCredentialsCharset(); + } + +} diff --git a/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java index 1872d36793..dfc9bb3703 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java @@ -17,7 +17,6 @@ package org.springframework.security.web.authentication.www; import java.io.IOException; -import java.util.Base64; import javax.servlet.FilterChain; import javax.servlet.ServletException; @@ -27,7 +26,6 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; -import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -35,9 +33,7 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.AuthenticationEntryPoint; import org.springframework.security.web.authentication.NullRememberMeServices; import org.springframework.security.web.authentication.RememberMeServices; -import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; /** @@ -95,12 +91,12 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { // ~ Instance fields // ================================================================================================ - private AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); private AuthenticationEntryPoint authenticationEntryPoint; private AuthenticationManager authenticationManager; private RememberMeServices rememberMeServices = new NullRememberMeServices(); private boolean ignoreFailure = false; private String credentialsCharset = "UTF-8"; + private BasicAuthenticationConverter authenticationConverter = new BasicAuthenticationConverter(); /** * Creates an instance which will authenticate against the supplied @@ -152,19 +148,14 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { HttpServletResponse response, FilterChain chain) throws IOException, ServletException { final boolean debug = this.logger.isDebugEnabled(); - - String header = request.getHeader("Authorization"); - - if (!StringUtils.startsWithIgnoreCase(header, "basic ")) { - chain.doFilter(request, response); - return; - } - try { - String[] tokens = extractAndDecodeHeader(header, request); - assert tokens.length == 2; + UsernamePasswordAuthenticationToken authRequest = authenticationConverter.convert(request); + if (authRequest == null) { + chain.doFilter(request, response); + return; + } - String username = tokens[0]; + String username = authRequest.getName(); if (debug) { this.logger @@ -173,10 +164,6 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { } if (authenticationIsRequired(username)) { - UsernamePasswordAuthenticationToken authRequest = new UsernamePasswordAuthenticationToken( - username, tokens[1]); - authRequest.setDetails( - this.authenticationDetailsSource.buildDetails(request)); Authentication authResult = this.authenticationManager .authenticate(authRequest); @@ -216,35 +203,6 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { chain.doFilter(request, response); } - /** - * Decodes the header into a username and password. - * - * @throws BadCredentialsException if the Basic header is not present or is not valid - * Base64 - */ - private String[] extractAndDecodeHeader(String header, HttpServletRequest request) - throws IOException { - - byte[] base64Token = header.substring(6).getBytes("UTF-8"); - byte[] decoded; - try { - decoded = Base64.getDecoder().decode(base64Token); - } - catch (IllegalArgumentException e) { - throw new BadCredentialsException( - "Failed to decode basic authentication token"); - } - - String token = new String(decoded, getCredentialsCharset(request)); - - int delim = token.indexOf(":"); - - if (delim == -1) { - throw new BadCredentialsException("Invalid basic authentication token"); - } - return new String[] { token.substring(0, delim), token.substring(delim + 1) }; - } - private boolean authenticationIsRequired(String username) { // Only reauthenticate if username doesn't match SecurityContextHolder and user // isn't authenticated @@ -308,9 +266,7 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter { public void setAuthenticationDetailsSource( AuthenticationDetailsSource authenticationDetailsSource) { - Assert.notNull(authenticationDetailsSource, - "AuthenticationDetailsSource required"); - this.authenticationDetailsSource = authenticationDetailsSource; + authenticationConverter.setAuthenticationDetailsSource(authenticationDetailsSource); } public void setRememberMeServices(RememberMeServices rememberMeServices) { diff --git a/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java index ccab24af43..e97da82689 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java @@ -99,7 +99,7 @@ public class AbstractAuthenticationProcessingFilterTests { } @Test - public void testDefaultProcessesFilterUrlMatchesWithPathParameter() { + public void testDefaultProcessesFilterUrlMatchesWithPathParameter() throws Exception { MockHttpServletRequest request = createMockAuthenticationRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); MockAuthenticationFilter filter = new MockAuthenticationFilter(); diff --git a/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java new file mode 100644 index 0000000000..85e3608bed --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java @@ -0,0 +1,249 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.authentication; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.http.HttpStatus; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.AuthenticationManagerResolver; +import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.web.util.matcher.RequestMatcher; + +/** + * @author Sergey Bespalov + * @since 5.2.0 + */ +@RunWith(MockitoJUnitRunner.class) +public class AuthenticationFilterTests { + + @Mock + private AuthenticationSuccessHandler successHandler; + @Mock + private AuthenticationConverter authenticationConverter; + @Mock + private AuthenticationManager authenticationManager; + @Mock + private AuthenticationFailureHandler failureHandler; + @Mock + private AuthenticationManagerResolver authenticationManagerResolver; + @Mock + private RequestMatcher requestMatcher; + + @Before + public void setup() { + when(this.authenticationManagerResolver.resolve(any())).thenReturn(this.authenticationManager); + } + + @After + public void clearContext() throws Exception { + SecurityContextHolder.clearContext(); + } + + @Test + public void filterWhenDefaultsAndNoAuthenticationThenContinues() throws Exception { + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager, this.authenticationConverter); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response, chain); + + verifyZeroInteractions(this.authenticationManager); + verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + } + + @Test + public void filterWhenAuthenticationManagerResolverDefaultsAndNoAuthenticationThenContinues() throws Exception { + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManagerResolver, this.authenticationConverter); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response, chain); + + verifyZeroInteractions(this.authenticationManagerResolver); + verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + } + + @Test + public void filterWhenDefaultsAndAuthenticationSuccessThenContinues() throws Exception { + Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE"); + when(this.authenticationConverter.convert(any())).thenReturn(authentication); + when(this.authenticationManager.authenticate(any())).thenReturn(authentication); + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager, this.authenticationConverter); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response, chain); + + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull(); + } + + @Test + public void filterWhenAuthenticationManagerResolverDefaultsAndAuthenticationSuccessThenContinues() + throws Exception { + Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE"); + when(this.authenticationConverter.convert(any())).thenReturn(authentication); + when(this.authenticationManager.authenticate(any())).thenReturn(authentication); + + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManagerResolver, this.authenticationConverter); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response, chain); + + verify(this.authenticationManager).authenticate(any(Authentication.class)); + verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull(); + } + + @Test + public void filterWhenDefaultsAndAuthenticationFailThenUnauthorized() throws Exception { + Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE"); + when(this.authenticationConverter.convert(any())).thenReturn(authentication); + when(this.authenticationManager.authenticate(any())).thenThrow(new BadCredentialsException("failed")); + + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager, this.authenticationConverter); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response, chain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value()); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + } + + @Test + public void filterWhenAuthenticationManagerResolverDefaultsAndAuthenticationFailThenUnauthorized() + throws Exception { + Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE"); + when(this.authenticationConverter.convert(any())).thenReturn(authentication); + when(this.authenticationManager.authenticate(any())).thenThrow(new BadCredentialsException("failed")); + + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManagerResolver, this.authenticationConverter); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response, chain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value()); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + } + + @Test + public void filterWhenConvertEmptyThenOk() throws Exception { + when(this.authenticationConverter.convert(any())).thenReturn(null); + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManagerResolver, this.authenticationConverter); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, new MockHttpServletResponse(), chain); + + verifyZeroInteractions(this.authenticationManagerResolver); + verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + } + + @Test + public void filterWhenConvertAndAuthenticationSuccessThenSuccess() throws Exception { + Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER"); + when(this.authenticationConverter.convert(any())).thenReturn(authentication); + when(this.authenticationManager.authenticate(any())).thenReturn(authentication); + + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManagerResolver, this.authenticationConverter); + filter.setSuccessHandler(successHandler); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response, chain); + + verify(this.successHandler).onAuthenticationSuccess(any(), any(), any(), eq(authentication)); + verifyZeroInteractions(this.failureHandler); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull(); + } + + @Test(expected = ServletException.class) + public void filterWhenConvertAndAuthenticationEmptyThenServerError() throws Exception { + Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER"); + when(this.authenticationConverter.convert(any())).thenReturn(authentication); + when(this.authenticationManager.authenticate(any())).thenReturn(null); + + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManagerResolver, this.authenticationConverter); + filter.setSuccessHandler(successHandler); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = mock(FilterChain.class); + try { + filter.doFilter(request, response, chain); + } catch (ServletException e) { + verifyZeroInteractions(this.successHandler); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + + throw e; + } + } + + @Test + public void filterWhenNotMatchAndConvertAndAuthenticationSuccessThenContinues() throws Exception { + when(this.requestMatcher.matches(any())).thenReturn(false); + + AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManagerResolver, this.authenticationConverter); + filter.setRequestMatcher(this.requestMatcher); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response, chain); + + verifyZeroInteractions(this.authenticationConverter, this.authenticationManagerResolver, this.successHandler); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + } + +} diff --git a/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationConverterTests.java b/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationConverterTests.java new file mode 100644 index 0000000000..abfdaf548f --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationConverterTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.authentication.www; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; + +import javax.servlet.http.HttpServletRequest; + +import org.apache.commons.codec.binary.Base64; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.authentication.AuthenticationDetailsSource; +import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; + +/** + * @author Sergey Bespalov + * @since 5.2.0 + */ +@RunWith(MockitoJUnitRunner.class) +public class BasicAuthenticationConverterTests { + + @Mock + private AuthenticationDetailsSource authenticationDetailsSource; + private BasicAuthenticationConverter converter; + + @Before + public void setup() { + converter = new BasicAuthenticationConverter(authenticationDetailsSource); + } + + @Test + public void testNormalOperation() throws Exception { + String token = "rod:koala"; + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes()))); + UsernamePasswordAuthenticationToken authentication = converter.convert(request); + + verify(authenticationDetailsSource).buildDetails(any()); + assertThat(authentication).isNotNull(); + assertThat(authentication.getName()).isEqualTo("rod"); + } + + @Test + public void testWhenUnsupportedAuthorizationHeaderThenIgnored() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Authorization", "Bearer someOtherToken"); + UsernamePasswordAuthenticationToken authentication = converter.convert(request); + + verifyZeroInteractions(authenticationDetailsSource); + assertThat(authentication).isNull(); + } + + @Test(expected = BadCredentialsException.class) + public void testWhenInvalidBasicAuthorizationTokenThenError() throws Exception { + String token = "NOT_A_VALID_TOKEN_AS_MISSING_COLON"; + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes()))); + converter.convert(request); + } + + @Test(expected = BadCredentialsException.class) + public void testWhenInvalidBase64ThenError() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Authorization", "Basic NOT_VALID_BASE64"); + + converter.convert(request); + } + +}