diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java index 6ac7d6546c..9ab9795d0f 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java @@ -42,6 +42,7 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.FilterInvocation; import org.springframework.security.web.WebAttributes; import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator; +import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils; import org.springframework.util.StringUtils; import org.springframework.web.context.support.WebApplicationContextUtils; @@ -312,8 +313,7 @@ public abstract class AbstractAuthorizeTag { @SuppressWarnings({ "unchecked", "rawtypes" }) private SecurityExpressionHandler getExpressionHandler() throws IOException { - ApplicationContext appContext = WebApplicationContextUtils - .getRequiredWebApplicationContext(getServletContext()); + ApplicationContext appContext = SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(getServletContext()); Map handlers = appContext .getBeansOfType(SecurityExpressionHandler.class); @@ -335,8 +335,9 @@ public abstract class AbstractAuthorizeTag { return privEvaluatorFromRequest; } - ApplicationContext ctx = WebApplicationContextUtils.getRequiredWebApplicationContext(getServletContext()); - Map wipes = ctx.getBeansOfType(WebInvocationPrivilegeEvaluator.class); + ApplicationContext ctx = SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(getServletContext()); + Map wipes = ctx + .getBeansOfType(WebInvocationPrivilegeEvaluator.class); if (wipes.size() == 0) { throw new IOException( diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java index c16aa1535c..49a9c87301 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java @@ -21,6 +21,7 @@ import org.springframework.security.access.PermissionEvaluator; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.taglibs.TagLibConfig; +import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils; import org.springframework.web.context.support.WebApplicationContextUtils; import javax.servlet.ServletContext; @@ -136,7 +137,7 @@ public class AccessControlListTag extends TagSupport { protected ApplicationContext getContext(PageContext pageContext) { ServletContext servletContext = pageContext.getServletContext(); - return WebApplicationContextUtils.getRequiredWebApplicationContext(servletContext); + return SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(servletContext); } public Object getDomainObject() { diff --git a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java index 83770533ff..69ec40e7c9 100644 --- a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java +++ b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java @@ -12,12 +12,16 @@ */ package org.springframework.security.taglibs.authz; +import static org.fest.assertions.Assertions.*; + import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Collections; import javax.servlet.ServletContext; import javax.servlet.ServletRequest; @@ -29,10 +33,14 @@ import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockServletContext; +import org.springframework.security.access.expression.SecurityExpressionHandler; +import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.web.WebAttributes; import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator; +import org.springframework.security.web.access.expression.DefaultWebSecurityExpressionHandler; +import org.springframework.web.context.WebApplicationContext; /** * @@ -63,13 +71,41 @@ public class AbstractAuthorizeTagTests { String uri = "/something"; WebInvocationPrivilegeEvaluator expected = mock(WebInvocationPrivilegeEvaluator.class); tag.setUrl(uri); - request.setAttribute(WebAttributes.WEB_INVOCATION_PRIVILEGE_EVALUATOR_ATTRIBUTE, expected); + request.setAttribute(WebAttributes.WEB_INVOCATION_PRIVILEGE_EVALUATOR_ATTRIBUTE, + expected); tag.authorizeUsingUrlCheck(); verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any(Authentication.class)); } + @Test + public void privilegeEvaluatorFromChildContext() throws IOException { + String uri = "/something"; + WebInvocationPrivilegeEvaluator expected = mock(WebInvocationPrivilegeEvaluator.class); + tag.setUrl(uri); + WebApplicationContext wac = mock(WebApplicationContext.class); + when(wac.getBeansOfType(WebInvocationPrivilegeEvaluator.class)).thenReturn(Collections.singletonMap("wipe", expected)); + servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); + + tag.authorizeUsingUrlCheck(); + + verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any(Authentication.class)); + } + + @Test + @SuppressWarnings("rawtypes") + public void expressionFromChildContext() throws IOException { + SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "pass","USER")); + DefaultWebSecurityExpressionHandler expected = new DefaultWebSecurityExpressionHandler(); + tag.setAccess("permitAll"); + WebApplicationContext wac = mock(WebApplicationContext.class); + when(wac.getBeansOfType(SecurityExpressionHandler.class)).thenReturn(Collections.singletonMap("wipe", expected)); + servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); + + assertThat(tag.authorize()).isTrue(); + } + private class AuthzTag extends AbstractAuthorizeTag { @Override diff --git a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java index a0168780d8..5b3da2676b 100644 --- a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java +++ b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java @@ -26,6 +26,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.web.context.WebApplicationContext; +import javax.servlet.ServletContext; import javax.servlet.jsp.tagext.Tag; import java.util.*; @@ -40,7 +41,7 @@ public class AccessControlListTagTests { AccessControlListTag tag; PermissionEvaluator pe; MockPageContext pageContext; - Authentication bob = new TestingAuthenticationToken("bob","bobspass","A"); + Authentication bob = new TestingAuthenticationToken("bob", "bobspass", "A"); @Before @SuppressWarnings("rawtypes") @@ -56,8 +57,10 @@ public class AccessControlListTagTests { when(ctx.getBeansOfType(PermissionEvaluator.class)).thenReturn(beanMap); MockServletContext servletCtx = new MockServletContext(); - servletCtx.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, ctx); - pageContext = new MockPageContext(servletCtx, new MockHttpServletRequest(), new MockHttpServletResponse()); + servletCtx.setAttribute( + WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, ctx); + pageContext = new MockPageContext(servletCtx, new MockHttpServletRequest(), + new MockHttpServletResponse()); tag.setPageContext(pageContext); } @@ -78,7 +81,28 @@ public class AccessControlListTagTests { assertEquals("READ", tag.getHasPermission()); assertEquals(Tag.EVAL_BODY_INCLUDE, tag.doStartTag()); - assertTrue((Boolean)pageContext.getAttribute("allowed")); + assertTrue((Boolean) pageContext.getAttribute("allowed")); + } + + @Test + public void childContext() throws Exception { + ServletContext servletContext = pageContext.getServletContext(); + WebApplicationContext wac = (WebApplicationContext) servletContext + .getAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE); + servletContext.removeAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE); + servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); + + Object domainObject = new Object(); + when(pe.hasPermission(bob, domainObject, "READ")).thenReturn(true); + + tag.setDomainObject(domainObject); + tag.setHasPermission("READ"); + tag.setVar("allowed"); + assertSame(domainObject, tag.getDomainObject()); + assertEquals("READ", tag.getHasPermission()); + + assertEquals(Tag.EVAL_BODY_INCLUDE, tag.doStartTag()); + assertTrue((Boolean) pageContext.getAttribute("allowed")); } // SEC-2022 @@ -95,7 +119,7 @@ public class AccessControlListTagTests { assertEquals("READ,WRITE", tag.getHasPermission()); assertEquals(Tag.EVAL_BODY_INCLUDE, tag.doStartTag()); - assertTrue((Boolean)pageContext.getAttribute("allowed")); + assertTrue((Boolean) pageContext.getAttribute("allowed")); verify(pe).hasPermission(bob, domainObject, "READ"); verify(pe).hasPermission(bob, domainObject, "WRITE"); verifyNoMoreInteractions(pe); @@ -115,7 +139,7 @@ public class AccessControlListTagTests { assertEquals("1,2", tag.getHasPermission()); assertEquals(Tag.EVAL_BODY_INCLUDE, tag.doStartTag()); - assertTrue((Boolean)pageContext.getAttribute("allowed")); + assertTrue((Boolean) pageContext.getAttribute("allowed")); verify(pe).hasPermission(bob, domainObject, 1); verify(pe).hasPermission(bob, domainObject, 2); verifyNoMoreInteractions(pe); @@ -134,7 +158,7 @@ public class AccessControlListTagTests { assertEquals("1,WRITE", tag.getHasPermission()); assertEquals(Tag.EVAL_BODY_INCLUDE, tag.doStartTag()); - assertTrue((Boolean)pageContext.getAttribute("allowed")); + assertTrue((Boolean) pageContext.getAttribute("allowed")); verify(pe).hasPermission(bob, domainObject, 1); verify(pe).hasPermission(bob, domainObject, "WRITE"); verifyNoMoreInteractions(pe); @@ -150,6 +174,6 @@ public class AccessControlListTagTests { tag.setVar("allowed"); assertEquals(Tag.SKIP_BODY, tag.doStartTag()); - assertFalse((Boolean)pageContext.getAttribute("allowed")); + assertFalse((Boolean) pageContext.getAttribute("allowed")); } } diff --git a/web/src/main/java/org/springframework/security/web/context/support/SecurityWebApplicationContextUtils.java b/web/src/main/java/org/springframework/security/web/context/support/SecurityWebApplicationContextUtils.java new file mode 100644 index 0000000000..6b01258805 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/context/support/SecurityWebApplicationContextUtils.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.context.support; + +import java.util.Enumeration; + +import javax.servlet.ServletContext; + +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.WebApplicationContextUtils; + +/** + * Spring Security extension to Spring's {@link WebApplicationContextUtils}. + * + * @author Rob Winch + */ +public abstract class SecurityWebApplicationContextUtils extends WebApplicationContextUtils { + + /** + * Find a unique {@code WebApplicationContext} for this web app: either the + * root web app context (preferred) or a unique {@code WebApplicationContext} + * among the registered {@code ServletContext} attributes (typically coming + * from a single {@code DispatcherServlet} in the current web application). + *

Note that {@code DispatcherServlet}'s exposure of its context can be + * controlled through its {@code publishContext} property, which is {@code true} + * by default but can be selectively switched to only publish a single context + * despite multiple {@code DispatcherServlet} registrations in the web app. + * @param sc ServletContext to find the web application context for + * @return the desired WebApplicationContext for this web app + * @see #getWebApplicationContext(ServletContext) + * @see ServletContext#getAttributeNames() + * @throws IllegalStateException if no WebApplicationContext can be found + */ + public static WebApplicationContext findRequiredWebApplicationContext(ServletContext servletContext) { + WebApplicationContext wac = findWebApplicationContext(servletContext); + if (wac == null) { + throw new IllegalStateException("No WebApplicationContext found: no ContextLoaderListener registered?"); + } + return wac; + } + + /** + * Find a unique {@code WebApplicationContext} for this web app: either the + * root web app context (preferred) or a unique {@code WebApplicationContext} + * among the registered {@code ServletContext} attributes (typically coming + * from a single {@code DispatcherServlet} in the current web application). + *

Note that {@code DispatcherServlet}'s exposure of its context can be + * controlled through its {@code publishContext} property, which is {@code true} + * by default but can be selectively switched to only publish a single context + * despite multiple {@code DispatcherServlet} registrations in the web app. + * @param sc ServletContext to find the web application context for + * @return the desired WebApplicationContext for this web app, or {@code null} if none + * @see #getWebApplicationContext(ServletContext) + * @see ServletContext#getAttributeNames() + */ + private static WebApplicationContext findWebApplicationContext(ServletContext sc) { + WebApplicationContext wac = getWebApplicationContext(sc); + if (wac == null) { + Enumeration attrNames = sc.getAttributeNames(); + while (attrNames.hasMoreElements()) { + String attrName = attrNames.nextElement(); + Object attrValue = sc.getAttribute(attrName); + if (attrValue instanceof WebApplicationContext) { + if (wac != null) { + throw new IllegalStateException("No unique WebApplicationContext found: more than one " + + "DispatcherServlet registered with publishContext=true?"); + } + wac = (WebApplicationContext) attrValue; + } + } + } + return wac; + } + +} diff --git a/web/src/main/java/org/springframework/security/web/session/HttpSessionEventPublisher.java b/web/src/main/java/org/springframework/security/web/session/HttpSessionEventPublisher.java index fddd54f808..d73286e26d 100644 --- a/web/src/main/java/org/springframework/security/web/session/HttpSessionEventPublisher.java +++ b/web/src/main/java/org/springframework/security/web/session/HttpSessionEventPublisher.java @@ -19,6 +19,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.context.ApplicationContext; +import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils; import org.springframework.web.context.support.WebApplicationContextUtils; @@ -49,7 +50,7 @@ public class HttpSessionEventPublisher implements HttpSessionListener { //~ Methods ======================================================================================================== ApplicationContext getContext(ServletContext servletContext) { - return WebApplicationContextUtils.getRequiredWebApplicationContext(servletContext); + return SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(servletContext); } /** diff --git a/web/src/test/java/org/springframework/security/web/session/HttpSessionEventPublisherTests.java b/web/src/test/java/org/springframework/security/web/session/HttpSessionEventPublisherTests.java index 2509473096..cf2f5b0393 100644 --- a/web/src/test/java/org/springframework/security/web/session/HttpSessionEventPublisherTests.java +++ b/web/src/test/java/org/springframework/security/web/session/HttpSessionEventPublisherTests.java @@ -44,7 +44,45 @@ public class HttpSessionEventPublisherTests { StaticWebApplicationContext context = new StaticWebApplicationContext(); MockServletContext servletContext = new MockServletContext(); - servletContext.setAttribute(StaticWebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, context); + servletContext.setAttribute( + StaticWebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, + context); + + context.setServletContext(servletContext); + context.registerSingleton("listener", MockApplicationListener.class, null); + context.refresh(); + + MockHttpSession session = new MockHttpSession(servletContext); + MockApplicationListener listener = (MockApplicationListener) context + .getBean("listener"); + + HttpSessionEvent event = new HttpSessionEvent(session); + + publisher.sessionCreated(event); + + assertNotNull(listener.getCreatedEvent()); + assertNull(listener.getDestroyedEvent()); + assertEquals(session, listener.getCreatedEvent().getSession()); + + listener.setCreatedEvent(null); + listener.setDestroyedEvent(null); + + publisher.sessionDestroyed(event); + assertNotNull(listener.getDestroyedEvent()); + assertNull(listener.getCreatedEvent()); + assertEquals(session, listener.getDestroyedEvent().getSession()); + } + + @Test + public void publishedEventIsReceivedbyListenerChildContext() { + HttpSessionEventPublisher publisher = new HttpSessionEventPublisher(); + + StaticWebApplicationContext context = new StaticWebApplicationContext(); + + MockServletContext servletContext = new MockServletContext(); + servletContext.setAttribute( + "org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", + context); context.setServletContext(servletContext); context.registerSingleton("listener", MockApplicationListener.class, null);