diff --git a/messaging/spring-security-messaging.gradle b/messaging/spring-security-messaging.gradle index 0cb01c1553..a12a540d74 100644 --- a/messaging/spring-security-messaging.gradle +++ b/messaging/spring-security-messaging.gradle @@ -10,6 +10,7 @@ dependencies { optional project(':spring-security-web') optional 'org.springframework:spring-websocket' + optional 'io.projectreactor:reactor-core' optional 'javax.servlet:javax.servlet-api' testCompile project(path: ':spring-security-core', configuration: 'tests') diff --git a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java new file mode 100644 index 0000000000..c34fb576ac --- /dev/null +++ b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java @@ -0,0 +1,210 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.handler.invocation.reactive; + +import org.reactivestreams.Publisher; +import org.springframework.core.MethodParameter; +import org.springframework.core.ReactiveAdapter; +import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.expression.BeanResolver; +import org.springframework.expression.Expression; +import org.springframework.expression.ExpressionParser; +import org.springframework.expression.spel.standard.SpelExpressionParser; +import org.springframework.expression.spel.support.StandardEvaluationContext; +import org.springframework.messaging.Message; +import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.annotation.AuthenticationPrincipal; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.stereotype.Controller; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import reactor.core.publisher.Mono; + +import java.lang.annotation.Annotation; + +/** + * Allows resolving the {@link Authentication#getPrincipal()} using the + * {@link AuthenticationPrincipal} annotation. For example, the following + * {@link Controller}: + * + *
+ * @Controller
+ * public class MyController {
+ *     @MessageMapping("/im")
+ *     public void im(@AuthenticationPrincipal CustomUser customUser) {
+ *         // do something with CustomUser
+ *     }
+ * }
+ * 
+ * + *

+ * Will resolve the CustomUser argument using {@link Authentication#getPrincipal()} from + * the {@link ReactiveSecurityContextHolder}. If the {@link Authentication} or + * {@link Authentication#getPrincipal()} is null, it will return null. If the types do not + * match, null will be returned unless + * {@link AuthenticationPrincipal#errorOnInvalidType()} is true in which case a + * {@link ClassCastException} will be thrown. + * + *

+ * Alternatively, users can create a custom meta annotation as shown below: + * + *

+ * @Target({ ElementType.PARAMETER })
+ * @Retention(RetentionPolicy.RUNTIME)
+ * @AuthenticationPrincipal
+ * public @interface CurrentUser {
+ * }
+ * 
+ * + *

+ * The custom annotation can then be used instead. For example: + * + *

+ * @Controller
+ * public class MyController {
+ *     @MessageMapping("/im")
+ *     public void im(@CurrentUser CustomUser customUser) {
+ *         // do something with CustomUser
+ *     }
+ * }
+ * 
+ * @author Rob Winch + * @since 5.2 + */ +public class AuthenticationPrincipalArgumentResolver + implements HandlerMethodArgumentResolver { + + private ExpressionParser parser = new SpelExpressionParser(); + + private BeanResolver beanResolver; + + private ReactiveAdapterRegistry adapterRegistry = ReactiveAdapterRegistry + .getSharedInstance(); + + /** + * Sets the {@link BeanResolver} to be used on the expressions + * @param beanResolver the {@link BeanResolver} to use + */ + public void setBeanResolver(BeanResolver beanResolver) { + this.beanResolver = beanResolver; + } + + /** + * Sets the {@link ReactiveAdapterRegistry} to be used. + * @param adapterRegistry the {@link ReactiveAdapterRegistry} to use. Cannot be null. Default is + * {@link ReactiveAdapterRegistry#getSharedInstance()} + */ + public void setAdapterRegistry(ReactiveAdapterRegistry adapterRegistry) { + Assert.notNull(adapterRegistry, "adapterRegistry cannot be null"); + this.adapterRegistry = adapterRegistry; + } + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return findMethodAnnotation(AuthenticationPrincipal.class, parameter) != null; + } + + public Mono resolveArgument(MethodParameter parameter, Message message) { + ReactiveAdapter adapter = this.adapterRegistry + .getAdapter(parameter.getParameterType()); + return ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication).flatMap(a -> { + Object p = resolvePrincipal(parameter, a.getPrincipal()); + Mono principal = Mono.justOrEmpty(p); + return adapter == null ? + principal : + Mono.just(adapter.fromPublisher(principal)); + }); + } + + private Object resolvePrincipal(MethodParameter parameter, Object principal) { + AuthenticationPrincipal authPrincipal = findMethodAnnotation( + AuthenticationPrincipal.class, parameter); + + String expressionToParse = authPrincipal.expression(); + if (StringUtils.hasLength(expressionToParse)) { + StandardEvaluationContext context = new StandardEvaluationContext(); + context.setRootObject(principal); + context.setVariable("this", principal); + context.setBeanResolver(this.beanResolver); + + Expression expression = this.parser.parseExpression(expressionToParse); + principal = expression.getValue(context); + } + + if (isInvalidType(parameter, principal)) { + + if (authPrincipal.errorOnInvalidType()) { + throw new ClassCastException( + principal + " is not assignable to " + parameter + .getParameterType()); + } + else { + return null; + } + } + + return principal; + } + + private boolean isInvalidType(MethodParameter parameter, Object principal) { + if (principal == null) { + return false; + } + Class typeToCheck = parameter.getParameterType(); + boolean isParameterPublisher = Publisher.class + .isAssignableFrom(parameter.getParameterType()); + if (isParameterPublisher) { + ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter); + Class genericType = resolvableType.resolveGeneric(0); + if (genericType == null) { + return false; + } + typeToCheck = genericType; + } + return !typeToCheck.isAssignableFrom(principal.getClass()); + } + + /** + * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}. + * + * @param annotationClass the class of the {@link Annotation} to find on the + * {@link MethodParameter} + * @param parameter the {@link MethodParameter} to search for an {@link Annotation} + * @return the {@link Annotation} that was found or null. + */ + private T findMethodAnnotation(Class annotationClass, + MethodParameter parameter) { + T annotation = parameter.getParameterAnnotation(annotationClass); + if (annotation != null) { + return annotation; + } + Annotation[] annotationsToSearch = parameter.getParameterAnnotations(); + for (Annotation toSearch : annotationsToSearch) { + annotation = AnnotationUtils + .findAnnotation(toSearch.annotationType(), annotationClass); + if (annotation != null) { + return annotation; + } + } + return null; + } +} diff --git a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/ResolvableMethod.java b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/ResolvableMethod.java new file mode 100644 index 0000000000..7651e2eb2b --- /dev/null +++ b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/ResolvableMethod.java @@ -0,0 +1,692 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.handler.invocation; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.lang.reflect.Parameter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import org.aopalliance.intercept.MethodInterceptor; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.aop.target.EmptyTargetSource; +import org.springframework.cglib.core.SpringNamingPolicy; +import org.springframework.cglib.proxy.Callback; +import org.springframework.cglib.proxy.Enhancer; +import org.springframework.cglib.proxy.Factory; +import org.springframework.cglib.proxy.MethodProxy; +import org.springframework.core.LocalVariableTableParameterNameDiscoverer; +import org.springframework.core.MethodIntrospector; +import org.springframework.core.MethodParameter; +import org.springframework.core.ParameterNameDiscoverer; +import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.core.annotation.SynthesizingMethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.objenesis.ObjenesisException; +import org.springframework.objenesis.SpringObjenesis; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; +import org.springframework.util.ReflectionUtils; + +import static java.util.stream.Collectors.joining; + +/** + * NOTE: This class is a replica of the same class in spring-web so it can + * be used for tests in spring-messaging. + * + *

Convenience class to resolve method parameters from hints. + * + *

Background

+ * + *

When testing annotated methods we create test classes such as + * "TestController" with a diverse range of method signatures representing + * supported annotations and argument types. It becomes challenging to use + * naming strategies to keep track of methods and arguments especially in + * combination with variables for reflection metadata. + * + *

The idea with {@link ResolvableMethod} is NOT to rely on naming techniques + * but to use hints to zero in on method parameters. Such hints can be strongly + * typed and explicit about what is being tested. + * + *

1. Declared Return Type

+ * + * When testing return types it's likely to have many methods with a unique + * return type, possibly with or without an annotation. + * + *
+ * import static org.springframework.web.method.ResolvableMethod.on;
+ * import static org.springframework.web.method.MvcAnnotationPredicates.requestMapping;
+ *
+ * // Return type
+ * on(TestController.class).resolveReturnType(Foo.class);
+ * on(TestController.class).resolveReturnType(List.class, Foo.class);
+ * on(TestController.class).resolveReturnType(Mono.class, responseEntity(Foo.class));
+ *
+ * // Annotation + return type
+ * on(TestController.class).annotPresent(RequestMapping.class).resolveReturnType(Bar.class);
+ *
+ * // Annotation not present
+ * on(TestController.class).annotNotPresent(RequestMapping.class).resolveReturnType();
+ *
+ * // Annotation with attributes
+ * on(TestController.class).annot(requestMapping("/foo").params("p")).resolveReturnType();
+ * 
+ * + *

2. Method Arguments

+ * + * When testing method arguments it's more likely to have one or a small number + * of methods with a wide array of argument types and parameter annotations. + * + *
+ * import static org.springframework.web.method.MvcAnnotationPredicates.requestParam;
+ *
+ * ResolvableMethod testMethod = ResolvableMethod.on(getClass()).named("handle").build();
+ *
+ * testMethod.arg(Foo.class);
+ * testMethod.annotPresent(RequestParam.class).arg(Integer.class);
+ * testMethod.annotNotPresent(RequestParam.class)).arg(Integer.class);
+ * testMethod.annot(requestParam().name("c").notRequired()).arg(Integer.class);
+ * 
+ * + *

3. Mock Handler Method Invocation

+ * + * Locate a method by invoking it through a proxy of the target handler: + * + *
+ * ResolvableMethod.on(TestController.class).mockCall(o -> o.handle(null)).method();
+ * 
+ * + * @author Rossen Stoyanchev + * @since 5.2 + */ +public class ResolvableMethod { + + private static final Log logger = LogFactory.getLog(ResolvableMethod.class); + + private static final SpringObjenesis objenesis = new SpringObjenesis(); + + private static final ParameterNameDiscoverer nameDiscoverer = new LocalVariableTableParameterNameDiscoverer(); + + // Matches ValueConstants.DEFAULT_NONE (spring-web and spring-messaging) + private static final String DEFAULT_VALUE_NONE = "\n\t\t\n\t\t\n\uE000\uE001\uE002\n\t\t\t\t\n"; + + + private final Method method; + + + private ResolvableMethod(Method method) { + Assert.notNull(method, "'method' is required"); + this.method = method; + } + + + /** + * Return the resolved method. + */ + public Method method() { + return this.method; + } + + /** + * Return the declared return type of the resolved method. + */ + public MethodParameter returnType() { + return new SynthesizingMethodParameter(this.method, -1); + } + + /** + * Find a unique argument matching the given type. + * @param type the expected type + * @param generics optional array of generic types + */ + public MethodParameter arg(Class type, Class... generics) { + return new ArgResolver().arg(type, generics); + } + + /** + * Find a unique argument matching the given type. + * @param type the expected type + * @param generic at least one generic type + * @param generics optional array of generic types + */ + public MethodParameter arg(Class type, ResolvableType generic, ResolvableType... generics) { + return new ArgResolver().arg(type, generic, generics); + } + + /** + * Find a unique argument matching the given type. + * @param type the expected type + */ + public MethodParameter arg(ResolvableType type) { + return new ArgResolver().arg(type); + } + + /** + * Filter on method arguments with annotation. + * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + */ + @SafeVarargs + public final ArgResolver annot(Predicate... filter) { + return new ArgResolver(filter); + } + + @SafeVarargs + public final ArgResolver annotPresent(Class... annotationTypes) { + return new ArgResolver().annotPresent(annotationTypes); + } + + /** + * Filter on method arguments that don't have the given annotation type(s). + * @param annotationTypes the annotation types + */ + @SafeVarargs + public final ArgResolver annotNotPresent(Class... annotationTypes) { + return new ArgResolver().annotNotPresent(annotationTypes); + } + + + @Override + public String toString() { + return "ResolvableMethod=" + formatMethod(); + } + + + private String formatMethod() { + return (method().getName() + + Arrays.stream(this.method.getParameters()) + .map(this::formatParameter) + .collect(joining(",\n\t", "(\n\t", "\n)"))); + } + + private String formatParameter(Parameter param) { + Annotation[] anns = param.getAnnotations(); + return (anns.length > 0 ? + Arrays.stream(anns).map(this::formatAnnotation).collect(joining(",", "[", "]")) + " " + param : + param.toString()); + } + + private String formatAnnotation(Annotation annotation) { + Map map = AnnotationUtils.getAnnotationAttributes(annotation); + map.forEach((key, value) -> { + if (value.equals(DEFAULT_VALUE_NONE)) { + map.put(key, "NONE"); + } + }); + return annotation.annotationType().getName() + map; + } + + private static ResolvableType toResolvableType(Class type, Class... generics) { + return (ObjectUtils.isEmpty(generics) ? ResolvableType.forClass(type) : + ResolvableType.forClassWithGenerics(type, generics)); + } + + private static ResolvableType toResolvableType(Class type, ResolvableType generic, ResolvableType... generics) { + ResolvableType[] genericTypes = new ResolvableType[generics.length + 1]; + genericTypes[0] = generic; + System.arraycopy(generics, 0, genericTypes, 1, generics.length); + return ResolvableType.forClassWithGenerics(type, genericTypes); + } + + + /** + * Create a {@code ResolvableMethod} builder for the given handler class. + */ + public static Builder on(Class objectClass) { + return new Builder<>(objectClass); + } + + + /** + * Builder for {@code ResolvableMethod}. + */ + public static class Builder { + + private final Class objectClass; + + private final List> filters = new ArrayList<>(4); + + + private Builder(Class objectClass) { + Assert.notNull(objectClass, "Class must not be null"); + this.objectClass = objectClass; + } + + + private void addFilter(String message, Predicate filter) { + this.filters.add(new LabeledPredicate<>(message, filter)); + } + + /** + * Filter on methods with the given name. + */ + public Builder named(String methodName) { + addFilter("methodName=" + methodName, method -> method.getName().equals(methodName)); + return this; + } + + /** + * Filter on methods with the given parameter types. + */ + public Builder argTypes(Class... argTypes) { + addFilter("argTypes=" + Arrays.toString(argTypes), method -> + ObjectUtils.isEmpty(argTypes) ? method.getParameterCount() == 0 : + Arrays.equals(method.getParameterTypes(), argTypes)); + return this; + } + + /** + * Filter on annotated methods. + * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + */ + @SafeVarargs + public final Builder annot(Predicate... filters) { + this.filters.addAll(Arrays.asList(filters)); + return this; + } + + /** + * Filter on methods annotated with the given annotation type. + * @see #annot(Predicate[]) + * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + */ + @SafeVarargs + public final Builder annotPresent(Class... annotationTypes) { + String message = "annotationPresent=" + Arrays.toString(annotationTypes); + addFilter(message, method -> + Arrays.stream(annotationTypes).allMatch(annotType -> + AnnotatedElementUtils.findMergedAnnotation(method, annotType) != null)); + return this; + } + + /** + * Filter on methods not annotated with the given annotation type. + */ + @SafeVarargs + public final Builder annotNotPresent(Class... annotationTypes) { + String message = "annotationNotPresent=" + Arrays.toString(annotationTypes); + addFilter(message, method -> { + if (annotationTypes.length != 0) { + return Arrays.stream(annotationTypes).noneMatch(annotType -> + AnnotatedElementUtils.findMergedAnnotation(method, annotType) != null); + } + else { + return method.getAnnotations().length == 0; + } + }); + return this; + } + + /** + * Filter on methods returning the given type. + * @param returnType the return type + * @param generics optional array of generic types + */ + public Builder returning(Class returnType, Class... generics) { + return returning(toResolvableType(returnType, generics)); + } + + /** + * Filter on methods returning the given type with generics. + * @param returnType the return type + * @param generic at least one generic type + * @param generics optional extra generic types + */ + public Builder returning(Class returnType, ResolvableType generic, ResolvableType... generics) { + return returning(toResolvableType(returnType, generic, generics)); + } + + /** + * Filter on methods returning the given type. + * @param returnType the return type + */ + public Builder returning(ResolvableType returnType) { + String expected = returnType.toString(); + String message = "returnType=" + expected; + addFilter(message, m -> expected.equals(ResolvableType.forMethodReturnType(m).toString())); + return this; + } + + /** + * Build a {@code ResolvableMethod} from the provided filters which must + * resolve to a unique, single method. + *

See additional resolveXxx shortcut methods going directly to + * {@link Method} or return type parameter. + * @throws IllegalStateException for no match or multiple matches + */ + public ResolvableMethod method() { + Set methods = MethodIntrospector.selectMethods(this.objectClass, this::isMatch); + Assert.state(!methods.isEmpty(), () -> "No matching method: " + this); + Assert.state(methods.size() == 1, () -> "Multiple matching methods: " + this + formatMethods(methods)); + return new ResolvableMethod(methods.iterator().next()); + } + + private boolean isMatch(Method method) { + return this.filters.stream().allMatch(p -> p.test(method)); + } + + private String formatMethods(Set methods) { + return "\nMatched:\n" + methods.stream() + .map(Method::toGenericString).collect(joining(",\n\t", "[\n\t", "\n]")); + } + + public ResolvableMethod mockCall(Consumer invoker) { + MethodInvocationInterceptor interceptor = new MethodInvocationInterceptor(); + T proxy = initProxy(this.objectClass, interceptor); + invoker.accept(proxy); + Method method = interceptor.getInvokedMethod(); + return new ResolvableMethod(method); + } + + + // Build & resolve shortcuts... + + /** + * Resolve and return the {@code Method} equivalent to: + *

{@code build().method()} + */ + public final Method resolveMethod() { + return method().method(); + } + + /** + * Resolve and return the {@code Method} equivalent to: + *

{@code named(methodName).build().method()} + */ + public Method resolveMethod(String methodName) { + return named(methodName).method().method(); + } + + /** + * Resolve and return the declared return type equivalent to: + *

{@code build().returnType()} + */ + public final MethodParameter resolveReturnType() { + return method().returnType(); + } + + /** + * Shortcut to the unique return type equivalent to: + *

{@code returning(returnType).build().returnType()} + * @param returnType the return type + * @param generics optional array of generic types + */ + public MethodParameter resolveReturnType(Class returnType, Class... generics) { + return returning(returnType, generics).method().returnType(); + } + + /** + * Shortcut to the unique return type equivalent to: + *

{@code returning(returnType).build().returnType()} + * @param returnType the return type + * @param generic at least one generic type + * @param generics optional extra generic types + */ + public MethodParameter resolveReturnType(Class returnType, ResolvableType generic, + ResolvableType... generics) { + + return returning(returnType, generic, generics).method().returnType(); + } + + public MethodParameter resolveReturnType(ResolvableType returnType) { + return returning(returnType).method().returnType(); + } + + + @Override + public String toString() { + return "ResolvableMethod.Builder[\n" + + "\tobjectClass = " + this.objectClass.getName() + ",\n" + + "\tfilters = " + formatFilters() + "\n]"; + } + + private String formatFilters() { + return this.filters.stream().map(Object::toString) + .collect(joining(",\n\t\t", "[\n\t\t", "\n\t]")); + } + } + + + /** + * Predicate with a descriptive label. + */ + private static class LabeledPredicate implements Predicate { + + private final String label; + + private final Predicate delegate; + + + private LabeledPredicate(String label, Predicate delegate) { + this.label = label; + this.delegate = delegate; + } + + + @Override + public boolean test(T method) { + return this.delegate.test(method); + } + + @Override + public Predicate and(Predicate other) { + return this.delegate.and(other); + } + + @Override + public Predicate negate() { + return this.delegate.negate(); + } + + @Override + public Predicate or(Predicate other) { + return this.delegate.or(other); + } + + @Override + public String toString() { + return this.label; + } + } + + + /** + * Resolver for method arguments. + */ + public class ArgResolver { + + private final List> filters = new ArrayList<>(4); + + + @SafeVarargs + private ArgResolver(Predicate... filter) { + this.filters.addAll(Arrays.asList(filter)); + } + + /** + * Filter on method arguments with annotations. + * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + */ + @SafeVarargs + public final ArgResolver annot(Predicate... filters) { + this.filters.addAll(Arrays.asList(filters)); + return this; + } + + /** + * Filter on method arguments that have the given annotations. + * @param annotationTypes the annotation types + * @see #annot(Predicate[]) + * See {@link org.springframework.web.method.MvcAnnotationPredicates}. + */ + @SafeVarargs + public final ArgResolver annotPresent(Class... annotationTypes) { + this.filters.add(param -> Arrays.stream(annotationTypes).allMatch(param::hasParameterAnnotation)); + return this; + } + + /** + * Filter on method arguments that don't have the given annotations. + * @param annotationTypes the annotation types + */ + @SafeVarargs + public final ArgResolver annotNotPresent(Class... annotationTypes) { + this.filters.add(param -> + (annotationTypes.length > 0 ? + Arrays.stream(annotationTypes).noneMatch(param::hasParameterAnnotation) : + param.getParameterAnnotations().length == 0)); + return this; + } + + /** + * Resolve the argument also matching to the given type. + * @param type the expected type + */ + public MethodParameter arg(Class type, Class... generics) { + return arg(toResolvableType(type, generics)); + } + + /** + * Resolve the argument also matching to the given type. + * @param type the expected type + */ + public MethodParameter arg(Class type, ResolvableType generic, ResolvableType... generics) { + return arg(toResolvableType(type, generic, generics)); + } + + /** + * Resolve the argument also matching to the given type. + * @param type the expected type + */ + public MethodParameter arg(ResolvableType type) { + this.filters.add(p -> type.toString().equals(ResolvableType.forMethodParameter(p).toString())); + return arg(); + } + + /** + * Resolve the argument. + */ + public final MethodParameter arg() { + List matches = applyFilters(); + Assert.state(!matches.isEmpty(), () -> + "No matching arg in method\n" + formatMethod()); + Assert.state(matches.size() == 1, () -> + "Multiple matching args in method\n" + formatMethod() + "\nMatches:\n\t" + matches); + return matches.get(0); + } + + + private List applyFilters() { + List matches = new ArrayList<>(); + for (int i = 0; i < method.getParameterCount(); i++) { + MethodParameter param = new SynthesizingMethodParameter(method, i); + param.initParameterNameDiscovery(nameDiscoverer); + if (this.filters.stream().allMatch(p -> p.test(param))) { + matches.add(param); + } + } + return matches; + } + } + + + private static class MethodInvocationInterceptor + implements org.springframework.cglib.proxy.MethodInterceptor, MethodInterceptor { + + private Method invokedMethod; + + + Method getInvokedMethod() { + return this.invokedMethod; + } + + @Override + @Nullable + public Object intercept(Object object, Method method, Object[] args, MethodProxy proxy) { + if (ReflectionUtils.isObjectMethod(method)) { + return ReflectionUtils.invokeMethod(method, object, args); + } + else { + this.invokedMethod = method; + return null; + } + } + + @Override + @Nullable + public Object invoke(org.aopalliance.intercept.MethodInvocation inv) throws Throwable { + return intercept(inv.getThis(), inv.getMethod(), inv.getArguments(), null); + } + } + + @SuppressWarnings("unchecked") + private static T initProxy(Class type, MethodInvocationInterceptor interceptor) { + Assert.notNull(type, "'type' must not be null"); + if (type.isInterface()) { + ProxyFactory factory = new ProxyFactory(EmptyTargetSource.INSTANCE); + factory.addInterface(type); + factory.addInterface(Supplier.class); + factory.addAdvice(interceptor); + return (T) factory.getProxy(); + } + + else { + Enhancer enhancer = new Enhancer(); + enhancer.setSuperclass(type); + enhancer.setInterfaces(new Class[] {Supplier.class}); + enhancer.setNamingPolicy(SpringNamingPolicy.INSTANCE); + enhancer.setCallbackType(org.springframework.cglib.proxy.MethodInterceptor.class); + + Class proxyClass = enhancer.createClass(); + Object proxy = null; + + if (objenesis.isWorthTrying()) { + try { + proxy = objenesis.newInstance(proxyClass, enhancer.getUseCache()); + } + catch (ObjenesisException ex) { + logger.debug("Objenesis failed, falling back to default constructor", ex); + } + } + + if (proxy == null) { + try { + proxy = ReflectionUtils.accessibleConstructor(proxyClass).newInstance(); + } + catch (Throwable ex) { + throw new IllegalStateException("Unable to instantiate proxy " + + "via both Objenesis and default constructor fails as well", ex); + } + } + + ((Factory) proxy).setCallbacks(new Callback[] {interceptor}); + return (T) proxy; + } + } + +} diff --git a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolverTests.java b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolverTests.java new file mode 100644 index 0000000000..f33d50d008 --- /dev/null +++ b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolverTests.java @@ -0,0 +1,113 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.handler.invocation.reactive; + +import org.junit.Test; +import org.springframework.core.MethodParameter; +import org.springframework.core.annotation.SynthesizingMethodParameter; +import org.springframework.security.authentication.TestAuthentication; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.annotation.AuthenticationPrincipal; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.messaging.handler.invocation.ResolvableMethod; +import reactor.core.publisher.Mono; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import static org.assertj.core.api.Assertions.*; + +/** + * @author Rob Winch + */ +public class AuthenticationPrincipalArgumentResolverTests { + private AuthenticationPrincipalArgumentResolver resolver = new AuthenticationPrincipalArgumentResolver(); + + @Test + public void supportsParameterWhenAuthenticationPrincipalThenTrue() { + assertThat(this.resolver.supportsParameter(arg0("authenticationPrincipalOnMonoUserDetails"))).isTrue(); + } + + @Test + public void resolveArgumentWhenAuthenticationPrincipalAndEmptyContextThenNull() { + Object result = this.resolver.resolveArgument(arg0("authenticationPrincipalOnMonoUserDetails"), null).block(); + assertThat(result).isNull(); + } + + @Test + public void resolveArgumentWhenAuthenticationPrincipalThenFound() { + Authentication authentication = TestAuthentication.authenticatedUser(); + Mono result = (Mono) this.resolver.resolveArgument(arg0("authenticationPrincipalOnMonoUserDetails"), null) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + assertThat(result.block()).isEqualTo(authentication.getPrincipal()); + } + + @SuppressWarnings("unused") + private void authenticationPrincipalOnMonoUserDetails(@AuthenticationPrincipal Mono user) { + } + + @Test + public void supportsParameterWhenCurrentUserThenTrue() { + assertThat(this.resolver.supportsParameter(arg0("currentUserOnMonoUserDetails"))).isTrue(); + } + + @Test + public void resolveArgumentWhenMonoAndAuthenticationPrincipalThenFound() { + Authentication authentication = TestAuthentication.authenticatedUser(); + Mono result = (Mono) this.resolver.resolveArgument(arg0("currentUserOnMonoUserDetails"), null) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + assertThat(result.block()).isEqualTo(authentication.getPrincipal()); + } + + @SuppressWarnings("unused") + private void currentUserOnMonoUserDetails(@CurrentUser Mono user) { + } + + @Test + public void resolveArgumentWhenExpressionThenFound() { + Authentication authentication = TestAuthentication.authenticatedUser(); + Mono result = (Mono) this.resolver.resolveArgument(arg0("authenticationPrincipalExpression"), null) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + assertThat(result.block()).isEqualTo(authentication.getName()); + } + + @SuppressWarnings("unused") + private void authenticationPrincipalExpression(@AuthenticationPrincipal(expression = "username") Mono username) { + } + + @Test + public void supportsParameterWhenNotAnnotatedThenFalse() { + assertThat(this.resolver.supportsParameter(arg0("monoUserDetails"))).isFalse(); + } + + @SuppressWarnings("unused") + private void monoUserDetails(Mono user) { + } + + private MethodParameter arg0(String methodName) { + ResolvableMethod method = ResolvableMethod.on(getClass()).named(methodName).method(); + return new SynthesizingMethodParameter(method.method(), 0); + } + + @AuthenticationPrincipal + @Retention(RetentionPolicy.RUNTIME) + @interface CurrentUser {} +}