diff --git a/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java b/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java index e05fc77bea..0ea3f858e8 100644 --- a/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java +++ b/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java @@ -108,7 +108,7 @@ public final class CurrentSecurityContextArgumentResolver StandardEvaluationContext context = new StandardEvaluationContext(); context.setRootObject(securityContext); context.setVariable("this", securityContext); - + context.setBeanResolver(this.beanResolver); Expression expression = this.parser.parseExpression(expressionToParse); securityContextResult = expression.getValue(context); } diff --git a/web/src/test/java/org/springframework/security/web/method/annotation/AuthenticationPrincipalArgumentResolverTests.java b/web/src/test/java/org/springframework/security/web/method/annotation/AuthenticationPrincipalArgumentResolverTests.java index d7be109aec..36da3d5e61 100644 --- a/web/src/test/java/org/springframework/security/web/method/annotation/AuthenticationPrincipalArgumentResolverTests.java +++ b/web/src/test/java/org/springframework/security/web/method/annotation/AuthenticationPrincipalArgumentResolverTests.java @@ -22,11 +22,14 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import java.lang.reflect.Method; +import java.util.function.Function; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.springframework.core.MethodParameter; +import org.springframework.expression.AccessException; +import org.springframework.expression.BeanResolver; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.annotation.AuthenticationPrincipal; import org.springframework.security.core.authority.AuthorityUtils; @@ -40,12 +43,21 @@ import org.springframework.util.ReflectionUtils; * */ public class AuthenticationPrincipalArgumentResolverTests { + + private final BeanResolver beanResolver = ((context, beanName) -> { + if (!"test".equals(beanName)) { + throw new AccessException("Could not resolve bean reference against BeanFactory"); + } + return (Function) (principal) -> principal.property; + }); + private Object expectedPrincipal; private AuthenticationPrincipalArgumentResolver resolver; @Before public void setup() { resolver = new AuthenticationPrincipalArgumentResolver(); + resolver.setBeanResolver(this.beanResolver); } @After @@ -128,6 +140,14 @@ public class AuthenticationPrincipalArgumentResolverTests { .isEqualTo(this.expectedPrincipal); } + @Test + public void resolveArgumentSpelBean() throws Exception { + CustomUserPrincipal principal = new CustomUserPrincipal(); + setAuthenticationPrincipal(principal); + this.expectedPrincipal = principal.property; + assertThat(this.resolver.resolveArgument(showUserSpelBean(), null, null, null)).isEqualTo(this.expectedPrincipal); + } + @Test public void resolveArgumentSpelCopy() throws Exception { CopyUserPrincipal principal = new CopyUserPrincipal("property"); @@ -198,6 +218,10 @@ public class AuthenticationPrincipalArgumentResolverTests { return getMethodParameter("showUserSpel", String.class); } + private MethodParameter showUserSpelBean() { + return getMethodParameter("showUserSpelBean", String.class); + } + private MethodParameter showUserSpelCopy() { return getMethodParameter("showUserSpelCopy", CopyUserPrincipal.class); } @@ -255,6 +279,10 @@ public class AuthenticationPrincipalArgumentResolverTests { @AuthenticationPrincipal(expression = "property") String user) { } + public void showUserSpelBean(@AuthenticationPrincipal( + expression = "@test.apply(#this)") String user) { + } + public void showUserSpelCopy( @AuthenticationPrincipal(expression = "new org.springframework.security.web.method.annotation.AuthenticationPrincipalArgumentResolverTests$CopyUserPrincipal(#this)") CopyUserPrincipal user) { } diff --git a/web/src/test/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolverTests.java b/web/src/test/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolverTests.java index 84dc2b2ce1..f7cc7b20d8 100644 --- a/web/src/test/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolverTests.java +++ b/web/src/test/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolverTests.java @@ -20,12 +20,15 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import java.lang.reflect.Method; +import java.util.function.Function; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.springframework.core.MethodParameter; +import org.springframework.expression.AccessException; +import org.springframework.expression.BeanResolver; import org.springframework.expression.spel.SpelEvaluationException; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; @@ -45,11 +48,20 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; * */ public class CurrentSecurityContextArgumentResolverTests { + + private final BeanResolver beanResolver = ((context, beanName) -> { + if (!"test".equals(beanName)) { + throw new AccessException("Could not resolve bean reference against BeanFactory"); + } + return (Function) SecurityContext::getAuthentication; + }); + private CurrentSecurityContextArgumentResolver resolver; @Before public void setup() { this.resolver = new CurrentSecurityContextArgumentResolver(); + this.resolver.setBeanResolver(this.beanResolver); } @After @@ -104,6 +116,15 @@ public class CurrentSecurityContextArgumentResolverTests { assertThat(auth1.getPrincipal()).isEqualTo(principal); } + @Test + public void resolveArgumentWithAuthenticationWithBean() { + String principal = "john"; + setAuthenticationPrincipal(principal); + Authentication auth1 = (Authentication) this.resolver + .resolveArgument(showSecurityContextAuthenticationWithBean(), null, null, null); + assertThat(auth1.getPrincipal()).isEqualTo(principal); + } + @Test public void resolveArgumentWithNullAuthentication() { SecurityContext context = SecurityContextHolder.getContext(); @@ -217,6 +238,10 @@ public class CurrentSecurityContextArgumentResolverTests { return getMethodParameter("showSecurityContextAuthenticationAnnotation", Authentication.class); } + public MethodParameter showSecurityContextAuthenticationWithBean() { + return getMethodParameter("showSecurityContextAuthenticationWithBean", Authentication.class); + } + private MethodParameter showSecurityContextAuthenticationWithOptionalPrincipal() { return getMethodParameter("showSecurityContextAuthenticationWithOptionalPrincipal", Object.class); } @@ -279,6 +304,10 @@ public class CurrentSecurityContextArgumentResolverTests { public void showSecurityContextAuthenticationAnnotation(@CurrentSecurityContext(expression = "authentication") Authentication authentication) { } + public void showSecurityContextAuthenticationWithBean( + @CurrentSecurityContext(expression = "@test.apply(#this)") Authentication authentication) { + } + public void showSecurityContextAuthenticationWithOptionalPrincipal(@CurrentSecurityContext(expression = "authentication?.principal") Object principal) { }