diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java index 9216d6f677..edccad1bf0 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java @@ -1,5 +1,5 @@ /* - * Copyright 2004-2010 the original author or authors. + * Copyright 2004-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,7 +32,9 @@ import org.springframework.expression.ParseException; import org.springframework.security.access.expression.ExpressionUtils; import org.springframework.security.access.expression.SecurityExpressionHandler; import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.web.FilterInvocation; import org.springframework.security.web.WebAttributes; import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator; @@ -110,7 +112,7 @@ public abstract class AbstractAuthorizeTag { * @throws IOException */ public boolean authorizeUsingAccessExpression() throws IOException { - if (SecurityContextHolder.getContext().getAuthentication() == null) { + if (getContext().getAuthentication() == null) { return false; } SecurityExpressionHandler handler = getExpressionHandler(); @@ -131,7 +133,7 @@ public abstract class AbstractAuthorizeTag { FilterInvocation f = new FilterInvocation(getRequest(), getResponse(), (request, response) -> { throw new UnsupportedOperationException(); }); - return handler.createEvaluationContext(SecurityContextHolder.getContext().getAuthentication(), f); + return handler.createEvaluationContext(getContext().getAuthentication(), f); } /** @@ -142,7 +144,7 @@ public abstract class AbstractAuthorizeTag { */ public boolean authorizeUsingUrlCheck() throws IOException { String contextPath = ((HttpServletRequest) getRequest()).getContextPath(); - Authentication currentUser = SecurityContextHolder.getContext().getAuthentication(); + Authentication currentUser = getContext().getAuthentication(); return getPrivilegeEvaluator().isAllowed(contextPath, getUrl(), getMethod(), currentUser); } @@ -170,6 +172,17 @@ public abstract class AbstractAuthorizeTag { this.method = (method != null) ? method.toUpperCase() : null; } + private SecurityContext getContext() { + ApplicationContext appContext = SecurityWebApplicationContextUtils + .findRequiredWebApplicationContext(getServletContext()); + String[] names = appContext.getBeanNamesForType(SecurityContextHolderStrategy.class); + if (names.length == 1) { + SecurityContextHolderStrategy strategy = appContext.getBean(SecurityContextHolderStrategy.class); + return strategy.getContext(); + } + return SecurityContextHolder.getContext(); + } + @SuppressWarnings({ "unchecked", "rawtypes" }) private SecurityExpressionHandler getExpressionHandler() throws IOException { ApplicationContext appContext = SecurityWebApplicationContextUtils diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java index 32e4fad6be..fb690918cf 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java @@ -32,6 +32,7 @@ import org.springframework.context.ApplicationContext; import org.springframework.security.access.PermissionEvaluator; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.taglibs.TagLibConfig; import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils; @@ -56,6 +57,9 @@ public class AccessControlListTag extends TagSupport { protected static final Log logger = LogFactory.getLog(AccessControlListTag.class); + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private ApplicationContext applicationContext; private Object domainObject; @@ -77,7 +81,7 @@ public class AccessControlListTag extends TagSupport { // Of course they have access to a null object! return evalBody(); } - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication(); if (authentication == null) { logger.debug("SecurityContextHolder did not return a non-null Authentication object, so skipping tag body"); return skipBody(); @@ -145,6 +149,12 @@ public class AccessControlListTag extends TagSupport { } this.applicationContext = getContext(this.pageContext); this.permissionEvaluator = getBeanOfType(PermissionEvaluator.class); + String[] names = this.applicationContext.getBeanNamesForType(SecurityContextHolderStrategy.class); + if (names.length == 1) { + SecurityContextHolderStrategy strategy = this.applicationContext + .getBean(SecurityContextHolderStrategy.class); + this.securityContextHolderStrategy = strategy; + } } private T getBeanOfType(Class type) throws JspException { diff --git a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AuthenticationTag.java b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AuthenticationTag.java index 3d7cd31da1..adc9a663a8 100644 --- a/taglibs/src/main/java/org/springframework/security/taglibs/authz/AuthenticationTag.java +++ b/taglibs/src/main/java/org/springframework/security/taglibs/authz/AuthenticationTag.java @@ -18,6 +18,7 @@ package org.springframework.security.taglibs.authz; import java.io.IOException; +import jakarta.servlet.ServletContext; import jakarta.servlet.jsp.JspException; import jakarta.servlet.jsp.PageContext; import jakarta.servlet.jsp.tagext.Tag; @@ -25,9 +26,12 @@ import jakarta.servlet.jsp.tagext.TagSupport; import org.springframework.beans.BeanWrapperImpl; import org.springframework.beans.BeansException; +import org.springframework.context.ApplicationContext; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils; import org.springframework.security.web.util.TextEscapeUtils; import org.springframework.web.util.TagUtils; @@ -42,6 +46,9 @@ import org.springframework.web.util.TagUtils; */ public class AuthenticationTag extends TagSupport { + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + private String var; private String property; @@ -76,6 +83,18 @@ public class AuthenticationTag extends TagSupport { this.scopeSpecified = true; } + public void setPageContext(PageContext pageContext) { + super.setPageContext(pageContext); + ServletContext servletContext = pageContext.getServletContext(); + ApplicationContext context = SecurityWebApplicationContextUtils + .findRequiredWebApplicationContext(servletContext); + String[] names = context.getBeanNamesForType(SecurityContextHolderStrategy.class); + if (names.length == 1) { + SecurityContextHolderStrategy strategy = context.getBean(SecurityContextHolderStrategy.class); + this.securityContextHolderStrategy = strategy; + } + } + @Override public int doStartTag() throws JspException { return super.doStartTag(); @@ -86,12 +105,11 @@ public class AuthenticationTag extends TagSupport { Object result = null; // determine the value by... if (this.property != null) { - if ((SecurityContextHolder.getContext() == null) - || !(SecurityContextHolder.getContext() instanceof SecurityContext) - || (SecurityContextHolder.getContext().getAuthentication() == null)) { + SecurityContext context = this.securityContextHolderStrategy.getContext(); + if ((context == null) || !(context instanceof SecurityContext) || (context.getAuthentication() == null)) { return Tag.EVAL_PAGE; } - Authentication auth = SecurityContextHolder.getContext().getAuthentication(); + Authentication auth = context.getAuthentication(); if (auth.getPrincipal() == null) { return Tag.EVAL_PAGE; } diff --git a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java index 6880f7c94e..00500f6e60 100644 --- a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java +++ b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,11 +31,15 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockServletContext; import org.springframework.security.access.expression.SecurityExpressionHandler; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.web.WebAttributes; import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator; import org.springframework.security.web.access.expression.DefaultWebSecurityExpressionHandler; import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.GenericWebApplicationContext; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; @@ -73,6 +77,9 @@ public class AbstractAuthorizeTagTests { @Test public void privilegeEvaluatorFromRequest() throws IOException { + WebApplicationContext wac = mock(WebApplicationContext.class); + this.servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac); + given(wac.getBeanNamesForType(SecurityContextHolderStrategy.class)).willReturn(new String[0]); String uri = "/something"; WebInvocationPrivilegeEvaluator expected = mock(WebInvocationPrivilegeEvaluator.class); this.tag.setUrl(uri); @@ -81,6 +88,24 @@ public class AbstractAuthorizeTagTests { verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any()); } + @Test + public void privilegeEvaluatorFromRequestUsesSecurityContextHolderStrategy() throws IOException { + SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + given(strategy.getContext()).willReturn(new SecurityContextImpl( + new TestingAuthenticationToken("user", "password", AuthorityUtils.NO_AUTHORITIES))); + GenericWebApplicationContext wac = new GenericWebApplicationContext(); + wac.registerBean(SecurityContextHolderStrategy.class, () -> strategy); + wac.refresh(); + this.servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac); + String uri = "/something"; + WebInvocationPrivilegeEvaluator expected = mock(WebInvocationPrivilegeEvaluator.class); + this.tag.setUrl(uri); + this.request.setAttribute(WebAttributes.WEB_INVOCATION_PRIVILEGE_EVALUATOR_ATTRIBUTE, expected); + this.tag.authorizeUsingUrlCheck(); + verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any()); + verify(strategy).getContext(); + } + @Test public void privilegeEvaluatorFromChildContext() throws IOException { String uri = "/something"; @@ -89,6 +114,7 @@ public class AbstractAuthorizeTagTests { WebApplicationContext wac = mock(WebApplicationContext.class); given(wac.getBeansOfType(WebInvocationPrivilegeEvaluator.class)) .willReturn(Collections.singletonMap("wipe", expected)); + given(wac.getBeanNamesForType(SecurityContextHolderStrategy.class)).willReturn(new String[0]); this.servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); this.tag.authorizeUsingUrlCheck(); verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any()); @@ -103,6 +129,7 @@ public class AbstractAuthorizeTagTests { WebApplicationContext wac = mock(WebApplicationContext.class); given(wac.getBeansOfType(SecurityExpressionHandler.class)) .willReturn(Collections.singletonMap("wipe", expected)); + given(wac.getBeanNamesForType(SecurityContextHolderStrategy.class)).willReturn(new String[0]); this.servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac); assertThat(this.tag.authorize()).isTrue(); } diff --git a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java index aa3987748b..c4333433c5 100644 --- a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java +++ b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,7 +33,10 @@ import org.springframework.security.access.PermissionEvaluator; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.GenericWebApplicationContext; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.BDDMockito.given; @@ -67,6 +70,7 @@ public class AccessControlListTagTests { Map beanMap = new HashMap(); beanMap.put("pe", this.pe); given(ctx.getBeansOfType(PermissionEvaluator.class)).willReturn(beanMap); + given(ctx.getBeanNamesForType(SecurityContextHolderStrategy.class)).willReturn(new String[0]); MockServletContext servletCtx = new MockServletContext(); servletCtx.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, ctx); this.pageContext = new MockPageContext(servletCtx, new MockHttpServletRequest(), new MockHttpServletResponse()); @@ -91,6 +95,30 @@ public class AccessControlListTagTests { assertThat((Boolean) this.pageContext.getAttribute("allowed")).isTrue(); } + @Test + public void securityContextHolderStrategyIsUsedIfConfigured() throws Exception { + SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + given(strategy.getContext()).willReturn(new SecurityContextImpl(this.bob)); + GenericWebApplicationContext context = new GenericWebApplicationContext(); + context.registerBean(SecurityContextHolderStrategy.class, () -> strategy); + context.registerBean(PermissionEvaluator.class, () -> this.pe); + context.refresh(); + MockServletContext servletCtx = new MockServletContext(); + servletCtx.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, context); + this.pageContext = new MockPageContext(servletCtx, new MockHttpServletRequest(), new MockHttpServletResponse()); + this.tag.setPageContext(this.pageContext); + Object domainObject = new Object(); + given(this.pe.hasPermission(this.bob, domainObject, "READ")).willReturn(true); + this.tag.setDomainObject(domainObject); + this.tag.setHasPermission("READ"); + this.tag.setVar("allowed"); + assertThat(this.tag.getDomainObject()).isSameAs(domainObject); + assertThat(this.tag.getHasPermission()).isEqualTo("READ"); + assertThat(this.tag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE); + assertThat((Boolean) this.pageContext.getAttribute("allowed")).isTrue(); + verify(strategy).getContext(); + } + @Test public void childContext() throws Exception { ServletContext servletContext = this.pageContext.getServletContext(); diff --git a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AuthenticationTagTests.java b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AuthenticationTagTests.java index 9292d1729d..82dc302067 100644 --- a/taglibs/src/test/java/org/springframework/security/taglibs/authz/AuthenticationTagTests.java +++ b/taglibs/src/test/java/org/springframework/security/taglibs/authz/AuthenticationTagTests.java @@ -21,14 +21,23 @@ import jakarta.servlet.jsp.tagext.Tag; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; +import org.springframework.mock.web.MockPageContext; +import org.springframework.mock.web.MockServletContext; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.core.userdetails.User; +import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.GenericWebApplicationContext; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests {@link AuthenticationTag}. @@ -130,6 +139,24 @@ public class AuthenticationTagTests { assertThat(this.authenticationTag.getLastMessage()).isEqualTo("<>& "); } + @Test + public void setSecurityContextHolderStrategyThenUses() throws Exception { + SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class); + given(strategy.getContext()).willReturn(new SecurityContextImpl( + new TestingAuthenticationToken("rodAsString", "koala", AuthorityUtils.NO_AUTHORITIES))); + MockServletContext servletContext = new MockServletContext(); + GenericWebApplicationContext applicationContext = new GenericWebApplicationContext(); + applicationContext.registerBean(SecurityContextHolderStrategy.class, () -> strategy); + applicationContext.refresh(); + servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, applicationContext); + this.authenticationTag.setPageContext(new MockPageContext(servletContext)); + this.authenticationTag.setProperty("principal"); + assertThat(this.authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY); + assertThat(this.authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE); + assertThat(this.authenticationTag.getLastMessage()).isEqualTo("rodAsString"); + verify(strategy).getContext(); + } + private class MyAuthenticationTag extends AuthenticationTag { String lastMessage = null;