diff --git a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java new file mode 100644 index 0000000000..f771a7d03a --- /dev/null +++ b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java @@ -0,0 +1,209 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.handler.invocation.reactive; + +import org.reactivestreams.Publisher; +import org.springframework.core.MethodParameter; +import org.springframework.core.ReactiveAdapter; +import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.expression.BeanResolver; +import org.springframework.expression.Expression; +import org.springframework.expression.ExpressionParser; +import org.springframework.expression.spel.standard.SpelExpressionParser; +import org.springframework.expression.spel.support.StandardEvaluationContext; +import org.springframework.messaging.Message; +import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.annotation.CurrentSecurityContext; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.stereotype.Controller; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import reactor.core.publisher.Mono; + +import java.lang.annotation.Annotation; + +/** + * Allows resolving the {@link Authentication#getPrincipal()} using the + * {@link CurrentSecurityContext} annotation. For example, the following + * {@link Controller}: + * + *
+ * @Controller
+ * public class MyController {
+ *     @MessageMapping("/im")
+ *     public void im(@CurrentSecurityContext SecurityContext context) {
+ *         // do something with context
+ *     }
+ * }
+ * 
+ * + *

+ * Will resolve the SecurityContext argument using the {@link ReactiveSecurityContextHolder}. + * If the {@link SecurityContext} is empty, it will return null. If the types do not + * match, null will be returned unless + * {@link CurrentSecurityContext#errorOnInvalidType()} is true in which case a + * {@link ClassCastException} will be thrown. + * + *

+ * Alternatively, users can create a custom meta annotation as shown below: + * + *

+ * @Target({ ElementType.PARAMETER })
+ * @Retention(RetentionPolicy.RUNTIME)
+ * @CurrentSecurityContext(expression = "authentication?.principal")
+ * public @interface CurrentUser {
+ * }
+ * 
+ * + *

+ * The custom annotation can then be used instead. For example: + * + *

+ * @Controller
+ * public class MyController {
+ *     @MessageMapping("/im")
+ *     public void im(@CurrentUser CustomUser customUser) {
+ *         // do something with CustomUser
+ *     }
+ * }
+ * 
+ * @author Rob Winch + * @since 5.2 + */ +public class CurrentSecurityContextArgumentResolver + implements HandlerMethodArgumentResolver { + + private ExpressionParser parser = new SpelExpressionParser(); + + private BeanResolver beanResolver; + + private ReactiveAdapterRegistry adapterRegistry = ReactiveAdapterRegistry + .getSharedInstance(); + + /** + * Sets the {@link BeanResolver} to be used on the expressions + * @param beanResolver the {@link BeanResolver} to use + */ + public void setBeanResolver(BeanResolver beanResolver) { + this.beanResolver = beanResolver; + } + + /** + * Sets the {@link ReactiveAdapterRegistry} to be used. + * @param adapterRegistry the {@link ReactiveAdapterRegistry} to use. Cannot be null. Default is + * {@link ReactiveAdapterRegistry#getSharedInstance()} + */ + public void setAdapterRegistry(ReactiveAdapterRegistry adapterRegistry) { + Assert.notNull(adapterRegistry, "adapterRegistry cannot be null"); + this.adapterRegistry = adapterRegistry; + } + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; + } + + public Mono resolveArgument(MethodParameter parameter, Message message) { + ReactiveAdapter adapter = this.adapterRegistry + .getAdapter(parameter.getParameterType()); + return ReactiveSecurityContextHolder.getContext() + .flatMap(securityContext -> { + Object sc = resolveSecurityContext(parameter, securityContext); + Mono result = Mono.justOrEmpty(sc); + return adapter == null ? + result : + Mono.just(adapter.fromPublisher(result)); + }); + } + + private Object resolveSecurityContext(MethodParameter parameter, Object securityContext) { + CurrentSecurityContext contextAnno = findMethodAnnotation( + CurrentSecurityContext.class, parameter); + + String expressionToParse = contextAnno.expression(); + if (StringUtils.hasLength(expressionToParse)) { + StandardEvaluationContext context = new StandardEvaluationContext(); + context.setRootObject(securityContext); + context.setVariable("this", securityContext); + context.setBeanResolver(this.beanResolver); + + Expression expression = this.parser.parseExpression(expressionToParse); + securityContext = expression.getValue(context); + } + + if (isInvalidType(parameter, securityContext)) { + + if (contextAnno.errorOnInvalidType()) { + throw new ClassCastException( + securityContext + " is not assignable to " + parameter + .getParameterType()); + } + else { + return null; + } + } + + return securityContext; + } + + private boolean isInvalidType(MethodParameter parameter, Object value) { + if (value == null) { + return false; + } + Class typeToCheck = parameter.getParameterType(); + boolean isParameterPublisher = Publisher.class + .isAssignableFrom(parameter.getParameterType()); + if (isParameterPublisher) { + ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter); + Class genericType = resolvableType.resolveGeneric(0); + if (genericType == null) { + return false; + } + typeToCheck = genericType; + } + return !typeToCheck.isAssignableFrom(value.getClass()); + } + + /** + * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}. + * + * @param annotationClass the class of the {@link Annotation} to find on the + * {@link MethodParameter} + * @param parameter the {@link MethodParameter} to search for an {@link Annotation} + * @return the {@link Annotation} that was found or null. + */ + private T findMethodAnnotation(Class annotationClass, + MethodParameter parameter) { + T annotation = parameter.getParameterAnnotation(annotationClass); + if (annotation != null) { + return annotation; + } + Annotation[] annotationsToSearch = parameter.getParameterAnnotations(); + for (Annotation toSearch : annotationsToSearch) { + annotation = AnnotationUtils + .findAnnotation(toSearch.annotationType(), annotationClass); + if (annotation != null) { + return annotation; + } + } + return null; + } +} diff --git a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java new file mode 100644 index 0000000000..dba93aa552 --- /dev/null +++ b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java @@ -0,0 +1,114 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.handler.invocation.reactive; + +import org.junit.Test; +import org.springframework.core.MethodParameter; +import org.springframework.core.annotation.SynthesizingMethodParameter; +import org.springframework.security.authentication.TestAuthentication; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.annotation.CurrentSecurityContext; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.messaging.handler.invocation.ResolvableMethod; +import reactor.core.publisher.Mono; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Rob Winch + */ +public class CurrentSecurityContextArgumentResolverTests { + private CurrentSecurityContextArgumentResolver resolver = new CurrentSecurityContextArgumentResolver(); + + @Test + public void supportsParameterWhenAuthenticationPrincipalThenTrue() { + assertThat(this.resolver.supportsParameter(arg0("currentSecurityContextOnMonoSecurityContext"))).isTrue(); + } + + @Test + public void resolveArgumentWhenAuthenticationPrincipalAndEmptyContextThenNull() { + Object result = this.resolver.resolveArgument(arg0("currentSecurityContextOnMonoSecurityContext"), null).block(); + assertThat(result).isNull(); + } + + @Test + public void resolveArgumentWhenAuthenticationPrincipalThenFound() { + Authentication authentication = TestAuthentication.authenticatedUser(); + Mono result = (Mono) this.resolver.resolveArgument(arg0("currentSecurityContextOnMonoSecurityContext"), null) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + assertThat(result.block().getAuthentication()).isEqualTo(authentication); + } + + @SuppressWarnings("unused") + private void currentSecurityContextOnMonoSecurityContext(@CurrentSecurityContext Mono context) { + } + + @Test + public void supportsParameterWhenCurrentUserThenTrue() { + assertThat(this.resolver.supportsParameter(arg0("currentUserOnMonoUserDetails"))).isTrue(); + } + + @Test + public void resolveArgumentWhenMonoAndAuthenticationPrincipalThenFound() { + Authentication authentication = TestAuthentication.authenticatedUser(); + Mono result = (Mono) this.resolver.resolveArgument(arg0("currentUserOnMonoUserDetails"), null) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + assertThat(result.block()).isEqualTo(authentication.getPrincipal()); + } + + @SuppressWarnings("unused") + private void currentUserOnMonoUserDetails(@CurrentUser Mono user) { + } + + @Test + public void resolveArgumentWhenExpressionThenFound() { + Authentication authentication = TestAuthentication.authenticatedUser(); + Mono result = (Mono) this.resolver.resolveArgument(arg0("authenticationPrincipalExpression"), null) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + assertThat(result.block()).isEqualTo(authentication.getName()); + } + + @SuppressWarnings("unused") + private void authenticationPrincipalExpression(@CurrentSecurityContext(expression = "authentication?.principal?.username") Mono username) { + } + + @Test + public void supportsParameterWhenNotAnnotatedThenFalse() { + assertThat(this.resolver.supportsParameter(arg0("monoUserDetails"))).isFalse(); + } + + @SuppressWarnings("unused") + private void monoUserDetails(Mono user) { + } + + private MethodParameter arg0(String methodName) { + ResolvableMethod method = ResolvableMethod.on(getClass()).named(methodName).method(); + return new SynthesizingMethodParameter(method.method(), 0); + } + + @CurrentSecurityContext(expression = "authentication?.principal") + @Retention(RetentionPolicy.RUNTIME) + @interface CurrentUser {} +}