diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfiguration.java index 5d296bb29e..707e8775f7 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfiguration.java @@ -15,14 +15,13 @@ */ package org.springframework.security.config.annotation.web.configuration; -import java.util.List; - import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.context.annotation.Bean; import org.springframework.context.expression.BeanFactoryResolver; import org.springframework.expression.BeanResolver; +import org.springframework.security.web.bind.support.CurrentSecurityContextArgumentResolver; import org.springframework.security.web.method.annotation.AuthenticationPrincipalArgumentResolver; import org.springframework.security.web.method.annotation.CsrfTokenArgumentResolver; import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor; @@ -31,6 +30,8 @@ import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; import org.springframework.web.servlet.support.RequestDataValueProcessor; +import java.util.List; + /** * Used to add a {@link RequestDataValueProcessor} for Spring MVC and Spring Security CSRF * integration. This configuration is added whenever {@link EnableWebMvc} is added by @@ -52,6 +53,10 @@ class WebMvcSecurityConfiguration implements WebMvcConfigurer, ApplicationContex argumentResolvers.add(authenticationPrincipalResolver); argumentResolvers .add(new org.springframework.security.web.bind.support.AuthenticationPrincipalArgumentResolver()); + + CurrentSecurityContextArgumentResolver currentSecurityContextArgumentResolver = new CurrentSecurityContextArgumentResolver(); + currentSecurityContextArgumentResolver.setBeanResolver(beanResolver); + argumentResolvers.add(currentSecurityContextArgumentResolver); argumentResolvers.add(new CsrfTokenArgumentResolver()); } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java index 1acdcbc3dd..ae7ed2ede0 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java @@ -33,6 +33,7 @@ import org.springframework.security.core.userdetails.ReactiveUserDetailsPassword import org.springframework.security.core.userdetails.ReactiveUserDetailsService; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.web.reactive.result.method.annotation.AuthenticationPrincipalArgumentResolver; +import org.springframework.security.web.reactive.result.method.annotation.CurrentSecurityContextArgumentResolver; import org.springframework.web.reactive.config.WebFluxConfigurer; import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer; @@ -78,6 +79,17 @@ class ServerHttpSecurityConfiguration implements WebFluxConfigurer { return resolver; } + @Bean + public CurrentSecurityContextArgumentResolver reactiveCurrentSecurityContextArgumentResolver() { + CurrentSecurityContextArgumentResolver resolver = new CurrentSecurityContextArgumentResolver( + this.adapterRegistry); + if (this.beanFactory != null) { + resolver.setBeanResolver(new BeanFactoryResolver(this.beanFactory)); + } + return resolver; + } + + @Bean(HTTPSECURITY_BEAN_NAME) @Scope("prototype") public ServerHttpSecurity httpSecurity() { diff --git a/core/src/main/java/org/springframework/security/core/annotation/CurrentSecurityContext.java b/core/src/main/java/org/springframework/security/core/annotation/CurrentSecurityContext.java new file mode 100644 index 0000000000..fbbe2c95ed --- /dev/null +++ b/core/src/main/java/org/springframework/security/core/annotation/CurrentSecurityContext.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.core.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation that is used to resolve {@link SecurityContext#getAuthentication()} to a method + * argument. + * + * @author Dan Zheng + * @since 5.2.x + * + * See: CurrentSecurityContextArgumentResolver + */ +@Target({ ElementType.PARAMETER }) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface CurrentSecurityContext { + /** + * True if a {@link ClassCastException} should be thrown when the current + * {@link SecurityContext} is the incorrect type. Default is false. + * + * @return + */ + boolean errorOnInvalidType() default false; + /** + * If specified will use the provided SpEL expression to resolve the security context. This + * is convenient if users need to transform the result. + * + *

+ * For example, perhaps the user wants to resolve a CustomUser object that is final + * and is leveraging a UserDetailsService. This can be handled by returning an object + * that looks like: + *

+ * + *
+	 * public class CustomUserUserDetails extends User {
+	 *     // ...
+	 *     public CustomUser getCustomUser() {
+	 *         return customUser;
+	 *     }
+	 * }
+	 * 
+ * + * Then the user can specify an annotation that looks like: + * + *
+	 * @CurrentSecurityContext(expression = "authentication")
+	 * 
+ * + * @return the expression to use. + */ + String expression() default ""; +} diff --git a/web/src/main/java/org/springframework/security/web/bind/support/CurrentSecurityContextArgumentResolver.java b/web/src/main/java/org/springframework/security/web/bind/support/CurrentSecurityContextArgumentResolver.java new file mode 100644 index 0000000000..634801f6f4 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/bind/support/CurrentSecurityContextArgumentResolver.java @@ -0,0 +1,172 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.bind.support; + +import org.springframework.core.MethodParameter; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.expression.BeanResolver; +import org.springframework.expression.Expression; +import org.springframework.expression.ExpressionParser; +import org.springframework.expression.spel.standard.SpelExpressionParser; +import org.springframework.expression.spel.support.StandardEvaluationContext; +import org.springframework.security.core.annotation.CurrentSecurityContext; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.stereotype.Controller; +import org.springframework.util.StringUtils; +import org.springframework.web.bind.support.WebDataBinderFactory; +import org.springframework.web.context.request.NativeWebRequest; +import org.springframework.web.method.support.HandlerMethodArgumentResolver; +import org.springframework.web.method.support.ModelAndViewContainer; + +import java.lang.annotation.Annotation; + +/** + * Allows resolving the {@link SecurityContext} using the + * {@link CurrentSecurityContext} annotation. For example, the following + * {@link Controller}: + * + *
+ * @Controller
+ * public class MyController {
+ *     @RequestMapping("/im")
+ *     public void security(@CurrentSecurityContext SecurityContext context) {
+ *         // do something with context
+ *     }
+ * }
+ * 
+ * + * it can also support the spring SPEL expression to get the value from SecurityContext + *
+ * @Controller
+ * public class MyController {
+ *     @RequestMapping("/im")
+ *     public void security(@CurrentSecurityContext(expression="authentication") Authentication authentication) {
+ *         // do something with context
+ *     }
+ * }
+ * 
+ * + *

+ * Will resolve the SecurityContext argument using {@link SecurityContextHolder#getContext()} from + * the {@link SecurityContextHolder}. If the {@link SecurityContext} is null, it will return null. + * If the types do not match, null will be returned unless + * {@link CurrentSecurityContext#errorOnInvalidType()} is true in which case a + * {@link ClassCastException} will be thrown. + *

+ * + * @author Dan Zheng + * @since 5.2.x + */ +public final class CurrentSecurityContextArgumentResolver + implements HandlerMethodArgumentResolver { + + private ExpressionParser parser = new SpelExpressionParser(); + + private BeanResolver beanResolver; + /** + * check if this argument resolve can support the parameter. + * @param parameter the method parameter. + * @return true = it can support parameter. + * + * @see + * org.springframework.web.method.support.HandlerMethodArgumentResolver# + * supportsParameter(org.springframework.core.MethodParameter) + */ + public boolean supportsParameter(MethodParameter parameter) { + return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; + } + + /** + * resolve the argument to inject into the controller parameter. + * @param parameter the method parameter. + * @param mavContainer the model and view container. + * @param webRequest the web request. + * @param binderFactory the web data binder factory. + * + * @see org.springframework.web.method.support.HandlerMethodArgumentResolver# + * resolveArgument (org.springframework.core.MethodParameter, + * org.springframework.web.method.support.ModelAndViewContainer, + * org.springframework.web.context.request.NativeWebRequest, + * org.springframework.web.bind.support.WebDataBinderFactory) + */ + public Object resolveArgument(MethodParameter parameter, + ModelAndViewContainer mavContainer, NativeWebRequest webRequest, + WebDataBinderFactory binderFactory) throws Exception { + SecurityContext securityContext = SecurityContextHolder.getContext(); + if (securityContext == null) { + return null; + } + Object securityContextResult = securityContext; + + CurrentSecurityContext securityContextAnnotation = findMethodAnnotation( + CurrentSecurityContext.class, parameter); + + String expressionToParse = securityContextAnnotation.expression(); + if (StringUtils.hasLength(expressionToParse)) { + StandardEvaluationContext context = new StandardEvaluationContext(); + context.setRootObject(securityContext); + context.setVariable("this", securityContext); + + Expression expression = this.parser.parseExpression(expressionToParse); + securityContextResult = expression.getValue(context); + } + + if (securityContextResult != null + && !parameter.getParameterType().isAssignableFrom(securityContextResult.getClass())) { + if (securityContextAnnotation.errorOnInvalidType()) { + throw new ClassCastException(securityContextResult + " is not assignable to " + + parameter.getParameterType()); + } + else { + return null; + } + } + return securityContextResult; + } + /** + * Sets the {@link BeanResolver} to be used on the expressions + * @param beanResolver the {@link BeanResolver} to use + */ + public void setBeanResolver(BeanResolver beanResolver) { + this.beanResolver = beanResolver; + } + + /** + * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}. + * + * @param annotationClass the class of the {@link Annotation} to find on the + * {@link MethodParameter} + * @param parameter the {@link MethodParameter} to search for an {@link Annotation} + * @return the {@link Annotation} that was found or null. + */ + private T findMethodAnnotation(Class annotationClass, + MethodParameter parameter) { + T annotation = parameter.getParameterAnnotation(annotationClass); + if (annotation != null) { + return annotation; + } + Annotation[] annotationsToSearch = parameter.getParameterAnnotations(); + for (Annotation toSearch : annotationsToSearch) { + annotation = AnnotationUtils.findAnnotation(toSearch.annotationType(), + annotationClass); + if (annotation != null) { + return annotation; + } + } + return null; + } +} diff --git a/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java b/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java new file mode 100644 index 0000000000..20a75e32bd --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java @@ -0,0 +1,184 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.reactive.result.method.annotation; + +import org.reactivestreams.Publisher; +import org.springframework.core.MethodParameter; +import org.springframework.core.ReactiveAdapter; +import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.expression.BeanResolver; +import org.springframework.expression.Expression; +import org.springframework.expression.ExpressionParser; +import org.springframework.expression.spel.standard.SpelExpressionParser; +import org.springframework.expression.spel.support.StandardEvaluationContext; +import org.springframework.security.core.annotation.CurrentSecurityContext; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.util.StringUtils; +import org.springframework.web.reactive.BindingContext; +import org.springframework.web.reactive.result.method.HandlerMethodArgumentResolverSupport; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +import java.lang.annotation.Annotation; + +/** + * Resolves the SecurityContext + * @author Dan Zheng + * @since 5.2.x + */ +public class CurrentSecurityContextArgumentResolver extends HandlerMethodArgumentResolverSupport { + + private ExpressionParser parser = new SpelExpressionParser(); + + private BeanResolver beanResolver; + + public CurrentSecurityContextArgumentResolver(ReactiveAdapterRegistry adapterRegistry) { + super(adapterRegistry); + } + + /** + * Sets the {@link BeanResolver} to be used on the expressions + * @param beanResolver the {@link BeanResolver} to use + */ + public void setBeanResolver(BeanResolver beanResolver) { + this.beanResolver = beanResolver; + } + + /** + * check if this argument resolve can support the parameter. + * @param parameter the method parameter. + * @return true = it can support parameter. + * + * @see + * org.springframework.web.reactive.result.method.HandlerMethodArgumentResolver# + * supportsParameter(org.springframework.core.MethodParameter) + */ + @Override + public boolean supportsParameter(MethodParameter parameter) { + return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; + } + + /** + * resolve the argument to inject into the controller parameter. + * @param parameter the method parameter. + * @param bindingContext the binding context. + * @param exchange the server web exchange. + * @return the reactive mono object result. + */ + @Override + public Mono resolveArgument(MethodParameter parameter, BindingContext bindingContext, + ServerWebExchange exchange) { + ReactiveAdapter adapter = getAdapterRegistry().getAdapter(parameter.getParameterType()); + Mono reactiveSecurityContext = ReactiveSecurityContextHolder.getContext(); + if (reactiveSecurityContext == null) { + return null; + } + return reactiveSecurityContext.flatMap( a -> { + Object p = resolveSecurityContext(parameter, a); + Mono o = Mono.justOrEmpty(p); + return adapter == null ? o : Mono.just(adapter.fromPublisher(o)); + }); + + } + + /** + * resolve the expression from {@link CurrentSecurityContext} annotation to get the value. + * @param parameter the method parameter. + * @param securityContext the security context. + * @return the resolved object from expression. + */ + private Object resolveSecurityContext(MethodParameter parameter, SecurityContext securityContext) { + CurrentSecurityContext securityContextAnnotation = findMethodAnnotation( + CurrentSecurityContext.class, parameter); + + Object securityContextResult = securityContext; + + String expressionToParse = securityContextAnnotation.expression(); + if (StringUtils.hasLength(expressionToParse)) { + StandardEvaluationContext context = new StandardEvaluationContext(); + context.setRootObject(securityContext); + context.setVariable("this", securityContext); + context.setBeanResolver(beanResolver); + + Expression expression = this.parser.parseExpression(expressionToParse); + securityContextResult = expression.getValue(context); + } + + if (isInvalidType(parameter, securityContextResult)) { + if (securityContextAnnotation.errorOnInvalidType()) { + throw new ClassCastException(securityContextResult + " is not assignable to " + + parameter.getParameterType()); + } + else { + return null; + } + } + + return securityContextResult; + } + + /** + * check if the retrieved value match with the parameter type. + * @param parameter the method parameter. + * @param reactiveSecurityContext the security context. + * @return true = is not invalid type. + */ + private boolean isInvalidType(MethodParameter parameter, Object reactiveSecurityContext) { + if (reactiveSecurityContext == null) { + return false; + } + Class typeToCheck = parameter.getParameterType(); + boolean isParameterPublisher = Publisher.class.isAssignableFrom(parameter.getParameterType()); + if (isParameterPublisher) { + ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter); + Class genericType = resolvableType.resolveGeneric(0); + if (genericType == null) { + return false; + } + typeToCheck = genericType; + } + return !typeToCheck.isAssignableFrom(reactiveSecurityContext.getClass()); + } + + /** + * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}. + * + * @param annotationClass the class of the {@link Annotation} to find on the + * {@link MethodParameter} + * @param parameter the {@link MethodParameter} to search for an {@link Annotation} + * @return the {@link Annotation} that was found or null. + */ + private T findMethodAnnotation(Class annotationClass, + MethodParameter parameter) { + T annotation = parameter.getParameterAnnotation(annotationClass); + if (annotation != null) { + return annotation; + } + Annotation[] annotationsToSearch = parameter.getParameterAnnotations(); + for (Annotation toSearch : annotationsToSearch) { + annotation = AnnotationUtils.findAnnotation(toSearch.annotationType(), + annotationClass); + if (annotation != null) { + return annotation; + } + } + return null; + } + +} diff --git a/web/src/test/java/org/springframework/security/web/bind/support/CurrentSecurityContextArgumentResolverTests.java b/web/src/test/java/org/springframework/security/web/bind/support/CurrentSecurityContextArgumentResolverTests.java new file mode 100644 index 0000000000..e0256ea8fa --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/bind/support/CurrentSecurityContextArgumentResolverTests.java @@ -0,0 +1,239 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.bind.support; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.core.MethodParameter; +import org.springframework.expression.spel.SpelEvaluationException; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.annotation.CurrentSecurityContext; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.userdetails.User; +import org.springframework.util.ReflectionUtils; + +import java.lang.reflect.Method; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.junit.Assert.fail; + +/** + * @author Dan Zheng + * @since 5.2.x + * + */ +public class CurrentSecurityContextArgumentResolverTests { + private Object expectedPrincipal; + private CurrentSecurityContextArgumentResolver resolver; + + @Before + public void setup() { + resolver = new CurrentSecurityContextArgumentResolver(); + } + + @After + public void cleanup() { + SecurityContextHolder.clearContext(); + } + + @Test + public void supportsParameterNoAnnotation() throws Exception { + assertThat(resolver.supportsParameter(showSecurityContextNoAnnotation())).isFalse(); + } + + @Test + public void supportsParameterAnnotation() throws Exception { + assertThat(resolver.supportsParameter(showSecurityContextAnnotation())).isTrue(); + } + + @Test + public void resolveArgumentNullAuthentication() throws Exception { + SecurityContext context = SecurityContextHolder.getContext(); + Authentication authentication = context.getAuthentication(); + context.setAuthentication(null); + assertThat(resolver.resolveArgument(showSecurityContextAuthenticationAnnotation(), null, null, null)) + .isNull(); + context.setAuthentication(authentication); + } + + @Test + public void resolveArgumentWithAuthentication() throws Exception { + String principal = "john"; + setAuthenticationPrincipal(principal); + Authentication auth1 = (Authentication) resolver.resolveArgument(showSecurityContextAuthenticationAnnotation(), null, null, null); + assertThat(auth1.getPrincipal()).isEqualTo(principal); + } + + @Test + public void resolveArgumentWithNullAuthentication() throws Exception { + SecurityContext context = SecurityContextHolder.getContext(); + Authentication authentication = context.getAuthentication(); + context.setAuthentication(null); + assertThatExceptionOfType(SpelEvaluationException.class) + .isThrownBy(() -> { + resolver.resolveArgument(showSecurityContextAuthenticationWithPrincipal(), null, null, null); + }); + context.setAuthentication(authentication); + } + + @Test + public void resolveArgumentWithOptionalPrincipal() throws Exception { + SecurityContext context = SecurityContextHolder.getContext(); + Authentication authentication = context.getAuthentication(); + context.setAuthentication(null); + Object principalResult = resolver.resolveArgument(showSecurityContextAuthenticationWithOptionalPrincipal(), null, null, null); + assertThat(principalResult).isNull(); + context.setAuthentication(authentication); + } + + @Test + public void resolveArgumentWithPrincipal() throws Exception { + String principal = "smith"; + setAuthenticationPrincipal(principal); + String principalResult = (String) resolver.resolveArgument(showSecurityContextAuthenticationWithPrincipal(), null, null, null); + assertThat(principalResult).isEqualTo(principal); + } + + @Test + public void resolveArgumentUserDetails() throws Exception { + setAuthenticationDetail(new User("my_user", "my_password", + AuthorityUtils.createAuthorityList("ROLE_USER"))); + + User u = (User) resolver.resolveArgument(showSecurityContextWithUserDetail(), null, null, + null); + assertThat(u.getUsername()).isEqualTo("my_user"); + } + + @Test + public void resolveArgumentSecurityContextErrorOnInvalidTypeImplicit() throws Exception { + String principal = "invalid_type_implicit"; + setAuthenticationPrincipal(principal); + assertThat(resolver.resolveArgument(showSecurityContextErrorOnInvalidTypeImplicit(), null, null, null)) + .isNull(); + } + + @Test + public void resolveArgumentSecurityContextErrorOnInvalidTypeFalse() throws Exception { + String principal = "invalid_type_false"; + setAuthenticationPrincipal(principal); + assertThat(resolver.resolveArgument(showSecurityContextErrorOnInvalidTypeFalse(), null, null, null)) + .isNull(); + } + + @Test + public void resolveArgumentSecurityContextErrorOnInvalidTypeTrue() throws Exception { + String principal = "invalid_type_true"; + setAuthenticationPrincipal(principal); + try { + resolver.resolveArgument(showSecurityContextErrorOnInvalidTypeTrue(), null, null, null); + fail("should not reach here"); + } catch(ClassCastException ex) {} + } + + private MethodParameter showSecurityContextNoAnnotation() { + return getMethodParameter("showSecurityContextNoAnnotation", String.class); + } + + private MethodParameter showSecurityContextAnnotation() { + return getMethodParameter("showSecurityContextAnnotation", SecurityContext.class); + } + + private MethodParameter showSecurityContextAuthenticationAnnotation() { + return getMethodParameter("showSecurityContextAuthenticationAnnotation", Authentication.class); + } + + private MethodParameter showSecurityContextAuthenticationWithOptionalPrincipal() { + return getMethodParameter("showSecurityContextAuthenticationWithOptionalPrincipal", Object.class); + } + + private MethodParameter showSecurityContextAuthenticationWithPrincipal() { + return getMethodParameter("showSecurityContextAuthenticationWithPrincipal", Object.class); + } + + private MethodParameter showSecurityContextWithUserDetail() { + return getMethodParameter("showSecurityContextWithUserDetail", Object.class); + } + + private MethodParameter showSecurityContextErrorOnInvalidTypeImplicit() { + return getMethodParameter("showSecurityContextErrorOnInvalidTypeImplicit", String.class); + } + + private MethodParameter showSecurityContextErrorOnInvalidTypeFalse() { + return getMethodParameter("showSecurityContextErrorOnInvalidTypeFalse", String.class); + } + + private MethodParameter showSecurityContextErrorOnInvalidTypeTrue() { + return getMethodParameter("showSecurityContextErrorOnInvalidTypeTrue", String.class); + } + + private MethodParameter getMethodParameter(String methodName, Class... paramTypes) { + Method method = ReflectionUtils.findMethod(TestController.class, methodName, + paramTypes); + return new MethodParameter(method, 0); + } + + public static class TestController { + public void showSecurityContextNoAnnotation(String user) { + } + + public void showSecurityContextAnnotation(@CurrentSecurityContext SecurityContext context) { + } + + public void showSecurityContextAuthenticationAnnotation(@CurrentSecurityContext(expression = "authentication") Authentication authentication) { + } + + public void showSecurityContextAuthenticationWithOptionalPrincipal(@CurrentSecurityContext(expression = "authentication?.principal") Object principal) { + } + + public void showSecurityContextAuthenticationWithPrincipal(@CurrentSecurityContext(expression = "authentication.principal") Object principal) { + } + + public void showSecurityContextWithUserDetail(@CurrentSecurityContext(expression = "authentication.details") Object detail) { + } + + public void showSecurityContextErrorOnInvalidTypeImplicit( + @CurrentSecurityContext String implicit) { + } + + public void showSecurityContextErrorOnInvalidTypeFalse( + @CurrentSecurityContext(errorOnInvalidType = false) String implicit) { + } + + public void showSecurityContextErrorOnInvalidTypeTrue( + @CurrentSecurityContext(errorOnInvalidType = true) String implicit) { + } + } + + private void setAuthenticationPrincipal(Object principal) { + SecurityContextHolder.getContext() + .setAuthentication( + new TestingAuthenticationToken(principal, "password", + "ROLE_USER")); + } + + private void setAuthenticationDetail(Object detail) { + TestingAuthenticationToken tat = new TestingAuthenticationToken("user", "password", + "ROLE_USER"); + tat.setDetails(detail); + SecurityContextHolder.getContext() + .setAuthentication(tat); + } +} diff --git a/web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolverTests.java b/web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolverTests.java new file mode 100644 index 0000000000..a34560d926 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolverTests.java @@ -0,0 +1,238 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.reactive.result.method.annotation; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.core.MethodParameter; +import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.expression.BeanResolver; +import org.springframework.expression.spel.SpelEvaluationException; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.annotation.CurrentSecurityContext; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.web.method.ResolvableMethod; +import org.springframework.web.reactive.BindingContext; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.fail; + + +/** + * @author Dan Zheng + * @since 5.2.x + */ +@RunWith(MockitoJUnitRunner.class) +public class CurrentSecurityContextArgumentResolverTests { + @Mock + ServerWebExchange exchange; + @Mock + BindingContext bindingContext; + @Mock + Authentication authentication; + @Mock + BeanResolver beanResolver; + @Mock + SecurityContext securityContext; + + ResolvableMethod securityContextMethod = ResolvableMethod.on(getClass()).named("securityContext").build(); + ResolvableMethod securityContextWithAuthentication = ResolvableMethod.on(getClass()).named("securityContextWithAuthentication").build(); + + CurrentSecurityContextArgumentResolver resolver; + + @Before + public void setup() { + resolver = new CurrentSecurityContextArgumentResolver(new ReactiveAdapterRegistry()); + this.resolver.setBeanResolver(this.beanResolver); + } + + @Test + public void supportsParameterCurrentSecurityContext() throws Exception { + assertThat(resolver.supportsParameter(this.securityContextMethod.arg(Mono.class, SecurityContext.class))).isTrue(); + } + + @Test + public void supportsParameterWithAuthentication() throws Exception { + assertThat(resolver.supportsParameter(this.securityContextWithAuthentication.arg(Mono.class, Authentication.class))).isTrue(); + } + + @Test + public void resolveArgumentWithNullSecurityContext() throws Exception { + MethodParameter parameter = ResolvableMethod.on(getClass()).named("securityContext").build().arg(Mono.class, SecurityContext.class); + Context context = ReactiveSecurityContextHolder.withSecurityContext(Mono.empty()); + Mono argument = resolver.resolveArgument(parameter, bindingContext, exchange); + Object obj = argument.subscriberContext(context).block(); + assertThat(obj).isNull(); + ReactiveSecurityContextHolder.clearContext(); + } + + @Test + public void resolveArgumentWithSecurityContext() throws Exception { + MethodParameter parameter = ResolvableMethod.on(getClass()).named("securityContext").build().arg(Mono.class, SecurityContext.class); + Authentication auth = buildAuthenticationWithPrincipal("hello"); + Context context = ReactiveSecurityContextHolder.withAuthentication(auth); + Mono argument = resolver.resolveArgument(parameter, bindingContext, exchange); + SecurityContext securityContext = (SecurityContext) argument.subscriberContext(context).cast(Mono.class).block().block(); + assertThat(securityContext.getAuthentication()).isSameAs(auth); + ReactiveSecurityContextHolder.clearContext(); + } + + @Test + public void resolveArgumentWithNullAuthentication1() throws Exception { + MethodParameter parameter = ResolvableMethod.on(getClass()).named("securityContext").build().arg(Mono.class, SecurityContext.class); + Authentication auth = null; + Context context = ReactiveSecurityContextHolder.withAuthentication(auth); + Mono argument = resolver.resolveArgument(parameter, bindingContext, exchange); + SecurityContext securityContext = (SecurityContext) argument.subscriberContext(context).cast(Mono.class).block().block(); + assertThat(securityContext.getAuthentication()).isNull(); + ReactiveSecurityContextHolder.clearContext(); + } + + @Test + public void resolveArgumentWithNullAuthentication2() throws Exception { + MethodParameter parameter = ResolvableMethod.on(getClass()).named("securityContextWithAuthentication").build().arg(Mono.class, Authentication.class); + Authentication auth = null; + Context context = ReactiveSecurityContextHolder.withAuthentication(auth); + Mono argument = resolver.resolveArgument(parameter, bindingContext, exchange); + Mono r = (Mono) argument.subscriberContext(context).block(); + assertThat(r.block()).isNull(); + ReactiveSecurityContextHolder.clearContext(); + } + + @Test + public void resolveArgumentWithAuthentication1() throws Exception { + MethodParameter parameter = ResolvableMethod.on(getClass()).named("securityContextWithAuthentication").build().arg(Mono.class, Authentication.class); + Authentication auth = buildAuthenticationWithPrincipal("authentication1"); + Context context = ReactiveSecurityContextHolder.withAuthentication(auth); + Mono argument = resolver.resolveArgument(parameter, bindingContext, exchange); + Mono auth1 = (Mono) argument.subscriberContext(context).block(); + assertThat(auth1.block()).isSameAs(auth); + ReactiveSecurityContextHolder.clearContext(); + } + + @Test + public void resolveArgumentWithNullAuthenticationOptional1() throws Exception { + MethodParameter parameter = ResolvableMethod.on(getClass()).named("securityContextWithDepthPropOptional").build().arg(Mono.class, Object.class); + Authentication auth = null; + Context context = ReactiveSecurityContextHolder.withAuthentication(auth); + Mono argument = resolver.resolveArgument(parameter, bindingContext, exchange); + Mono obj = (Mono) argument.subscriberContext(context).block(); + assertThat(obj.block()).isNull(); + ReactiveSecurityContextHolder.clearContext(); + } + + @Test + public void resolveArgumentWithAuthenticationOptional1() throws Exception { + MethodParameter parameter = ResolvableMethod.on(getClass()).named("securityContextWithDepthPropOptional").build().arg(Mono.class, Object.class); + Authentication auth = buildAuthenticationWithPrincipal("auth_optional"); + Context context = ReactiveSecurityContextHolder.withAuthentication(auth); + Mono argument = resolver.resolveArgument(parameter, bindingContext, exchange); + Mono obj = (Mono) argument.subscriberContext(context).block(); + assertThat(obj.block()).isEqualTo("auth_optional"); + ReactiveSecurityContextHolder.clearContext(); + } + + @Test + public void resolveArgumentWithNullDepthProp1() throws Exception { + MethodParameter parameter = ResolvableMethod.on(getClass()).named("securityContextWithDepthProp").build().arg(Mono.class, Object.class); + Authentication auth = null; + Context context = ReactiveSecurityContextHolder.withAuthentication(auth); + Mono argument = resolver.resolveArgument(parameter, bindingContext, exchange); + try { + Mono obj = (Mono) argument.subscriberContext(context).block(); + fail("should not reach here"); + } catch(SpelEvaluationException e) { + } + ReactiveSecurityContextHolder.clearContext(); + } + + @Test + public void resolveArgumentWithStringDepthProp() throws Exception { + MethodParameter parameter = ResolvableMethod.on(getClass()).named("securityContextWithDepthStringProp").build().arg(Mono.class, String.class); + Authentication auth = buildAuthenticationWithPrincipal("auth_string"); + Context context = ReactiveSecurityContextHolder.withAuthentication(auth); + Mono argument = resolver.resolveArgument(parameter, bindingContext, exchange); + Mono obj = (Mono) argument.subscriberContext(context).block(); + assertThat(obj.block()).isEqualTo("auth_string"); + ReactiveSecurityContextHolder.clearContext(); + } + + @Test + public void resolveArgumentWhenErrorOnInvalidTypeImplicit() throws Exception { + MethodParameter parameter = ResolvableMethod.on(getClass()).named("errorOnInvalidTypeWhenImplicit").build().arg(Mono.class, String.class); + Authentication auth = buildAuthenticationWithPrincipal("invalid_type_implicit"); + Context context = ReactiveSecurityContextHolder.withAuthentication(auth); + Mono argument = resolver.resolveArgument(parameter, bindingContext, exchange); + Mono obj = (Mono) argument.subscriberContext(context).block(); + assertThat(obj.block()).isNull(); + ReactiveSecurityContextHolder.clearContext(); + } + + @Test + public void resolveArgumentErrorOnInvalidTypeWhenExplicitFalse() throws Exception { + MethodParameter parameter = ResolvableMethod.on(getClass()).named("errorOnInvalidTypeWhenExplicitFalse").build().arg(Mono.class, String.class); + Authentication auth = buildAuthenticationWithPrincipal("error_on_invalid_type_explicit_false"); + Context context = ReactiveSecurityContextHolder.withAuthentication(auth); + Mono argument = resolver.resolveArgument(parameter, bindingContext, exchange); + Mono obj = (Mono) argument.subscriberContext(context).block(); + assertThat(obj.block()).isNull(); + ReactiveSecurityContextHolder.clearContext(); + } + + @Test + public void resolveArgumentErrorOnInvalidTypeWhenExplicitTrue() throws Exception { + MethodParameter parameter = ResolvableMethod.on(getClass()).named("errorOnInvalidTypeWhenExplicitTrue").build().arg(Mono.class, String.class); + Authentication auth = buildAuthenticationWithPrincipal("error_on_invalid_type_explicit_true"); + Context context = ReactiveSecurityContextHolder.withAuthentication(auth); + Mono argument = resolver.resolveArgument(parameter, bindingContext, exchange); + try { + Mono obj = (Mono) argument.subscriberContext(context).block(); + fail("should not reach here"); + } catch(ClassCastException ex) { + } + ReactiveSecurityContextHolder.clearContext(); + } + + void securityContext(@CurrentSecurityContext Mono monoSecurityContext) {} + + void securityContextWithAuthentication(@CurrentSecurityContext(expression = "authentication") Mono authentication) {} + + void securityContextWithDepthPropOptional(@CurrentSecurityContext(expression = "authentication?.principal") Mono principal) {} + + void securityContextWithDepthProp(@CurrentSecurityContext(expression = "authentication.principal") Mono principal) {} + + void securityContextWithDepthStringProp(@CurrentSecurityContext(expression = "authentication.principal") Mono principal) {} + + void errorOnInvalidTypeWhenImplicit(@CurrentSecurityContext Mono implicit) {} + + void errorOnInvalidTypeWhenExplicitFalse(@CurrentSecurityContext(errorOnInvalidType = false) Mono implicit) {} + + void errorOnInvalidTypeWhenExplicitTrue(@CurrentSecurityContext(errorOnInvalidType = true) Mono implicit) {} + + + private Authentication buildAuthenticationWithPrincipal(Object principal) { + return new TestingAuthenticationToken(principal, "password", + "ROLE_USER"); + } +}