diff --git a/web/src/main/java/org/springframework/security/web/access/intercept/AuthorizationFilter.java b/web/src/main/java/org/springframework/security/web/access/intercept/AuthorizationFilter.java index 19bcbfdc11..0dc292c2ff 100644 --- a/web/src/main/java/org/springframework/security/web/access/intercept/AuthorizationFilter.java +++ b/web/src/main/java/org/springframework/security/web/access/intercept/AuthorizationFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,8 +18,11 @@ package org.springframework.security.web.access.intercept; import java.io.IOException; +import javax.servlet.DispatcherType; import javax.servlet.FilterChain; import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -28,7 +31,7 @@ import org.springframework.security.authorization.AuthorizationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.util.Assert; -import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.filter.GenericFilterBean; /** * An authorization filter that restricts access to the URL using @@ -37,10 +40,16 @@ import org.springframework.web.filter.OncePerRequestFilter; * @author Evgeniy Cheban * @since 5.5 */ -public class AuthorizationFilter extends OncePerRequestFilter { +public class AuthorizationFilter extends GenericFilterBean { private final AuthorizationManager authorizationManager; + private boolean observeOncePerRequest = true; + + private boolean filterErrorDispatch = false; + + private boolean filterAsyncDispatch = false; + /** * Creates an instance. * @param authorizationManager the {@link AuthorizationManager} to use @@ -51,11 +60,53 @@ public class AuthorizationFilter extends OncePerRequestFilter { } @Override - protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain) throws ServletException, IOException { - this.authorizationManager.verify(this::getAuthentication, request); - filterChain.doFilter(request, response); + HttpServletRequest request = (HttpServletRequest) servletRequest; + HttpServletResponse response = (HttpServletResponse) servletResponse; + + if (this.observeOncePerRequest && isApplied(request)) { + chain.doFilter(request, response); + return; + } + + if (skipDispatch(request)) { + chain.doFilter(request, response); + return; + } + + String alreadyFilteredAttributeName = getAlreadyFilteredAttributeName(); + request.setAttribute(alreadyFilteredAttributeName, Boolean.TRUE); + try { + this.authorizationManager.verify(this::getAuthentication, request); + chain.doFilter(request, response); + } + finally { + request.removeAttribute(alreadyFilteredAttributeName); + } + } + + private boolean skipDispatch(HttpServletRequest request) { + if (DispatcherType.ERROR.equals(request.getDispatcherType()) && !this.filterErrorDispatch) { + return true; + } + if (DispatcherType.ASYNC.equals(request.getDispatcherType()) && !this.filterAsyncDispatch) { + return true; + } + return false; + } + + private boolean isApplied(HttpServletRequest request) { + return request.getAttribute(getAlreadyFilteredAttributeName()) != null; + } + + private String getAlreadyFilteredAttributeName() { + String name = getFilterName(); + if (name == null) { + name = getClass().getName(); + } + return name + ".APPLIED"; } private Authentication getAuthentication() { @@ -75,4 +126,38 @@ public class AuthorizationFilter extends OncePerRequestFilter { return this.authorizationManager; } + public boolean isObserveOncePerRequest() { + return this.observeOncePerRequest; + } + + /** + * Sets whether this filter apply only once per request. By default, this is + * true, meaning the filter will only execute once per request. Sometimes + * users may wish it to execute more than once per request, such as when JSP forwards + * are being used and filter security is desired on each included fragment of the HTTP + * request. + * @param observeOncePerRequest whether the filter should only be applied once per + * request + */ + public void setObserveOncePerRequest(boolean observeOncePerRequest) { + this.observeOncePerRequest = observeOncePerRequest; + } + + /** + * If set to true, the filter will be applied to error dispatcher. Defaults to false. + * @param filterErrorDispatch whether the filter should be applied to error dispatcher + */ + public void setFilterErrorDispatch(boolean filterErrorDispatch) { + this.filterErrorDispatch = filterErrorDispatch; + } + + /** + * If set to true, the filter will be applied to the async dispatcher. Defaults to + * false. + * @param filterAsyncDispatch whether the filter should be applied to async dispatch + */ + public void setFilterAsyncDispatch(boolean filterAsyncDispatch) { + this.filterAsyncDispatch = filterAsyncDispatch; + } + } diff --git a/web/src/test/java/org/springframework/security/web/access/intercept/AuthorizationFilterTests.java b/web/src/test/java/org/springframework/security/web/access/intercept/AuthorizationFilterTests.java index 0923605216..c8dc1d7b5c 100644 --- a/web/src/test/java/org/springframework/security/web/access/intercept/AuthorizationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/access/intercept/AuthorizationFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,15 +16,20 @@ package org.springframework.security.web.access.intercept; +import java.io.IOException; import java.util.function.Supplier; +import javax.servlet.DispatcherType; import javax.servlet.FilterChain; +import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.access.AccessDeniedException; @@ -36,6 +41,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.test.util.ReflectionTestUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -43,6 +49,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -53,6 +60,24 @@ import static org.mockito.Mockito.verifyNoInteractions; */ public class AuthorizationFilterTests { + private static final String ALREADY_FILTERED_ATTRIBUTE_NAME = "org.springframework.security.web.access.intercept.AuthorizationFilter.APPLIED"; + + private AuthorizationFilter filter; + + private AuthorizationManager authorizationManager; + + private MockHttpServletRequest request = new MockHttpServletRequest(); + + private final MockHttpServletResponse response = new MockHttpServletResponse(); + + private final FilterChain chain = new MockFilterChain(); + + @BeforeEach + public void setup() { + this.authorizationManager = mock(AuthorizationManager.class); + this.filter = new AuthorizationFilter(this.authorizationManager); + } + @AfterEach public void tearDown() { SecurityContextHolder.clearContext(); @@ -132,4 +157,102 @@ public class AuthorizationFilterTests { assertThat(authorizationFilter.getAuthorizationManager()).isSameAs(authorizationManager); } + @Test + public void doFilterWhenObserveOncePerRequestTrueAndIsAppliedThenNotInvoked() throws ServletException, IOException { + setIsAppliedTrue(); + this.filter.setObserveOncePerRequest(true); + this.filter.doFilter(this.request, this.response, this.chain); + verifyNoInteractions(this.authorizationManager); + } + + @Test + public void doFilterWhenObserveOncePerRequestTrueAndNotAppliedThenInvoked() throws ServletException, IOException { + this.filter.setObserveOncePerRequest(true); + this.filter.doFilter(this.request, this.response, this.chain); + verify(this.authorizationManager).verify(any(), any()); + } + + @Test + public void doFilterWhenObserveOncePerRequestFalseAndIsAppliedThenInvoked() throws ServletException, IOException { + setIsAppliedTrue(); + this.filter.setObserveOncePerRequest(false); + this.filter.doFilter(this.request, this.response, this.chain); + verify(this.authorizationManager).verify(any(), any()); + } + + @Test + public void doFilterWhenObserveOncePerRequestFalseAndNotAppliedThenInvoked() throws ServletException, IOException { + this.filter.setObserveOncePerRequest(false); + this.filter.doFilter(this.request, this.response, this.chain); + verify(this.authorizationManager).verify(any(), any()); + } + + @Test + public void doFilterWhenFilterErrorDispatchFalseAndIsErrorThenNotInvoked() throws ServletException, IOException { + this.request.setDispatcherType(DispatcherType.ERROR); + this.filter.setFilterErrorDispatch(false); + this.filter.doFilter(this.request, this.response, this.chain); + verifyNoInteractions(this.authorizationManager); + } + + @Test + public void doFilterWhenFilterErrorDispatchTrueAndIsErrorThenInvoked() throws ServletException, IOException { + this.request.setDispatcherType(DispatcherType.ERROR); + this.filter.setFilterErrorDispatch(true); + this.filter.doFilter(this.request, this.response, this.chain); + verify(this.authorizationManager).verify(any(), any()); + } + + @Test + public void doFilterWhenFilterThenSetAlreadyFilteredAttribute() throws ServletException, IOException { + this.request = mock(MockHttpServletRequest.class); + this.filter.doFilter(this.request, this.response, this.chain); + verify(this.request).setAttribute(ALREADY_FILTERED_ATTRIBUTE_NAME, Boolean.TRUE); + } + + @Test + public void doFilterWhenFilterThenRemoveAlreadyFilteredAttribute() throws ServletException, IOException { + this.request = spy(MockHttpServletRequest.class); + this.filter.doFilter(this.request, this.response, this.chain); + verify(this.request).setAttribute(ALREADY_FILTERED_ATTRIBUTE_NAME, Boolean.TRUE); + assertThat(this.request.getAttribute(ALREADY_FILTERED_ATTRIBUTE_NAME)).isNull(); + } + + @Test + public void doFilterWhenFilterAsyncDispatchTrueAndIsAsyncThenInvoked() throws ServletException, IOException { + this.request.setDispatcherType(DispatcherType.ASYNC); + this.filter.setFilterAsyncDispatch(true); + this.filter.doFilter(this.request, this.response, this.chain); + verify(this.authorizationManager).verify(any(), any()); + } + + @Test + public void doFilterWhenFilterAsyncDispatchFalseAndIsAsyncThenNotInvoked() throws ServletException, IOException { + this.request.setDispatcherType(DispatcherType.ASYNC); + this.filter.setFilterAsyncDispatch(false); + this.filter.doFilter(this.request, this.response, this.chain); + verifyNoInteractions(this.authorizationManager); + } + + @Test + public void filterWhenFilterErrorDispatchDefaultThenFalse() { + Boolean filterErrorDispatch = (Boolean) ReflectionTestUtils.getField(this.filter, "filterErrorDispatch"); + assertThat(filterErrorDispatch).isFalse(); + } + + @Test + public void filterWhenFilterAsyncDispatchDefaultThenFalse() { + Boolean filterAsyncDispatch = (Boolean) ReflectionTestUtils.getField(this.filter, "filterAsyncDispatch"); + assertThat(filterAsyncDispatch).isFalse(); + } + + @Test + public void filterWhenObserveOncePerRequestDefaultThenTrue() { + assertThat(this.filter.isObserveOncePerRequest()).isTrue(); + } + + private void setIsAppliedTrue() { + this.request.setAttribute(ALREADY_FILTERED_ATTRIBUTE_NAME, Boolean.TRUE); + } + }