diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Compiler.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Compiler.java index 97dddbdfe52..e6ed475a7be 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Compiler.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Compiler.java @@ -28,11 +28,14 @@ import org.elasticsearch.painless.spi.Whitelist; import org.objectweb.asm.util.Printer; import java.lang.reflect.Constructor; +import java.lang.reflect.Method; import java.net.MalformedURLException; import java.net.URL; import java.security.CodeSource; import java.security.SecureClassLoader; import java.security.cert.Certificate; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; @@ -89,16 +92,11 @@ final class Compiler { */ @Override public Class findClass(String name) throws ClassNotFoundException { - if (scriptClass.getName().equals(name)) { - return scriptClass; + Class found = additionalClasses.get(name); + if (found != null) { + return found; } - if (factoryClass != null && factoryClass.getName().equals(name)) { - return factoryClass; - } - if (statefulFactoryClass != null && statefulFactoryClass.getName().equals(name)) { - return statefulFactoryClass; - } - Class found = painlessLookup.canonicalTypeNameToType(name.replace('$', '.')); + found = painlessLookup.canonicalTypeNameToType(name.replace('$', '.')); return found != null ? found : super.findClass(name); } @@ -155,21 +153,16 @@ final class Compiler { */ private final Class scriptClass; - /** - * The class/interface to create the {@code scriptClass} instance. - */ - private final Class factoryClass; - - /** - * An optional class/interface to create the {@code factoryClass} instance. - */ - private final Class statefulFactoryClass; - /** * The whitelist the script will use. */ private final PainlessLookup painlessLookup; + /** + * Classes that do not exist in the lookup, but are needed by the script factories. + */ + private final Map> additionalClasses; + /** * Standard constructor. * @param scriptClass The class/interface the script will implement. @@ -179,9 +172,36 @@ final class Compiler { */ Compiler(Class scriptClass, Class factoryClass, Class statefulFactoryClass, PainlessLookup painlessLookup) { this.scriptClass = scriptClass; - this.factoryClass = factoryClass; - this.statefulFactoryClass = statefulFactoryClass; this.painlessLookup = painlessLookup; + Map> additionalClasses = new HashMap<>(); + additionalClasses.put(scriptClass.getName(), scriptClass); + addFactoryMethod(additionalClasses, factoryClass, "newInstance"); + addFactoryMethod(additionalClasses, statefulFactoryClass, "newFactory"); + addFactoryMethod(additionalClasses, statefulFactoryClass, "newInstance"); + this.additionalClasses = Collections.unmodifiableMap(additionalClasses); + } + + private static void addFactoryMethod(Map> additionalClasses, Class factoryClass, String methodName) { + if (factoryClass == null) { + return; + } + + Method factoryMethod = null; + for (Method method : factoryClass.getMethods()) { + if (methodName.equals(method.getName())) { + factoryMethod = method; + break; + } + } + if (factoryMethod == null) { + return; + } + + additionalClasses.put(factoryClass.getName(), factoryClass); + for (int i = 0; i < factoryMethod.getParameterTypes().length; ++i) { + Class parameterClazz = factoryMethod.getParameterTypes()[i]; + additionalClasses.put(parameterClazz.getName(), parameterClazz); + } } /**