diff --git a/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java index 6418564abd..866e6e3487 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java @@ -27,6 +27,7 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -114,18 +115,16 @@ public class BasicAuthenticationFilter extends GenericFilterBean { String header = request.getHeader("Authorization"); - if ((header != null) && header.startsWith("Basic ")) { - byte[] base64Token = header.substring(6).getBytes("UTF-8"); - String token = new String(Base64.decode(base64Token), getCredentialsCharset(request)); + if (header == null || !header.startsWith("Basic ")) { + chain.doFilter(request, response); + return; + } - String username = ""; - String password = ""; - int delim = token.indexOf(":"); + try { + String[] tokens = extractAndDecodeHeader(header, request); + assert tokens.length == 2; - if (delim != -1) { - username = token.substring(0, delim); - password = token.substring(delim + 1); - } + String username = tokens[0]; if (debug) { logger.debug("Basic Authentication Authorization header found for user '" + username + "'"); @@ -133,37 +132,12 @@ public class BasicAuthenticationFilter extends GenericFilterBean { if (authenticationIsRequired(username)) { UsernamePasswordAuthenticationToken authRequest = - new UsernamePasswordAuthenticationToken(username, password); + new UsernamePasswordAuthenticationToken(username, tokens[1]); authRequest.setDetails(authenticationDetailsSource.buildDetails(request)); + Authentication authResult = authenticationManager.authenticate(authRequest); - Authentication authResult; - - try { - authResult = authenticationManager.authenticate(authRequest); - } catch (AuthenticationException failed) { - // Authentication failed - if (debug) { - logger.debug("Authentication request for user: " + username + " failed: " + failed.toString()); - } - - SecurityContextHolder.getContext().setAuthentication(null); - - rememberMeServices.loginFail(request, response); - - onUnsuccessfulAuthentication(request, response, failed); - - if (ignoreFailure) { - chain.doFilter(request, response); - } else { - authenticationEntryPoint.commence(request, response, failed); - } - - return; - } - - // Authentication success if (debug) { - logger.debug("Authentication success: " + authResult.toString()); + logger.debug("Authentication success: " + authResult); } SecurityContextHolder.getContext().setAuthentication(authResult); @@ -172,11 +146,55 @@ public class BasicAuthenticationFilter extends GenericFilterBean { onSuccessfulAuthentication(request, response, authResult); } + + } catch (AuthenticationException failed) { + SecurityContextHolder.clearContext(); + + if (debug) { + logger.debug("Authentication request for failed: " + failed); + } + + rememberMeServices.loginFail(request, response); + + onUnsuccessfulAuthentication(request, response, failed); + + if (ignoreFailure) { + chain.doFilter(request, response); + } else { + authenticationEntryPoint.commence(request, response, failed); + } + + return; } chain.doFilter(request, response); } + /** + * Decodes the header into a username and password. + *
+ * @throws BadCredentialsException if the Basic header is not present or is not valid Base64 + */ + private String[] extractAndDecodeHeader(String header, HttpServletRequest request) throws IOException { + + byte[] base64Token = header.substring(6).getBytes("UTF-8"); + byte[] decoded; + try { + decoded = Base64.decode(base64Token); + } catch (IllegalArgumentException e) { + throw new BadCredentialsException("Failed to decode basic authentication token"); + } + + String token = new String(decoded, getCredentialsCharset(request)); + + int delim = token.indexOf(":"); + + if (delim == -1) { + throw new BadCredentialsException("Invalid basic authentication token"); + } + return new String[] {token.substring(0, delim), token.substring(delim + 1)}; + } + private boolean authenticationIsRequired(String username) { // Only reauthenticate if username doesn't match SecurityContextHolder and user isn't authenticated // (see SEC-53) diff --git a/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java index a9d4b20ba3..b866304352 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java @@ -20,11 +20,7 @@ import static org.mockito.AdditionalMatchers.not; import static org.mockito.Matchers.*; import static org.mockito.Mockito.*; -import java.io.IOException; - -import javax.servlet.Filter; import javax.servlet.FilterChain; -import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; @@ -55,24 +51,9 @@ public class BasicAuthenticationFilterTests { private BasicAuthenticationFilter filter; private AuthenticationManager manager; -// private Mockery jmock = new JUnit4Mockery(); //~ Methods ======================================================================================================== - private MockHttpServletResponse executeFilterInContainerSimulator(Filter filter, final ServletRequest request, - final boolean expectChainToProceed) throws ServletException, IOException { -// filter.init(mock(FilterConfig.class)); - - final MockHttpServletResponse response = new MockHttpServletResponse(); - - FilterChain chain = mock(FilterChain.class); - filter.doFilter(request, response, chain); -// filter.destroy(); - - verify(chain, expectChainToProceed ? times(1) : never()).doFilter(any(ServletRequest.class), any(ServletResponse.class)); - return response; - } - @Before public void setUp() throws Exception { SecurityContextHolder.clearContext(); @@ -97,13 +78,17 @@ public class BasicAuthenticationFilterTests { @Test public void testFilterIgnoresRequestsContainingNoAuthorizationHeader() throws Exception { - // Setup our HTTP request + MockHttpServletRequest request = new MockHttpServletRequest(); request.setServletPath("/some_file.html"); + final MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response, chain); + + verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); // Test - executeFilterInContainerSimulator(filter, request, true); - assertNull(SecurityContextHolder.getContext().getAuthentication()); } @@ -119,47 +104,64 @@ public class BasicAuthenticationFilterTests { @Test public void testInvalidBasicAuthorizationTokenIsIgnored() throws Exception { - // Setup our HTTP request String token = "NOT_A_VALID_TOKEN_AS_MISSING_COLON"; MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes()))); request.setServletPath("/some_file.html"); request.setSession(new MockHttpSession()); + final MockHttpServletResponse response = new MockHttpServletResponse(); - // The filter chain shouldn't proceed - executeFilterInContainerSimulator(filter, request, false); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response, chain); + verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class)); assertNull(SecurityContextHolder.getContext().getAuthentication()); + assertEquals(401, response.getStatus()); + } + + @Test + public void invalidBase64IsIgnored() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader("Authorization", "Basic NOT_VALID_BASE64"); + request.setServletPath("/some_file.html"); + request.setSession(new MockHttpSession()); + final MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response, chain); + // The filter chain shouldn't proceed + verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class)); + assertNull(SecurityContextHolder.getContext().getAuthentication()); + assertEquals(401, response.getStatus()); } @Test public void testNormalOperation() throws Exception { - // Setup our HTTP request String token = "rod:koala"; MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes()))); request.setServletPath("/some_file.html"); -// request.setSession(new MockHttpSession()); // Test assertNull(SecurityContextHolder.getContext().getAuthentication()); - executeFilterInContainerSimulator(filter, request, true); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, new MockHttpServletResponse(), chain); + verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); assertNotNull(SecurityContextHolder.getContext().getAuthentication()); assertEquals("rod", SecurityContextHolder.getContext().getAuthentication().getName()); - } @Test public void testOtherAuthorizationSchemeIsIgnored() throws Exception { - // Setup our HTTP request + MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "SOME_OTHER_AUTHENTICATION_SCHEME"); request.setServletPath("/some_file.html"); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, new MockHttpServletResponse(), chain); - // Test - executeFilterInContainerSimulator(filter, request, true); - + verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); assertNull(SecurityContextHolder.getContext().getAuthentication()); } @@ -179,27 +181,36 @@ public class BasicAuthenticationFilterTests { @Test public void testSuccessLoginThenFailureLoginResultsInSessionLosingToken() throws Exception { - // Setup our HTTP request String token = "rod:koala"; MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes()))); request.setServletPath("/some_file.html"); + final MockHttpServletResponse response1 = new MockHttpServletResponse(); + + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response1, chain); + + verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); // Test - executeFilterInContainerSimulator(filter, request, true); - assertNotNull(SecurityContextHolder.getContext().getAuthentication()); assertEquals("rod", SecurityContextHolder.getContext().getAuthentication().getName()); // NOW PERFORM FAILED AUTHENTICATION - // Setup our HTTP request + token = "otherUser:WRONG_PASSWORD"; request = new MockHttpServletRequest(); request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes()))); + final MockHttpServletResponse response2 = new MockHttpServletResponse(); + + chain = mock(FilterChain.class); + filter.doFilter(request, response2, chain); + + verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class)); request.setServletPath("/some_file.html"); - // Test - the filter chain will not be invoked, as we get a 403 forbidden response - MockHttpServletResponse response = executeFilterInContainerSimulator(filter, request, false); + // Test - the filter chain will not be invoked, as we get a 401 forbidden response + MockHttpServletResponse response = response2; assertNull(SecurityContextHolder.getContext().getAuthentication()); assertEquals(401, response.getStatus()); @@ -207,7 +218,6 @@ public class BasicAuthenticationFilterTests { @Test public void testWrongPasswordContinuesFilterChainIfIgnoreFailureIsTrue() throws Exception { - // Setup our HTTP request String token = "rod:WRONG_PASSWORD"; MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes()))); @@ -216,26 +226,30 @@ public class BasicAuthenticationFilterTests { filter.setIgnoreFailure(true); assertTrue(filter.isIgnoreFailure()); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, new MockHttpServletResponse(), chain); + + verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); // Test - the filter chain will be invoked, as we've set ignoreFailure = true - executeFilterInContainerSimulator(filter, request, true); - assertNull(SecurityContextHolder.getContext().getAuthentication()); } @Test public void testWrongPasswordReturnsForbiddenIfIgnoreFailureIsFalse() throws Exception { - // Setup our HTTP request String token = "rod:WRONG_PASSWORD"; MockHttpServletRequest request = new MockHttpServletRequest(); request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes()))); request.setServletPath("/some_file.html"); request.setSession(new MockHttpSession()); assertFalse(filter.isIgnoreFailure()); + final MockHttpServletResponse response = new MockHttpServletResponse(); - // Test - the filter chain will not be invoked, as we get a 403 forbidden response - MockHttpServletResponse response = executeFilterInContainerSimulator(filter, request, false); + FilterChain chain = mock(FilterChain.class); + filter.doFilter(request, response, chain); + // Test - the filter chain will not be invoked, as we get a 401 forbidden response + verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class)); assertNull(SecurityContextHolder.getContext().getAuthentication()); assertEquals(401, response.getStatus()); }