[LANG-1544] MethodUtils.invokeMethod NullPointerException in case of null in args list (#680)

* LANG-1544:
- Null guards in place to handle one or more nulls specified as one of the parameters of the method to invoke.
- Check for an exact match of the actual parameter types against all of the methods on the class. This prevents picking an "upcasted" method (i.e. int specified but a method with a double is chosen).
- Throw an IllegalStateException with a helpful message if multiple candidate methods were found. This happens when multiple Methods had the same "distance" from the desired parameter types. Before this change the algorithm would just chose the first one.
- Tests for the above.

Co-authored-by: mike.buck@pb.com <mike.buck@pb.com>
This commit is contained in:
Michael Buck 2020-12-22 15:12:14 -05:00 committed by GitHub
parent f22dbf0399
commit fb6a7e7788
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 120 additions and 27 deletions

View File

@ -30,14 +30,17 @@ import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.ClassUtils;
import org.apache.commons.lang3.ClassUtils.Interfaces;
import org.apache.commons.lang3.Validate;
import static java.util.stream.Collectors.toList;
/**
* <p>Utility reflection methods focused on {@link Method}s, originally from Commons BeanUtils.
* Differences from the BeanUtils version may be noted, especially where similar functionality
@ -742,49 +745,70 @@ public class MethodUtils {
Validate.notNull(cls, "cls");
Validate.notEmpty(methodName, "methodName");
// Address methods in superclasses
Method[] methodArray = cls.getDeclaredMethods();
final List<Class<?>> superclassList = ClassUtils.getAllSuperclasses(cls);
for (final Class<?> klass : superclassList) {
methodArray = ArrayUtils.addAll(methodArray, klass.getDeclaredMethods());
}
final List<Method> methods = Arrays.stream(cls.getDeclaredMethods())
.filter(method -> method.getName().equals(methodName))
.collect(toList());
Method inexactMatch = null;
for (final Method method : methodArray) {
if (methodName.equals(method.getName()) &&
Objects.deepEquals(parameterTypes, method.getParameterTypes())) {
ClassUtils.getAllSuperclasses(cls).stream()
.map(Class::getDeclaredMethods)
.flatMap(Arrays::stream)
.filter(method -> method.getName().equals(methodName))
.forEach(methods::add);
for (Method method : methods) {
if (Arrays.deepEquals(method.getParameterTypes(), parameterTypes)) {
return method;
} else if (methodName.equals(method.getName()) &&
ClassUtils.isAssignable(parameterTypes, method.getParameterTypes(), true)) {
if ((inexactMatch == null) || (distance(parameterTypes, method.getParameterTypes())
< distance(parameterTypes, inexactMatch.getParameterTypes()))) {
inexactMatch = method;
}
}
}
return inexactMatch;
final TreeMap<Integer, List<Method>> candidates = new TreeMap<>();
methods.stream()
.filter(method -> ClassUtils.isAssignable(parameterTypes, method.getParameterTypes(), true))
.forEach(method -> {
final int distance = distance(parameterTypes, method.getParameterTypes());
final List<Method> candidatesAtDistance = candidates.computeIfAbsent(distance, k -> new ArrayList<>());
candidatesAtDistance.add(method);
});
if (candidates.isEmpty()) {
return null;
}
final List<Method> bestCandidates = candidates.values().iterator().next();
if (bestCandidates.size() == 1) {
return bestCandidates.get(0);
}
throw new IllegalStateException(
String.format("Found multiple candidates for method %s on class %s : %s",
methodName + Arrays.stream(parameterTypes).map(String::valueOf).collect(Collectors.joining(",", "(", ")")),
cls.getName(),
bestCandidates.stream().map(Method::toString).collect(Collectors.joining(",", "[", "]")))
);
}
/**
* <p>Returns the aggregate number of inheritance hops between assignable argument class types. Returns -1
* if the arguments aren't assignable. Fills a specific purpose for getMatchingMethod and is not generalized.</p>
* @param classArray
* @param toClassArray
* @param fromClassArray the Class array to calculate the distance from.
* @param toClassArray the Class array to calculate the distance to.
* @return the aggregate number of inheritance hops between assignable argument class types.
*/
private static int distance(final Class<?>[] classArray, final Class<?>[] toClassArray) {
private static int distance(final Class<?>[] fromClassArray, final Class<?>[] toClassArray) {
int answer = 0;
if (!ClassUtils.isAssignable(classArray, toClassArray, true)) {
if (!ClassUtils.isAssignable(fromClassArray, toClassArray, true)) {
return -1;
}
for (int offset = 0; offset < classArray.length; offset++) {
for (int offset = 0; offset < fromClassArray.length; offset++) {
// Note InheritanceUtils.distance() uses different scoring system.
if (classArray[offset].equals(toClassArray[offset])) {
final Class<?> aClass = fromClassArray[offset];
final Class<?> toClass = toClassArray[offset];
if (aClass == null || aClass.equals(toClass)) {
continue;
} else if (ClassUtils.isAssignable(classArray[offset], toClassArray[offset], true)
&& !ClassUtils.isAssignable(classArray[offset], toClassArray[offset], false)) {
} else if (ClassUtils.isAssignable(aClass, toClass, true)
&& !ClassUtils.isAssignable(aClass, toClass, false)) {
answer++;
} else {
answer = answer + 2;

View File

@ -29,6 +29,7 @@ import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.awt.Color;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Arrays;
@ -1018,4 +1019,72 @@ public class MethodUtilsTest {
distanceMethod.setAccessible(false);
}
@Test
public void testGetMatchingMethod() throws NoSuchMethodException {
assertEquals(MethodUtils.getMatchingMethod(GetMatchingMethodClass.class, "testMethod"),
GetMatchingMethodClass.class.getMethod("testMethod"));
assertEquals(MethodUtils.getMatchingMethod(GetMatchingMethodClass.class, "testMethod", Long.TYPE),
GetMatchingMethodClass.class.getMethod("testMethod", Long.TYPE));
assertEquals(MethodUtils.getMatchingMethod(GetMatchingMethodClass.class, "testMethod", Long.class),
GetMatchingMethodClass.class.getMethod("testMethod", Long.class));
assertEquals(MethodUtils.getMatchingMethod(GetMatchingMethodClass.class, "testMethod", (Class<?>) null),
GetMatchingMethodClass.class.getMethod("testMethod", Long.class));
assertThrows(IllegalStateException.class,
() -> MethodUtils.getMatchingMethod(GetMatchingMethodClass.class, "testMethod2", (Class<?>) null));
assertEquals(MethodUtils.getMatchingMethod(GetMatchingMethodClass.class, "testMethod3", Long.TYPE, Long.class),
GetMatchingMethodClass.class.getMethod("testMethod3", Long.TYPE, Long.class));
assertEquals(MethodUtils.getMatchingMethod(GetMatchingMethodClass.class, "testMethod3", Long.class, Long.TYPE),
GetMatchingMethodClass.class.getMethod("testMethod3", Long.class, Long.TYPE));
assertEquals(MethodUtils.getMatchingMethod(GetMatchingMethodClass.class, "testMethod3", null, Long.TYPE),
GetMatchingMethodClass.class.getMethod("testMethod3", Long.class, Long.TYPE));
assertEquals(MethodUtils.getMatchingMethod(GetMatchingMethodClass.class, "testMethod3", Long.TYPE, null),
GetMatchingMethodClass.class.getMethod("testMethod3", Long.TYPE, Long.class));
assertThrows(IllegalStateException.class,
() -> MethodUtils.getMatchingMethod(GetMatchingMethodClass.class, "testMethod4", null, null));
}
private static final class GetMatchingMethodClass {
public void testMethod() {
}
public void testMethod(final Long aLong) {
}
public void testMethod(final long aLong) {
}
public void testMethod2(final Long aLong) {
}
public void testMethod2(final Color aColor) {
}
public void testMethod2(final long aLong) {
}
public void testMethod3(final long aLong, final Long anotherLong) {
}
public void testMethod3(final Long aLong, final long anotherLong) {
}
public void testMethod3(final Long aLong, final Long anotherLong) {
}
public void testMethod4(final Long aLong, final Long anotherLong) {
}
public void testMethod4(final Color aColor1, final Color aColor2) {
}
}
}