invoker) {
+ MethodInvocationInterceptor interceptor = new MethodInvocationInterceptor();
+ T proxy = initProxy(this.objectClass, interceptor);
+ invoker.accept(proxy);
+ Method method = interceptor.getInvokedMethod();
+ return new ResolvableMethod(method);
+ }
+
+
+ // Build & resolve shortcuts...
+
+ /**
+ * Resolve and return the {@code Method} equivalent to:
+ * {@code build().method()}
+ */
+ public final Method resolveMethod() {
+ return method().method();
+ }
+
+ /**
+ * Resolve and return the {@code Method} equivalent to:
+ *
{@code named(methodName).build().method()}
+ */
+ public Method resolveMethod(String methodName) {
+ return named(methodName).method().method();
+ }
+
+ /**
+ * Resolve and return the declared return type equivalent to:
+ *
{@code build().returnType()}
+ */
+ public final MethodParameter resolveReturnType() {
+ return method().returnType();
+ }
+
+ /**
+ * Shortcut to the unique return type equivalent to:
+ *
{@code returning(returnType).build().returnType()}
+ * @param returnType the return type
+ * @param generics optional array of generic types
+ */
+ public MethodParameter resolveReturnType(Class> returnType, Class>... generics) {
+ return returning(returnType, generics).method().returnType();
+ }
+
+ /**
+ * Shortcut to the unique return type equivalent to:
+ *
{@code returning(returnType).build().returnType()}
+ * @param returnType the return type
+ * @param generic at least one generic type
+ * @param generics optional extra generic types
+ */
+ public MethodParameter resolveReturnType(Class> returnType, ResolvableType generic,
+ ResolvableType... generics) {
+
+ return returning(returnType, generic, generics).method().returnType();
+ }
+
+ public MethodParameter resolveReturnType(ResolvableType returnType) {
+ return returning(returnType).method().returnType();
+ }
+
+
+ @Override
+ public String toString() {
+ return "ResolvableMethod.Builder[\n" +
+ "\tobjectClass = " + this.objectClass.getName() + ",\n" +
+ "\tfilters = " + formatFilters() + "\n]";
+ }
+
+ private String formatFilters() {
+ return this.filters.stream().map(Object::toString)
+ .collect(joining(",\n\t\t", "[\n\t\t", "\n\t]"));
+ }
+ }
+
+
+ /**
+ * Predicate with a descriptive label.
+ */
+ private static class LabeledPredicate implements Predicate {
+
+ private final String label;
+
+ private final Predicate delegate;
+
+
+ private LabeledPredicate(String label, Predicate delegate) {
+ this.label = label;
+ this.delegate = delegate;
+ }
+
+
+ @Override
+ public boolean test(T method) {
+ return this.delegate.test(method);
+ }
+
+ @Override
+ public Predicate and(Predicate super T> other) {
+ return this.delegate.and(other);
+ }
+
+ @Override
+ public Predicate negate() {
+ return this.delegate.negate();
+ }
+
+ @Override
+ public Predicate or(Predicate super T> other) {
+ return this.delegate.or(other);
+ }
+
+ @Override
+ public String toString() {
+ return this.label;
+ }
+ }
+
+
+ /**
+ * Resolver for method arguments.
+ */
+ public class ArgResolver {
+
+ private final List> filters = new ArrayList<>(4);
+
+
+ @SafeVarargs
+ private ArgResolver(Predicate... filter) {
+ this.filters.addAll(Arrays.asList(filter));
+ }
+
+ /**
+ * Filter on method arguments with annotations.
+ * See {@link org.springframework.web.method.MvcAnnotationPredicates}.
+ */
+ @SafeVarargs
+ public final ArgResolver annot(Predicate... filters) {
+ this.filters.addAll(Arrays.asList(filters));
+ return this;
+ }
+
+ /**
+ * Filter on method arguments that have the given annotations.
+ * @param annotationTypes the annotation types
+ * @see #annot(Predicate[])
+ * See {@link org.springframework.web.method.MvcAnnotationPredicates}.
+ */
+ @SafeVarargs
+ public final ArgResolver annotPresent(Class extends Annotation>... annotationTypes) {
+ this.filters.add(param -> Arrays.stream(annotationTypes).allMatch(param::hasParameterAnnotation));
+ return this;
+ }
+
+ /**
+ * Filter on method arguments that don't have the given annotations.
+ * @param annotationTypes the annotation types
+ */
+ @SafeVarargs
+ public final ArgResolver annotNotPresent(Class extends Annotation>... annotationTypes) {
+ this.filters.add(param ->
+ (annotationTypes.length > 0 ?
+ Arrays.stream(annotationTypes).noneMatch(param::hasParameterAnnotation) :
+ param.getParameterAnnotations().length == 0));
+ return this;
+ }
+
+ /**
+ * Resolve the argument also matching to the given type.
+ * @param type the expected type
+ */
+ public MethodParameter arg(Class> type, Class>... generics) {
+ return arg(toResolvableType(type, generics));
+ }
+
+ /**
+ * Resolve the argument also matching to the given type.
+ * @param type the expected type
+ */
+ public MethodParameter arg(Class> type, ResolvableType generic, ResolvableType... generics) {
+ return arg(toResolvableType(type, generic, generics));
+ }
+
+ /**
+ * Resolve the argument also matching to the given type.
+ * @param type the expected type
+ */
+ public MethodParameter arg(ResolvableType type) {
+ this.filters.add(p -> type.toString().equals(ResolvableType.forMethodParameter(p).toString()));
+ return arg();
+ }
+
+ /**
+ * Resolve the argument.
+ */
+ public final MethodParameter arg() {
+ List matches = applyFilters();
+ Assert.state(!matches.isEmpty(), () ->
+ "No matching arg in method\n" + formatMethod());
+ Assert.state(matches.size() == 1, () ->
+ "Multiple matching args in method\n" + formatMethod() + "\nMatches:\n\t" + matches);
+ return matches.get(0);
+ }
+
+
+ private List applyFilters() {
+ List matches = new ArrayList<>();
+ for (int i = 0; i < method.getParameterCount(); i++) {
+ MethodParameter param = new SynthesizingMethodParameter(method, i);
+ param.initParameterNameDiscovery(nameDiscoverer);
+ if (this.filters.stream().allMatch(p -> p.test(param))) {
+ matches.add(param);
+ }
+ }
+ return matches;
+ }
+ }
+
+
+ private static class MethodInvocationInterceptor
+ implements org.springframework.cglib.proxy.MethodInterceptor, MethodInterceptor {
+
+ private Method invokedMethod;
+
+
+ Method getInvokedMethod() {
+ return this.invokedMethod;
+ }
+
+ @Override
+ @Nullable
+ public Object intercept(Object object, Method method, Object[] args, MethodProxy proxy) {
+ if (ReflectionUtils.isObjectMethod(method)) {
+ return ReflectionUtils.invokeMethod(method, object, args);
+ }
+ else {
+ this.invokedMethod = method;
+ return null;
+ }
+ }
+
+ @Override
+ @Nullable
+ public Object invoke(org.aopalliance.intercept.MethodInvocation inv) throws Throwable {
+ return intercept(inv.getThis(), inv.getMethod(), inv.getArguments(), null);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ private static T initProxy(Class> type, MethodInvocationInterceptor interceptor) {
+ Assert.notNull(type, "'type' must not be null");
+ if (type.isInterface()) {
+ ProxyFactory factory = new ProxyFactory(EmptyTargetSource.INSTANCE);
+ factory.addInterface(type);
+ factory.addInterface(Supplier.class);
+ factory.addAdvice(interceptor);
+ return (T) factory.getProxy();
+ }
+
+ else {
+ Enhancer enhancer = new Enhancer();
+ enhancer.setSuperclass(type);
+ enhancer.setInterfaces(new Class>[] {Supplier.class});
+ enhancer.setNamingPolicy(SpringNamingPolicy.INSTANCE);
+ enhancer.setCallbackType(org.springframework.cglib.proxy.MethodInterceptor.class);
+
+ Class> proxyClass = enhancer.createClass();
+ Object proxy = null;
+
+ if (objenesis.isWorthTrying()) {
+ try {
+ proxy = objenesis.newInstance(proxyClass, enhancer.getUseCache());
+ }
+ catch (ObjenesisException ex) {
+ logger.debug("Objenesis failed, falling back to default constructor", ex);
+ }
+ }
+
+ if (proxy == null) {
+ try {
+ proxy = ReflectionUtils.accessibleConstructor(proxyClass).newInstance();
+ }
+ catch (Throwable ex) {
+ throw new IllegalStateException("Unable to instantiate proxy " +
+ "via both Objenesis and default constructor fails as well", ex);
+ }
+ }
+
+ ((Factory) proxy).setCallbacks(new Callback[] {interceptor});
+ return (T) proxy;
+ }
+ }
+
+}
diff --git a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolverTests.java b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolverTests.java
new file mode 100644
index 0000000000..f33d50d008
--- /dev/null
+++ b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolverTests.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.messaging.handler.invocation.reactive;
+
+import org.junit.Test;
+import org.springframework.core.MethodParameter;
+import org.springframework.core.annotation.SynthesizingMethodParameter;
+import org.springframework.security.authentication.TestAuthentication;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.annotation.AuthenticationPrincipal;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.userdetails.UserDetails;
+import org.springframework.security.messaging.handler.invocation.ResolvableMethod;
+import reactor.core.publisher.Mono;
+
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+
+import static org.assertj.core.api.Assertions.*;
+
+/**
+ * @author Rob Winch
+ */
+public class AuthenticationPrincipalArgumentResolverTests {
+ private AuthenticationPrincipalArgumentResolver resolver = new AuthenticationPrincipalArgumentResolver();
+
+ @Test
+ public void supportsParameterWhenAuthenticationPrincipalThenTrue() {
+ assertThat(this.resolver.supportsParameter(arg0("authenticationPrincipalOnMonoUserDetails"))).isTrue();
+ }
+
+ @Test
+ public void resolveArgumentWhenAuthenticationPrincipalAndEmptyContextThenNull() {
+ Object result = this.resolver.resolveArgument(arg0("authenticationPrincipalOnMonoUserDetails"), null).block();
+ assertThat(result).isNull();
+ }
+
+ @Test
+ public void resolveArgumentWhenAuthenticationPrincipalThenFound() {
+ Authentication authentication = TestAuthentication.authenticatedUser();
+ Mono result = (Mono) this.resolver.resolveArgument(arg0("authenticationPrincipalOnMonoUserDetails"), null)
+ .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
+ .block();
+ assertThat(result.block()).isEqualTo(authentication.getPrincipal());
+ }
+
+ @SuppressWarnings("unused")
+ private void authenticationPrincipalOnMonoUserDetails(@AuthenticationPrincipal Mono user) {
+ }
+
+ @Test
+ public void supportsParameterWhenCurrentUserThenTrue() {
+ assertThat(this.resolver.supportsParameter(arg0("currentUserOnMonoUserDetails"))).isTrue();
+ }
+
+ @Test
+ public void resolveArgumentWhenMonoAndAuthenticationPrincipalThenFound() {
+ Authentication authentication = TestAuthentication.authenticatedUser();
+ Mono result = (Mono) this.resolver.resolveArgument(arg0("currentUserOnMonoUserDetails"), null)
+ .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
+ .block();
+ assertThat(result.block()).isEqualTo(authentication.getPrincipal());
+ }
+
+ @SuppressWarnings("unused")
+ private void currentUserOnMonoUserDetails(@CurrentUser Mono user) {
+ }
+
+ @Test
+ public void resolveArgumentWhenExpressionThenFound() {
+ Authentication authentication = TestAuthentication.authenticatedUser();
+ Mono result = (Mono) this.resolver.resolveArgument(arg0("authenticationPrincipalExpression"), null)
+ .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
+ .block();
+ assertThat(result.block()).isEqualTo(authentication.getName());
+ }
+
+ @SuppressWarnings("unused")
+ private void authenticationPrincipalExpression(@AuthenticationPrincipal(expression = "username") Mono username) {
+ }
+
+ @Test
+ public void supportsParameterWhenNotAnnotatedThenFalse() {
+ assertThat(this.resolver.supportsParameter(arg0("monoUserDetails"))).isFalse();
+ }
+
+ @SuppressWarnings("unused")
+ private void monoUserDetails(Mono user) {
+ }
+
+ private MethodParameter arg0(String methodName) {
+ ResolvableMethod method = ResolvableMethod.on(getClass()).named(methodName).method();
+ return new SynthesizingMethodParameter(method.method(), 0);
+ }
+
+ @AuthenticationPrincipal
+ @Retention(RetentionPolicy.RUNTIME)
+ @interface CurrentUser {}
+}