Fix Painless Lambdas for Java 9 (#24070)

Replaces LambdaMetaFactory with LambdaBootstrap, a custom solution for lambdas in Painless using a design similar to LambdaMetaFactory, but allows for custom adaptation of types which recent changes to LambdaMetaFactory no longer allowed.
This commit is contained in:
Jack Conradson 2017-04-24 09:58:02 -07:00 committed by GitHub
parent 026bf2e3ee
commit 30cc33e2e5
16 changed files with 888 additions and 459 deletions

View File

@ -79,9 +79,19 @@ final class Compiler {
* @param bytes The generated byte code.
* @return A Class object extending {@link PainlessScript}.
*/
Class<? extends PainlessScript> define(String name, byte[] bytes) {
Class<? extends PainlessScript> defineScript(String name, byte[] bytes) {
return defineClass(name, bytes, 0, bytes.length, CODESOURCE).asSubclass(PainlessScript.class);
}
/**
* Generates a Class object for a lambda method.
* @param name The name of the class.
* @param bytes The generated byte code.
* @return A Class object.
*/
Class<?> defineLambda(String name, byte[] bytes) {
return defineClass(name, bytes, 0, bytes.length);
}
}
/**
@ -110,7 +120,7 @@ final class Compiler {
root.write();
try {
Class<? extends PainlessScript> clazz = loader.define(CLASS_NAME, root.getBytes());
Class<? extends PainlessScript> clazz = loader.defineScript(CLASS_NAME, root.getBytes());
clazz.getField("$DEFINITION").set(null, definition);
java.lang.reflect.Constructor<? extends PainlessScript> constructor =
clazz.getConstructor(String.class, String.class, BitSet.class);

View File

@ -23,7 +23,6 @@ import org.elasticsearch.painless.Definition.Method;
import org.elasticsearch.painless.Definition.RuntimeClass;
import java.lang.invoke.CallSite;
import java.lang.invoke.LambdaMetafactory;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodHandles.Lookup;
@ -132,7 +131,7 @@ public final class Def {
} catch (final ReflectiveOperationException roe) {
throw new AssertionError(roe);
}
// lookup up the factory for arraylength MethodHandle (intrinsic) from Java 9:
// https://bugs.openjdk.java.net/browse/JDK-8156915
MethodHandle arrayLengthMHFactory;
@ -150,7 +149,7 @@ public final class Def {
static <T extends Throwable> void rethrow(Throwable t) throws T {
throw (T) t;
}
/** Returns an array length getter MethodHandle for the given array type */
static MethodHandle arrayLengthGetter(Class<?> arrayType) {
if (JAVA9_ARRAY_LENGTH_MH_FACTORY != null) {
@ -206,7 +205,7 @@ public final class Def {
}
}
}
throw new IllegalArgumentException("Unable to find dynamic method [" + name + "] with [" + arity + "] arguments " +
"for class [" + receiverClass.getCanonicalName() + "].");
}
@ -239,7 +238,7 @@ public final class Def {
if (recipeString.isEmpty()) {
return lookupMethodInternal(definition, receiverClass, name, numArguments - 1).handle;
}
// convert recipe string to a bitset for convenience (the code below should be refactored...)
BitSet lambdaArgs = new BitSet();
for (int i = 0; i < recipeString.length(); i++) {
@ -247,7 +246,7 @@ public final class Def {
}
// otherwise: first we have to compute the "real" arity. This is because we have extra arguments:
// e.g. f(a, g(x), b, h(y), i()) looks like f(a, g, x, b, h, y, i).
// e.g. f(a, g(x), b, h(y), i()) looks like f(a, g, x, b, h, y, i).
int arity = callSiteType.parameterCount() - 1;
int upTo = 1;
for (int i = 1; i < numArguments; i++) {
@ -257,7 +256,7 @@ public final class Def {
arity -= numCaptures;
}
}
// lookup the method with the proper arity, then we know everything (e.g. interface types of parameters).
// based on these we can finally link any remaining lambdas that were deferred.
Method method = lookupMethodInternal(definition, receiverClass, name, arity);
@ -268,7 +267,7 @@ public final class Def {
for (int i = 1; i < numArguments; i++) {
// its a functional reference, replace the argument with an impl
if (lambdaArgs.get(i - 1)) {
// decode signature of form 'type.call,2'
// decode signature of form 'type.call,2'
String signature = (String) args[upTo++];
int separator = signature.lastIndexOf('.');
int separator2 = signature.indexOf(',');
@ -313,10 +312,10 @@ public final class Def {
replaced += numCaptures;
}
}
return handle;
}
/**
* Returns an implementation of interfaceClass that calls receiverClass.name
* <p>
@ -335,7 +334,7 @@ public final class Def {
return lookupReferenceInternal(definition, lookup, interfaceType, implMethod.owner.name,
implMethod.name, receiverClass);
}
/** Returns a method handle to an implementation of clazz, given method reference signature. */
private static MethodHandle lookupReferenceInternal(Definition definition, Lookup lookup,
Definition.Type clazz, String type, String call, Class<?>... captures)
@ -351,47 +350,37 @@ public final class Def {
int arity = interfaceMethod.arguments.size() + captures.length;
final MethodHandle handle;
try {
MethodHandle accessor = lookup.findStaticGetter(lookup.lookupClass(),
getUserFunctionHandleFieldName(call, arity),
MethodHandle accessor = lookup.findStaticGetter(lookup.lookupClass(),
getUserFunctionHandleFieldName(call, arity),
MethodHandle.class);
handle = (MethodHandle) accessor.invokeExact();
handle = (MethodHandle)accessor.invokeExact();
} catch (NoSuchFieldException | IllegalAccessException e) {
// is it a synthetic method? If we generated the method ourselves, be more helpful. It can only fail
// because the arity does not match the expected interface type.
if (call.contains("$")) {
throw new IllegalArgumentException("Incorrect number of parameters for [" + interfaceMethod.name +
throw new IllegalArgumentException("Incorrect number of parameters for [" + interfaceMethod.name +
"] in [" + clazz.clazz + "]");
}
throw new IllegalArgumentException("Unknown call [" + call + "] with [" + arity + "] arguments.");
}
ref = new FunctionRef(clazz, interfaceMethod, handle, captures.length);
ref = new FunctionRef(clazz, interfaceMethod, call, handle.type(), captures.length);
} else {
// whitelist lookup
ref = new FunctionRef(definition, clazz, type, call, captures.length);
}
final CallSite callSite;
if (ref.needsBridges()) {
callSite = LambdaMetafactory.altMetafactory(lookup,
ref.invokedName,
ref.invokedType,
ref.samMethodType,
ref.implMethod,
ref.samMethodType,
LambdaMetafactory.FLAG_BRIDGES,
1,
ref.interfaceMethodType);
} else {
callSite = LambdaMetafactory.altMetafactory(lookup,
ref.invokedName,
ref.invokedType,
ref.samMethodType,
ref.implMethod,
ref.samMethodType,
0);
}
final CallSite callSite = LambdaBootstrap.lambdaBootstrap(
lookup,
ref.interfaceMethodName,
ref.factoryMethodType,
ref.interfaceMethodType,
ref.delegateClassName,
ref.delegateInvokeType,
ref.delegateMethodName,
ref.delegateMethodType
);
return callSite.dynamicInvoker().asType(MethodType.methodType(clazz.clazz, captures));
}
/** gets the field name used to lookup up the MethodHandle for a function. */
public static String getUserFunctionHandleFieldName(String name, int arity) {
return "handle$" + name + "$" + arity;
@ -595,7 +584,7 @@ public final class Def {
throw new IllegalArgumentException("Attempting to address a non-array type " +
"[" + receiverClass.getCanonicalName() + "] as an array.");
}
/** Helper class for isolating MethodHandles and methods to get iterators over arrays
* (to emulate "enhanced for loop" using MethodHandles). These cause boxing, and are not as efficient
* as they could be, but works.

View File

@ -1,5 +1,6 @@
package org.elasticsearch.painless;
import java.util.List;
import java.util.function.Function;
/*
@ -25,11 +26,11 @@ import java.util.function.Function;
public class FeatureTest {
private int x;
private int y;
/** empty ctor */
public FeatureTest() {
}
/** ctor with params */
public FeatureTest(int x, int y) {
this.x = x;
@ -60,14 +61,18 @@ public class FeatureTest {
public static boolean overloadedStatic() {
return true;
}
/** static method that returns what you ask it */
public static boolean overloadedStatic(boolean whatToReturn) {
return whatToReturn;
}
/** method taking two functions! */
public Object twoFunctionsOfX(Function<Object,Object> f, Function<Object,Object> g) {
return f.apply(g.apply(x));
}
public void listInput(List<Object> list) {
}
}

View File

@ -20,125 +20,138 @@
package org.elasticsearch.painless;
import org.elasticsearch.painless.Definition.Method;
import org.objectweb.asm.Handle;
import org.objectweb.asm.Opcodes;
import org.elasticsearch.painless.Definition.Type;
import org.elasticsearch.painless.api.Augmentation;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.lang.reflect.Modifier;
/**
* Reference to a function or lambda.
import static org.elasticsearch.painless.WriterConstants.CLASS_NAME;
import static org.objectweb.asm.Opcodes.H_INVOKEINTERFACE;
import static org.objectweb.asm.Opcodes.H_INVOKESTATIC;
import static org.objectweb.asm.Opcodes.H_INVOKEVIRTUAL;
import static org.objectweb.asm.Opcodes.H_NEWINVOKESPECIAL;
/**
* Reference to a function or lambda.
* <p>
* Once you have created one of these, you have "everything you need" to call LambdaMetaFactory
* either statically from bytecode with invokedynamic, or at runtime from Java.
* Once you have created one of these, you have "everything you need" to call {@link LambdaBootstrap}
* either statically from bytecode with invokedynamic, or at runtime from Java.
*/
public class FunctionRef {
/** Function Object's method name */
public final String invokedName;
/** CallSite signature */
public final MethodType invokedType;
/** Implementation method */
public final MethodHandle implMethod;
/** Function Object's method signature */
public final MethodType samMethodType;
/** When bridging is required, request this bridge interface */
/** functional interface method name */
public final String interfaceMethodName;
/** factory (CallSite) method signature */
public final MethodType factoryMethodType;
/** functional interface method signature */
public final MethodType interfaceMethodType;
/** ASM "Handle" to the method, for the constant pool */
public final Handle implMethodASM;
/** class of the delegate method to be called */
public final String delegateClassName;
/** the invocation type of the delegate method */
public final int delegateInvokeType;
/** the name of the delegate method */
public final String delegateMethodName;
/** delegate method signature */
public final MethodType delegateMethodType;
/** interface method */
public final Method interfaceMethod;
/** delegate method */
public final Method delegateMethod;
/** factory method type descriptor */
public final String factoryDescriptor;
/** functional interface method as type */
public final org.objectweb.asm.Type interfaceType;
/** delegate method type method as type */
public final org.objectweb.asm.Type delegateType;
/**
* Creates a new FunctionRef, which will resolve {@code type::call} from the whitelist.
* @param definition the whitelist against which this script is being compiled
* @param expected interface type to implement.
* @param expected functional interface type to implement.
* @param type the left hand side of a method reference expression
* @param call the right hand side of a method reference expression
* @param numCaptures number of captured arguments
*/
public FunctionRef(Definition definition, Definition.Type expected, String type, String call,
int numCaptures) {
this(expected, expected.struct.getFunctionalMethod(),
lookup(definition, expected, type, call, numCaptures > 0), numCaptures);
*/
public FunctionRef(Definition definition, Type expected, String type, String call, int numCaptures) {
this(expected, expected.struct.getFunctionalMethod(), lookup(definition, expected, type, call, numCaptures > 0), numCaptures);
}
/**
* Creates a new FunctionRef (already resolved)
* @param expected interface type to implement
* @param method functional interface method
* @param impl implementation method
* @param expected functional interface type to implement
* @param interfaceMethod functional interface method
* @param delegateMethod implementation method
* @param numCaptures number of captured arguments
*/
public FunctionRef(Definition.Type expected, Definition.Method method, Definition.Method impl, int numCaptures) {
// e.g. compareTo
invokedName = method.name;
// e.g. (Object)Comparator
MethodType implType = impl.getMethodType();
// only include captured parameters as arguments
invokedType = MethodType.methodType(expected.clazz,
implType.dropParameterTypes(numCaptures, implType.parameterCount()));
// e.g. (Object,Object)int
interfaceMethodType = method.getMethodType().dropParameterTypes(0, 1);
*/
public FunctionRef(Type expected, Method interfaceMethod, Method delegateMethod, int numCaptures) {
MethodType delegateMethodType = delegateMethod.getMethodType();
final int tag;
if ("<init>".equals(impl.name)) {
tag = Opcodes.H_NEWINVOKESPECIAL;
} else if (Modifier.isStatic(impl.modifiers)) {
tag = Opcodes.H_INVOKESTATIC;
} else if (impl.owner.clazz.isInterface()) {
tag = Opcodes.H_INVOKEINTERFACE;
interfaceMethodName = interfaceMethod.name;
factoryMethodType = MethodType.methodType(expected.clazz,
delegateMethodType.dropParameterTypes(numCaptures, delegateMethodType.parameterCount()));
interfaceMethodType = interfaceMethod.getMethodType().dropParameterTypes(0, 1);
// the Painless$Script class can be inferred if owner is null
if (delegateMethod.owner == null) {
delegateClassName = CLASS_NAME;
} else if (delegateMethod.augmentation) {
delegateClassName = Augmentation.class.getName();
} else {
tag = Opcodes.H_INVOKEVIRTUAL;
delegateClassName = delegateMethod.owner.clazz.getName();
}
final String owner;
final boolean ownerIsInterface;
if (impl.owner == null) {
// owner == null: script class itself
ownerIsInterface = false;
owner = WriterConstants.CLASS_TYPE.getInternalName();
} else if (impl.augmentation) {
ownerIsInterface = false;
owner = WriterConstants.AUGMENTATION_TYPE.getInternalName();
if ("<init>".equals(delegateMethod.name)) {
delegateInvokeType = H_NEWINVOKESPECIAL;
} else if (Modifier.isStatic(delegateMethod.modifiers)) {
delegateInvokeType = H_INVOKESTATIC;
} else if (delegateMethod.owner.clazz.isInterface()) {
delegateInvokeType = H_INVOKEINTERFACE;
} else {
ownerIsInterface = impl.owner.clazz.isInterface();
owner = impl.owner.type.getInternalName();
delegateInvokeType = H_INVOKEVIRTUAL;
}
implMethodASM = new Handle(tag, owner, impl.name, impl.method.getDescriptor(), ownerIsInterface);
implMethod = impl.handle;
// remove any prepended captured arguments for the 'natural' signature.
samMethodType = adapt(interfaceMethodType, impl.getMethodType().dropParameterTypes(0, numCaptures));
delegateMethodName = delegateMethod.name;
this.delegateMethodType = delegateMethodType.dropParameterTypes(0, numCaptures);
this.interfaceMethod = interfaceMethod;
this.delegateMethod = delegateMethod;
factoryDescriptor = factoryMethodType.toMethodDescriptorString();
interfaceType = org.objectweb.asm.Type.getMethodType(interfaceMethodType.toMethodDescriptorString());
delegateType = org.objectweb.asm.Type.getMethodType(this.delegateMethodType.toMethodDescriptorString());
}
/**
* Creates a new FunctionRef (low level).
* <p>
* This will <b>not</b> set implMethodASM. It is for runtime use only.
* Creates a new FunctionRef (low level).
* It is for runtime use only.
*/
public FunctionRef(Definition.Type expected, Definition.Method method, MethodHandle impl, int numCaptures) {
// e.g. compareTo
invokedName = method.name;
// e.g. (Object)Comparator
MethodType implType = impl.type();
// only include captured parameters as arguments
invokedType = MethodType.methodType(expected.clazz,
implType.dropParameterTypes(numCaptures, implType.parameterCount()));
// e.g. (Object,Object)int
interfaceMethodType = method.getMethodType().dropParameterTypes(0, 1);
public FunctionRef(Type expected, Method interfaceMethod, String delegateMethodName, MethodType delegateMethodType, int numCaptures) {
interfaceMethodName = interfaceMethod.name;
factoryMethodType = MethodType.methodType(expected.clazz,
delegateMethodType.dropParameterTypes(numCaptures, delegateMethodType.parameterCount()));
interfaceMethodType = interfaceMethod.getMethodType().dropParameterTypes(0, 1);
implMethod = impl;
implMethodASM = null;
// remove any prepended captured arguments for the 'natural' signature.
samMethodType = adapt(interfaceMethodType, impl.type().dropParameterTypes(0, numCaptures));
delegateClassName = CLASS_NAME;
delegateInvokeType = H_INVOKESTATIC;
this.delegateMethodName = delegateMethodName;
this.delegateMethodType = delegateMethodType.dropParameterTypes(0, numCaptures);
this.interfaceMethod = null;
delegateMethod = null;
factoryDescriptor = null;
interfaceType = null;
delegateType = null;
}
/**
/**
* Looks up {@code type::call} from the whitelist, and returns a matching method.
*/
private static Definition.Method lookup(Definition definition, Definition.Type expected,
String type, String call, boolean receiverCaptured) {
String type, String call, boolean receiverCaptured) {
// check its really a functional interface
// for e.g. Comparable
Method method = expected.struct.getFunctionalMethod();
@ -177,27 +190,4 @@ public class FunctionRef {
}
return impl;
}
/** Returns true if you should ask LambdaMetaFactory to construct a bridge for the interface signature */
public boolean needsBridges() {
// currently if the interface differs, we ask for a bridge, but maybe we should do smarter checking?
// either way, stuff will fail if its wrong :)
return interfaceMethodType.equals(samMethodType) == false;
}
/**
* If the interface expects a primitive type to be returned, we can't return Object,
* But we can set SAM to the wrapper version, and a cast will take place
*/
private MethodType adapt(MethodType expected, MethodType actual) {
// add some checks, now that we've set everything up, to deliver exceptions as early as possible.
if (expected.parameterCount() != actual.parameterCount()) {
throw new IllegalArgumentException("Incorrect number of parameters for [" + invokedName +
"] in [" + invokedType.returnType() + "]");
}
if (expected.returnType().isPrimitive() && actual.returnType() == Object.class) {
actual = actual.changeReturnType(MethodType.methodType(expected.returnType()).wrap().returnType());
}
return actual;
}
}

View File

@ -0,0 +1,530 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.painless;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.FieldVisitor;
import org.objectweb.asm.Handle;
import org.objectweb.asm.Type;
import org.objectweb.asm.commons.GeneratorAdapter;
import org.objectweb.asm.commons.Method;
import java.lang.invoke.CallSite;
import java.lang.invoke.ConstantCallSite;
import java.lang.invoke.LambdaConversionException;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.concurrent.atomic.AtomicLong;
import static java.lang.invoke.MethodHandles.Lookup;
import static org.elasticsearch.painless.Compiler.Loader;
import static org.elasticsearch.painless.WriterConstants.CLASS_VERSION;
import static org.elasticsearch.painless.WriterConstants.DELEGATE_BOOTSTRAP_HANDLE;
import static org.objectweb.asm.Opcodes.ACC_FINAL;
import static org.objectweb.asm.Opcodes.ACC_PRIVATE;
import static org.objectweb.asm.Opcodes.ACC_PUBLIC;
import static org.objectweb.asm.Opcodes.ACC_STATIC;
import static org.objectweb.asm.Opcodes.ACC_SUPER;
import static org.objectweb.asm.Opcodes.ACC_SYNTHETIC;
import static org.objectweb.asm.Opcodes.H_INVOKEINTERFACE;
import static org.objectweb.asm.Opcodes.H_INVOKESTATIC;
import static org.objectweb.asm.Opcodes.H_INVOKEVIRTUAL;
import static org.objectweb.asm.Opcodes.H_NEWINVOKESPECIAL;
/**
* LambdaBootstrap is used to generate all the code necessary to execute
* lambda functions and method references within Painless. The code generation
* used here is based upon the following article:
* http://cr.openjdk.java.net/~briangoetz/lambda/lambda-translation.html
* However, it is a simplified version as Painless has no concept of generics
* or serialization. LambdaBootstrap is being used as a replacement for
* {@link java.lang.invoke.LambdaMetafactory} since the Painless casting model
* cannot be fully supported through this class.
*
* For each lambda function/method reference used within a Painless script
* a class will be generated at link-time using the
* {@link LambdaBootstrap#lambdaBootstrap} method that contains the following:
* 1. member fields for any captured variables
* 2. a constructor that will take in captured variables and assign them to
* their respective member fields
* 3. if there are captures, a factory method that will take in captured
* variables and delegate them to the constructor
* 4. a method that will load the member fields representing captured variables
* and take in any other necessary values based on the arguments passed into the
* lambda function/reference method; it will then make a delegated call to the
* actual lambda function/reference method
*
* Take for example the following Painless script:
*
* {@code
* List list1 = new ArrayList(); "
* list1.add(2); "
* List list2 = new ArrayList(); "
* list1.forEach(x -> list2.add(x));"
* return list[0]"
* }
*
* The script contains a lambda function with a captured variable.
* The following Lambda class would be generated:
*
* {@code
* public static final class $$Lambda0 implements Consumer {
* private List arg$0;
*
* public $$Lambda0(List arg$0) {
* this.arg$0 = arg$0;
* }
*
* public static $$Lambda0 get$Lambda(List arg$0) {
* return $$Lambda0(arg$0);
* }
*
* public void accept(Object val$0) {
* Painless$Script.lambda$0(this.arg$0, val$0);
* }
* }
*
* public class Painless$Script implements ... {
* ...
* public static lambda$0(List list2, Object x) {
* list2.add(x);
* }
* ...
* }
* }
*
* Note if the above didn't have a captured variable then
* the factory method get$Lambda would not have been generated.
* Also the accept method actually uses an invokedynamic
* instruction to call the lambda$0 method so that
* {@link MethodHandle#asType} can be used to do the necessary
* conversions between argument types without having to hard
* code them.
*
* When the {@link CallSite} is linked the linked method depends
* on whether or not there are captures. If there are no captures
* the same instance of the generated lambda class will be
* returned each time by the factory method as there are no
* changing values other than the arguments. If there are
* captures a new instance of the generated lambda class will
* be returned each time with the captures passed into the
* factory method to be stored in the member fields.
*/
public final class LambdaBootstrap {
/**
* Metadata for a captured variable used during code generation.
*/
private static final class Capture {
private final String name;
private final Type type;
private final String desc;
/**
* Converts incoming parameters into the name, type, and
* descriptor for the captured argument.
* @param count The captured argument count
* @param type The class type of the captured argument
*/
private Capture(int count, Class<?> type) {
this.name = "arg$" + count;
this.type = Type.getType(type);
this.desc = this.type.getDescriptor();
}
}
/**
* A counter used to generate a unique name
* for each lambda function/reference class.
*/
private static final AtomicLong COUNTER = new AtomicLong(0);
/**
* Generates a lambda class for a lambda function/method reference
* within a Painless script. Variables with the prefix interface are considered
* to represent values for code generated for the lambda class. Variables with
* the prefix delegate are considered to represent values for code generated
* within the Painless script. The interface method delegates (calls) to the
* delegate method.
* @param lookup Standard {@link MethodHandles#lookup}
* @param interfaceMethodName Name of functional interface method that is called
* @param factoryMethodType The type of method to be linked to this CallSite; note that
* captured types are based on the parameters for this method
* @param interfaceMethodType The type of method representing the functional interface method
* @param delegateClassName The name of the Painless script class
* @param delegateInvokeType The type of method call to be made
* (static, virtual, interface, or constructor)
* @param delegateMethodName The name of the method to be called in the Painless script class
* @param delegateMethodType The type of method call in the Painless script class without
* the captured types
* @return A {@link CallSite} linked to a factory method for creating a lambda class
* that implements the expected functional interface
* @throws LambdaConversionException Thrown when an illegal type conversion occurs at link time
*/
public static CallSite lambdaBootstrap(
Lookup lookup,
String interfaceMethodName,
MethodType factoryMethodType,
MethodType interfaceMethodType,
String delegateClassName,
int delegateInvokeType,
String delegateMethodName,
MethodType delegateMethodType)
throws LambdaConversionException {
String factoryMethodName = "get$lambda";
String lambdaClassName = lookup.lookupClass().getName().replace('.', '/') +
"$$Lambda" + COUNTER.getAndIncrement();
Type lambdaClassType = Type.getType("L" + lambdaClassName + ";");
validateTypes(interfaceMethodType, delegateMethodType);
ClassWriter cw =
beginLambdaClass(lambdaClassName, factoryMethodType.returnType().getName());
Capture[] captures = generateCaptureFields(cw, factoryMethodType);
Method constructorMethod =
generateLambdaConstructor(cw, lambdaClassType, factoryMethodType, captures);
if (captures.length > 0) {
generateFactoryMethod(
cw, factoryMethodName, factoryMethodType, lambdaClassType, constructorMethod);
}
generateInterfaceMethod(cw, factoryMethodType, lambdaClassName, lambdaClassType,
interfaceMethodName, interfaceMethodType, delegateClassName, delegateInvokeType,
delegateMethodName, delegateMethodType, captures);
endLambdaClass(cw);
Class<?> lambdaClass =
createLambdaClass((Loader)lookup.lookupClass().getClassLoader(), cw, lambdaClassName);
if (captures.length > 0) {
return createCaptureCallSite(lookup, factoryMethodName, factoryMethodType, lambdaClass);
} else {
return createNoCaptureCallSite(factoryMethodType, lambdaClass);
}
}
/**
* Validates some conversions at link time. Currently, only ensures that the lambda method
* with a return value cannot delegate to a delegate method with no return type.
*/
private static void validateTypes(MethodType interfaceMethodType, MethodType delegateMethodType)
throws LambdaConversionException {
if (interfaceMethodType.returnType() != void.class &&
delegateMethodType.returnType() == void.class) {
throw new LambdaConversionException("lambda expects return type ["
+ interfaceMethodType.returnType() + "], but found return type [void]");
}
}
/**
* Creates the {@link ClassWriter} to be used for the lambda class generation.
*/
private static ClassWriter beginLambdaClass(String lambdaClassName, String lambdaInterface) {
String baseClass = Object.class.getName().replace('.', '/');
lambdaInterface = lambdaInterface.replace('.', '/');
int modifiers = ACC_PUBLIC | ACC_STATIC | ACC_SUPER | ACC_FINAL | ACC_SYNTHETIC;
ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS);
cw.visit(CLASS_VERSION,
modifiers, lambdaClassName, null, baseClass, new String[] {lambdaInterface});
return cw;
}
/**
* Generates member fields for captured variables
* based on the parameters for the factory method.
* @return An array of captured variable metadata
* for generating method arguments later on
*/
private static Capture[] generateCaptureFields(ClassWriter cw, MethodType factoryMethodType) {
int captureTotal = factoryMethodType.parameterCount();
Capture[] captures = new Capture[captureTotal];
for (int captureCount = 0; captureCount < captureTotal; ++captureCount) {
captures[captureCount] =
new Capture(captureCount, factoryMethodType.parameterType(captureCount));
int modifiers = ACC_PRIVATE + ACC_FINAL;
FieldVisitor fv = cw.visitField(
modifiers, captures[captureCount].name, captures[captureCount].desc, null, null);
fv.visitEnd();
}
return captures;
}
/**
* Generates a constructor that will take in captured
* arguments if any and store them in their respective
* member fields.
* @return The constructor {@link Method} used to
* call this method from a potential factory method
* if there are captured arguments
*/
private static Method generateLambdaConstructor(
ClassWriter cw,
Type lambdaClassType,
MethodType factoryMethodType,
Capture[] captures) {
String conName = "<init>";
String conDesc = factoryMethodType.changeReturnType(void.class).toMethodDescriptorString();
Method conMeth = new Method(conName, conDesc);
Type baseConType = Type.getType(Object.class);
Method baseConMeth = new Method(conName,
MethodType.methodType(void.class).toMethodDescriptorString());
int modifiers = ACC_PUBLIC;
GeneratorAdapter constructor = new GeneratorAdapter(modifiers, conMeth,
cw.visitMethod(modifiers, conName, conDesc, null, null));
constructor.visitCode();
constructor.loadThis();
constructor.invokeConstructor(baseConType, baseConMeth);
for (int captureCount = 0; captureCount < captures.length; ++captureCount) {
constructor.loadThis();
constructor.loadArg(captureCount);
constructor.putField(
lambdaClassType, captures[captureCount].name, captures[captureCount].type);
}
constructor.returnValue();
constructor.endMethod();
return conMeth;
}
/**
* Generates a factory method that can be used to create the lambda class
* if there are captured variables.
*/
private static void generateFactoryMethod(
ClassWriter cw,
String factoryMethodName,
MethodType factoryMethodType,
Type lambdaClassType,
Method constructorMethod) {
String facDesc = factoryMethodType.toMethodDescriptorString();
Method facMeth = new Method(factoryMethodName, facDesc);
int modifiers = ACC_PUBLIC | ACC_STATIC;
GeneratorAdapter factory = new GeneratorAdapter(modifiers, facMeth,
cw.visitMethod(modifiers, factoryMethodName, facDesc, null, null));
factory.visitCode();
factory.newInstance(lambdaClassType);
factory.dup();
factory.loadArgs();
factory.invokeConstructor(lambdaClassType, constructorMethod);
factory.returnValue();
factory.endMethod();
}
/**
* Generates the interface method that will delegate (call) to the delegate method.
*/
private static void generateInterfaceMethod(
ClassWriter cw,
MethodType factoryMethodType,
String lambdaClassName,
Type lambdaClassType,
String interfaceMethodName,
MethodType interfaceMethodType,
String delegateClassName,
int delegateInvokeType,
String delegateMethodName,
MethodType delegateMethodType,
Capture[] captures)
throws LambdaConversionException {
String lamDesc = interfaceMethodType.toMethodDescriptorString();
Method lamMeth = new Method(lambdaClassName, lamDesc);
int modifiers = ACC_PUBLIC;
GeneratorAdapter iface = new GeneratorAdapter(modifiers, lamMeth,
cw.visitMethod(modifiers, interfaceMethodName, lamDesc, null, null));
iface.visitCode();
// Handles the case where a reference method refers to a constructor.
// A new instance of the requested type will be created and the
// constructor with no parameters will be called.
// Example: String::new
if (delegateInvokeType == H_NEWINVOKESPECIAL) {
String conName = "<init>";
String conDesc = MethodType.methodType(void.class).toMethodDescriptorString();
Method conMeth = new Method(conName, conDesc);
Type conType = Type.getType(delegateMethodType.returnType());
iface.newInstance(conType);
iface.dup();
iface.invokeConstructor(conType, conMeth);
} else {
// Loads any captured variables onto the stack.
for (int captureCount = 0; captureCount < captures.length; ++captureCount) {
iface.loadThis();
iface.getField(
lambdaClassType, captures[captureCount].name, captures[captureCount].type);
}
// Loads any passed in arguments onto the stack.
iface.loadArgs();
// Handles the case for a lambda function or a static reference method.
// interfaceMethodType and delegateMethodType both have the captured types
// inserted into their type signatures. This later allows the delegate
// method to be invoked dynamically and have the interface method types
// appropriately converted to the delegate method types.
// Example: Integer::parseInt
// Example: something.each(x -> x + 1)
if (delegateInvokeType == H_INVOKESTATIC) {
interfaceMethodType =
interfaceMethodType.insertParameterTypes(0, factoryMethodType.parameterArray());
delegateMethodType =
delegateMethodType.insertParameterTypes(0, factoryMethodType.parameterArray());
} else if (delegateInvokeType == H_INVOKEVIRTUAL ||
delegateInvokeType == H_INVOKEINTERFACE) {
// Handles the case for a virtual or interface reference method with no captures.
// delegateMethodType drops the 'this' parameter because it will be re-inserted
// when the method handle for the dynamically invoked delegate method is created.
// Example: Object::toString
if (captures.length == 0) {
Class<?> clazz = delegateMethodType.parameterType(0);
delegateClassName = clazz.getName();
delegateMethodType = delegateMethodType.dropParameterTypes(0, 1);
// Handles the case for a virtual or interface reference method with 'this'
// captured. interfaceMethodType inserts the 'this' type into its
// method signature. This later allows the delegate
// method to be invoked dynamically and have the interface method types
// appropriately converted to the delegate method types.
// Example: something::toString
} else if (captures.length == 1) {
Class<?> clazz = factoryMethodType.parameterType(0);
delegateClassName = clazz.getName();
interfaceMethodType = interfaceMethodType.insertParameterTypes(0, clazz);
} else {
throw new LambdaConversionException(
"unexpected number of captures [ " + captures.length + "]");
}
} else {
throw new IllegalStateException(
"unexpected invocation type [" + delegateInvokeType + "]");
}
Handle delegateHandle =
new Handle(delegateInvokeType, delegateClassName.replace('.', '/'),
delegateMethodName, delegateMethodType.toMethodDescriptorString(),
delegateInvokeType == H_INVOKEINTERFACE);
iface.invokeDynamic(delegateMethodName, Type.getMethodType(interfaceMethodType
.toMethodDescriptorString()).getDescriptor(), DELEGATE_BOOTSTRAP_HANDLE,
delegateHandle);
}
iface.returnValue();
iface.endMethod();
}
/**
* Closes the {@link ClassWriter}.
*/
private static void endLambdaClass(ClassWriter cw) {
cw.visitEnd();
}
/**
* Defines the {@link Class} for the lambda class using the same {@link Loader}
* that originally defined the class for the Painless script.
*/
private static Class<?> createLambdaClass(
Loader loader,
ClassWriter cw,
String lambdaClassName) {
byte[] classBytes = cw.toByteArray();
return AccessController.doPrivileged((PrivilegedAction<Class<?>>)() ->
loader.defineLambda(lambdaClassName.replace('/', '.'), classBytes));
}
/**
* Creates an {@link ConstantCallSite} that will return the same instance
* of the generated lambda class every time this linked factory method is called.
*/
private static CallSite createNoCaptureCallSite(
MethodType factoryMethodType,
Class<?> lambdaClass) {
Constructor<?> constructor = AccessController.doPrivileged(
(PrivilegedAction<Constructor<?>>)() -> {
try {
return lambdaClass.getConstructor();
} catch (NoSuchMethodException nsme) {
throw new IllegalStateException("unable to create lambda class", nsme);
}
});
try {
return new ConstantCallSite(MethodHandles.constant(
factoryMethodType.returnType(), constructor.newInstance()));
} catch (InstantiationException |
IllegalAccessException |
InvocationTargetException exception) {
throw new IllegalStateException("unable to create lambda class", exception);
}
}
/**
* Creates an {@link ConstantCallSite}
*/
private static CallSite createCaptureCallSite(
Lookup lookup,
String factoryMethodName,
MethodType factoryMethodType,
Class<?> lambdaClass) {
try {
return new ConstantCallSite(
lookup.findStatic(lambdaClass, factoryMethodName, factoryMethodType));
} catch (NoSuchMethodException | IllegalAccessException exception) {
throw new IllegalStateException("unable to create lambda factory class", exception);
}
}
/**
* Links the delegate method to the returned {@link CallSite}. The linked
* delegate method will use converted types from the interface method. Using
* invokedynamic to make the delegate method call allows
* {@link MethodHandle#asType} to be used to do the type conversion instead
* of either a lot more code or requiring many {@link Definition.Type}s to be looked
* up at link-time.
*/
public static CallSite delegateBootstrap(Lookup lookup,
String delegateMethodName,
MethodType interfaceMethodType,
MethodHandle delegateMethodHandle) {
return new ConstantCallSite(delegateMethodHandle.asType(interfaceMethodType));
}
}

View File

@ -84,11 +84,10 @@ public final class WriterConstants {
public static final Type UTILITY_TYPE = Type.getType(Utility.class);
public static final Method STRING_TO_CHAR = getAsmMethod(char.class, "StringTochar", String.class);
public static final Method CHAR_TO_STRING = getAsmMethod(String.class, "charToString", char.class);
public static final Type OBJECT_ARRAY_TYPE = Type.getType("[Ljava/lang/Object;");
public static final Type METHOD_HANDLE_TYPE = Type.getType(MethodHandle.class);
public static final Type AUGMENTATION_TYPE = Type.getType(Augmentation.class);
/**
@ -110,7 +109,6 @@ public final class WriterConstants {
public static final Method DEF_BOOTSTRAP_DELEGATE_METHOD = getAsmMethod(CallSite.class, "bootstrap", Definition.class,
MethodHandles.Lookup.class, String.class, MethodType.class, int.class, int.class, Object[].class);
public static final Type DEF_UTIL_TYPE = Type.getType(Def.class);
public static final Method DEF_TO_BOOLEAN = getAsmMethod(boolean.class, "DefToboolean" , Object.class);
public static final Method DEF_TO_BYTE_IMPLICIT = getAsmMethod(byte.class , "DefTobyteImplicit" , Object.class);
@ -132,10 +130,15 @@ public final class WriterConstants {
/** invokedynamic bootstrap for lambda expression/method references */
public static final MethodType LAMBDA_BOOTSTRAP_TYPE =
MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class,
MethodType.class, Object[].class);
MethodType.class, MethodType.class, String.class, int.class, String.class, MethodType.class);
public static final Handle LAMBDA_BOOTSTRAP_HANDLE =
new Handle(Opcodes.H_INVOKESTATIC, Type.getInternalName(LambdaMetafactory.class),
"altMetafactory", LAMBDA_BOOTSTRAP_TYPE.toMethodDescriptorString(), false);
new Handle(Opcodes.H_INVOKESTATIC, Type.getInternalName(LambdaBootstrap.class),
"lambdaBootstrap", LAMBDA_BOOTSTRAP_TYPE.toMethodDescriptorString(), false);
public static final MethodType DELEGATE_BOOTSTRAP_TYPE =
MethodType.methodType(CallSite.class, MethodHandles.Lookup.class, String.class, MethodType.class, MethodHandle.class);
public static final Handle DELEGATE_BOOTSTRAP_HANDLE =
new Handle(Opcodes.H_INVOKESTATIC, Type.getInternalName(LambdaBootstrap.class),
"delegateBootstrap", DELEGATE_BOOTSTRAP_TYPE.toMethodDescriptorString(), false);
/** dynamic invokedynamic bootstrap for indy string concats (Java 9+) */
public static final Handle INDY_STRING_CONCAT_BOOTSTRAP_HANDLE;

View File

@ -1060,7 +1060,7 @@ public final class Walker extends PainlessParserBaseVisitor<ANode> {
for (LamtypeContext lamtype : ctx.lamtype()) {
if (lamtype.decltype() == null) {
paramTypes.add("def");
paramTypes.add(null);
} else {
paramTypes.add(lamtype.decltype().getText());
}

View File

@ -19,6 +19,7 @@
package org.elasticsearch.painless.node;
import org.elasticsearch.painless.AnalyzerCaster;
import org.elasticsearch.painless.DefBootstrap;
import org.elasticsearch.painless.Definition;
import org.elasticsearch.painless.FunctionRef;
@ -30,10 +31,10 @@ import org.elasticsearch.painless.MethodWriter;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import java.lang.invoke.LambdaMetafactory;
import java.util.Objects;
import java.util.Set;
import static org.elasticsearch.painless.Definition.VOID_TYPE;
import static org.elasticsearch.painless.WriterConstants.LAMBDA_BOOTSTRAP_HANDLE;
/**
@ -77,6 +78,17 @@ public final class ECapturingFunctionRef extends AExpression implements ILambda
if (captured.type.sort != Definition.Sort.DEF) {
try {
ref = new FunctionRef(locals.getDefinition(), expected, captured.type.name, call, 1);
// check casts between the interface method and the delegate method are legal
for (int i = 0; i < ref.interfaceMethod.arguments.size(); ++i) {
Definition.Type from = ref.interfaceMethod.arguments.get(i);
Definition.Type to = ref.delegateMethod.arguments.get(i);
AnalyzerCaster.getLegalCast(location, from, to, false, true);
}
if (ref.interfaceMethod.rtn != VOID_TYPE) {
AnalyzerCaster.getLegalCast(location, ref.delegateMethod.rtn, ref.interfaceMethod.rtn, false, true);
}
} catch (IllegalArgumentException e) {
throw createError(e);
}
@ -101,29 +113,16 @@ public final class ECapturingFunctionRef extends AExpression implements ILambda
} else {
// typed interface, typed implementation
writer.visitVarInsn(captured.type.type.getOpcode(Opcodes.ILOAD), captured.getSlot());
// convert MethodTypes to asm Type for the constant pool.
String invokedType = ref.invokedType.toMethodDescriptorString();
Type samMethodType = Type.getMethodType(ref.samMethodType.toMethodDescriptorString());
Type interfaceType = Type.getMethodType(ref.interfaceMethodType.toMethodDescriptorString());
if (ref.needsBridges()) {
writer.invokeDynamic(ref.invokedName,
invokedType,
LAMBDA_BOOTSTRAP_HANDLE,
samMethodType,
ref.implMethodASM,
samMethodType,
LambdaMetafactory.FLAG_BRIDGES,
1,
interfaceType);
} else {
writer.invokeDynamic(ref.invokedName,
invokedType,
LAMBDA_BOOTSTRAP_HANDLE,
samMethodType,
ref.implMethodASM,
samMethodType,
0);
}
writer.invokeDynamic(
ref.interfaceMethodName,
ref.factoryDescriptor,
LAMBDA_BOOTSTRAP_HANDLE,
ref.interfaceType,
ref.delegateClassName,
ref.delegateInvokeType,
ref.delegateMethodName,
ref.delegateType
);
}
}

View File

@ -19,6 +19,7 @@
package org.elasticsearch.painless.node;
import org.elasticsearch.painless.AnalyzerCaster;
import org.elasticsearch.painless.Definition;
import org.elasticsearch.painless.Definition.Method;
import org.elasticsearch.painless.Definition.MethodKey;
@ -29,10 +30,10 @@ import org.elasticsearch.painless.Location;
import org.elasticsearch.painless.MethodWriter;
import org.objectweb.asm.Type;
import java.lang.invoke.LambdaMetafactory;
import java.util.Objects;
import java.util.Set;
import static org.elasticsearch.painless.Definition.VOID_TYPE;
import static org.elasticsearch.painless.WriterConstants.LAMBDA_BOOTSTRAP_HANDLE;
/**
@ -71,16 +72,28 @@ public final class EFunctionRef extends AExpression implements ILambda {
throw new IllegalArgumentException("Cannot convert function reference [" + type + "::" + call + "] " +
"to [" + expected.name + "], not a functional interface");
}
Method implMethod = locals.getMethod(new MethodKey(call, interfaceMethod.arguments.size()));
if (implMethod == null) {
Method delegateMethod = locals.getMethod(new MethodKey(call, interfaceMethod.arguments.size()));
if (delegateMethod == null) {
throw new IllegalArgumentException("Cannot convert function reference [" + type + "::" + call + "] " +
"to [" + expected.name + "], function not found");
}
ref = new FunctionRef(expected, interfaceMethod, implMethod, 0);
ref = new FunctionRef(expected, interfaceMethod, delegateMethod, 0);
// check casts between the interface method and the delegate method are legal
for (int i = 0; i < interfaceMethod.arguments.size(); ++i) {
Definition.Type from = interfaceMethod.arguments.get(i);
Definition.Type to = delegateMethod.arguments.get(i);
AnalyzerCaster.getLegalCast(location, from, to, false, true);
}
if (interfaceMethod.rtn != VOID_TYPE) {
AnalyzerCaster.getLegalCast(location, delegateMethod.rtn, interfaceMethod.rtn, false, true);
}
} else {
// whitelist lookup
ref = new FunctionRef(locals.getDefinition(), expected, type, call, 0);
}
} catch (IllegalArgumentException e) {
throw createError(e);
}
@ -92,29 +105,16 @@ public final class EFunctionRef extends AExpression implements ILambda {
void write(MethodWriter writer, Globals globals) {
if (ref != null) {
writer.writeDebugInfo(location);
// convert MethodTypes to asm Type for the constant pool.
String invokedType = ref.invokedType.toMethodDescriptorString();
Type samMethodType = Type.getMethodType(ref.samMethodType.toMethodDescriptorString());
Type interfaceType = Type.getMethodType(ref.interfaceMethodType.toMethodDescriptorString());
if (ref.needsBridges()) {
writer.invokeDynamic(ref.invokedName,
invokedType,
LAMBDA_BOOTSTRAP_HANDLE,
samMethodType,
ref.implMethodASM,
samMethodType,
LambdaMetafactory.FLAG_BRIDGES,
1,
interfaceType);
} else {
writer.invokeDynamic(ref.invokedName,
invokedType,
LAMBDA_BOOTSTRAP_HANDLE,
samMethodType,
ref.implMethodASM,
samMethodType,
0);
}
writer.invokeDynamic(
ref.interfaceMethodName,
ref.factoryDescriptor,
LAMBDA_BOOTSTRAP_HANDLE,
ref.interfaceType,
ref.delegateClassName,
ref.delegateInvokeType,
ref.delegateMethodName,
ref.delegateType
);
} else {
// TODO: don't do this: its just to cutover :)
writer.push((String)null);

View File

@ -19,6 +19,7 @@
package org.elasticsearch.painless.node;
import org.elasticsearch.painless.AnalyzerCaster;
import org.elasticsearch.painless.Definition;
import org.elasticsearch.painless.Definition.Method;
import org.elasticsearch.painless.Definition.Type;
@ -31,7 +32,6 @@ import org.elasticsearch.painless.MethodWriter;
import org.elasticsearch.painless.node.SFunction.FunctionReserved;
import org.objectweb.asm.Opcodes;
import java.lang.invoke.LambdaMetafactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
@ -39,7 +39,10 @@ import java.util.List;
import java.util.Objects;
import java.util.Set;
import static org.elasticsearch.painless.Definition.VOID_TYPE;
import static org.elasticsearch.painless.WriterConstants.CLASS_NAME;
import static org.elasticsearch.painless.WriterConstants.LAMBDA_BOOTSTRAP_HANDLE;
import static org.objectweb.asm.Opcodes.H_INVOKESTATIC;
/**
* Lambda expression node.
@ -109,8 +112,15 @@ public final class ELambda extends AExpression implements ILambda {
interfaceMethod = null;
// we don't know anything: treat as def
returnType = Definition.DEF_TYPE;
// don't infer any types
actualParamTypeStrs = paramTypeStrs;
// don't infer any types, replace any null types with def
actualParamTypeStrs = new ArrayList<>();
for (String type : paramTypeStrs) {
if (type == null) {
actualParamTypeStrs.add("def");
} else {
actualParamTypeStrs.add(type);
}
}
} else {
// we know the method statically, infer return type and any unknown/def types
interfaceMethod = expected.struct.getFunctionalMethod();
@ -128,11 +138,11 @@ public final class ELambda extends AExpression implements ILambda {
} else {
returnType = interfaceMethod.rtn;
}
// replace any def types with the actual type (which could still be def)
actualParamTypeStrs = new ArrayList<String>();
// replace any null types with the actual type
actualParamTypeStrs = new ArrayList<>();
for (int i = 0; i < paramTypeStrs.size(); i++) {
String paramType = paramTypeStrs.get(i);
if (paramType.equals(Definition.DEF_TYPE.name)) {
if (paramType == null) {
actualParamTypeStrs.add(interfaceMethod.arguments.get(i).name);
} else {
actualParamTypeStrs.add(paramType);
@ -180,6 +190,18 @@ public final class ELambda extends AExpression implements ILambda {
} catch (IllegalArgumentException e) {
throw createError(e);
}
// check casts between the interface method and the delegate method are legal
for (int i = 0; i < interfaceMethod.arguments.size(); ++i) {
Type from = interfaceMethod.arguments.get(i);
Type to = desugared.parameters.get(i + captures.size()).type;
AnalyzerCaster.getLegalCast(location, from, to, false, true);
}
if (interfaceMethod.rtn != VOID_TYPE) {
AnalyzerCaster.getLegalCast(location, desugared.rtnType, interfaceMethod.rtn, false, true);
}
actual = expected;
}
}
@ -194,31 +216,17 @@ public final class ELambda extends AExpression implements ILambda {
for (Variable capture : captures) {
writer.visitVarInsn(capture.type.type.getOpcode(Opcodes.ILOAD), capture.getSlot());
}
// convert MethodTypes to asm Type for the constant pool.
String invokedType = ref.invokedType.toMethodDescriptorString();
org.objectweb.asm.Type samMethodType =
org.objectweb.asm.Type.getMethodType(ref.samMethodType.toMethodDescriptorString());
org.objectweb.asm.Type interfaceType =
org.objectweb.asm.Type.getMethodType(ref.interfaceMethodType.toMethodDescriptorString());
if (ref.needsBridges()) {
writer.invokeDynamic(ref.invokedName,
invokedType,
LAMBDA_BOOTSTRAP_HANDLE,
samMethodType,
ref.implMethodASM,
samMethodType,
LambdaMetafactory.FLAG_BRIDGES,
1,
interfaceType);
} else {
writer.invokeDynamic(ref.invokedName,
invokedType,
LAMBDA_BOOTSTRAP_HANDLE,
samMethodType,
ref.implMethodASM,
samMethodType,
0);
}
writer.invokeDynamic(
ref.interfaceMethodName,
ref.factoryDescriptor,
LAMBDA_BOOTSTRAP_HANDLE,
ref.interfaceType,
ref.delegateClassName,
ref.delegateInvokeType,
ref.delegateMethodName,
ref.delegateType
);
} else {
// placeholder
writer.push((String)null);

View File

@ -174,7 +174,7 @@ public final class SFunction extends AStatement {
/** Writes the function to given ClassVisitor. */
void write (ClassVisitor writer, CompilerSettings settings, Globals globals) {
int access = Opcodes.ACC_PRIVATE | Opcodes.ACC_STATIC;
int access = Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC;
if (synthetic) {
access |= Opcodes.ACC_SYNTHETIC;
}

View File

@ -155,4 +155,5 @@ class org.elasticsearch.painless.FeatureTest -> org.elasticsearch.painless.Featu
boolean overloadedStatic()
boolean overloadedStatic(boolean)
Object twoFunctionsOfX(Function,Function)
void listInput(List)
}

View File

@ -19,35 +19,29 @@
package org.elasticsearch.painless;
import org.apache.lucene.util.Constants;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
public class AugmentationTests extends ScriptTestCase {
public void testStatic() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1, exec("ArrayList l = new ArrayList(); l.add(1); return l.getLength();"));
assertEquals(1, exec("ArrayList l = new ArrayList(); l.add(1); return l.length;"));
}
public void testSubclass() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1, exec("List l = new ArrayList(); l.add(1); return l.getLength();"));
assertEquals(1, exec("List l = new ArrayList(); l.add(1); return l.length;"));
}
public void testDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1, exec("def l = new ArrayList(); l.add(1); return l.getLength();"));
assertEquals(1, exec("def l = new ArrayList(); l.add(1); return l.length;"));
}
public void testCapturingReference() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1, exec("int foo(Supplier t) { return t.get() }" +
"ArrayList l = new ArrayList(); l.add(1);" +
"return foo(l::getLength);"));
@ -58,164 +52,140 @@ public class AugmentationTests extends ScriptTestCase {
"def l = new ArrayList(); l.add(1);" +
"return foo(l::getLength);"));
}
public void testIterable_Any() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(true,
assertEquals(true,
exec("List l = new ArrayList(); l.add(1); l.any(x -> x == 1)"));
}
public void testIterable_AsCollection() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(true,
assertEquals(true,
exec("List l = new ArrayList(); return l.asCollection() === l"));
}
public void testIterable_AsList() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(true,
assertEquals(true,
exec("List l = new ArrayList(); return l.asList() === l"));
assertEquals(5,
assertEquals(5,
exec("Set l = new HashSet(); l.add(5); return l.asList()[0]"));
}
public void testIterable_Each() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1,
assertEquals(1,
exec("List l = new ArrayList(); l.add(1); List l2 = new ArrayList(); l.each(l2::add); return l2.size()"));
}
public void testIterable_EachWithIndex() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(0,
assertEquals(0,
exec("List l = new ArrayList(); l.add(2); Map m = new HashMap(); l.eachWithIndex(m::put); return m.get(2)"));
}
public void testIterable_Every() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(false, exec("List l = new ArrayList(); l.add(1); l.add(2); l.every(x -> x == 1)"));
}
public void testIterable_FindResults() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1,
assertEquals(1,
exec("List l = new ArrayList(); l.add(1); l.add(2); l.findResults(x -> x == 1 ? x : null).size()"));
}
public void testIterable_GroupBy() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2,
assertEquals(2,
exec("List l = new ArrayList(); l.add(1); l.add(-1); l.groupBy(x -> x < 0 ? 'negative' : 'positive').size()"));
}
public void testIterable_Join() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("test,ing",
assertEquals("test,ing",
exec("List l = new ArrayList(); l.add('test'); l.add('ing'); l.join(',')"));
}
public void testIterable_Sum() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(3.0D, exec("def l = [1,2]; return l.sum()"));
assertEquals(5.0D,
assertEquals(5.0D,
exec("List l = new ArrayList(); l.add(1); l.add(2); l.sum(x -> x + 1)"));
}
public void testCollection_Collect() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(Arrays.asList(2, 3),
assertEquals(Arrays.asList(2, 3),
exec("List l = new ArrayList(); l.add(1); l.add(2); l.collect(x -> x + 1)"));
assertEquals(asSet(2, 3),
assertEquals(asSet(2, 3),
exec("List l = new ArrayList(); l.add(1); l.add(2); l.collect(new HashSet(), x -> x + 1)"));
}
public void testCollection_Find() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2,
assertEquals(2,
exec("List l = new ArrayList(); l.add(1); l.add(2); return l.find(x -> x == 2)"));
}
public void testCollection_FindAll() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(Arrays.asList(2),
assertEquals(Arrays.asList(2),
exec("List l = new ArrayList(); l.add(1); l.add(2); return l.findAll(x -> x == 2)"));
}
public void testCollection_FindResult() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("found",
assertEquals("found",
exec("List l = new ArrayList(); l.add(1); l.add(2); return l.findResult(x -> x > 1 ? 'found' : null)"));
assertEquals("notfound",
assertEquals("notfound",
exec("List l = new ArrayList(); l.add(1); l.add(2); return l.findResult('notfound', x -> x > 10 ? 'found' : null)"));
}
public void testCollection_Split() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(Arrays.asList(Arrays.asList(2), Arrays.asList(1)),
assertEquals(Arrays.asList(Arrays.asList(2), Arrays.asList(1)),
exec("List l = new ArrayList(); l.add(1); l.add(2); return l.split(x -> x == 2)"));
}
public void testMap_Collect() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(Arrays.asList("one1", "two2"),
assertEquals(Arrays.asList("one1", "two2"),
exec("Map m = new TreeMap(); m.one = 1; m.two = 2; m.collect((key,value) -> key + value)"));
assertEquals(asSet("one1", "two2"),
assertEquals(asSet("one1", "two2"),
exec("Map m = new TreeMap(); m.one = 1; m.two = 2; m.collect(new HashSet(), (key,value) -> key + value)"));
}
public void testMap_Count() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1,
assertEquals(1,
exec("Map m = new TreeMap(); m.one = 1; m.two = 2; m.count((key,value) -> value == 2)"));
}
public void testMap_Each() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2,
assertEquals(2,
exec("Map m = new TreeMap(); m.one = 1; m.two = 2; Map m2 = new TreeMap(); m.each(m2::put); return m2.size()"));
}
public void testMap_Every() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(false,
assertEquals(false,
exec("Map m = new TreeMap(); m.one = 1; m.two = 2; m.every((key,value) -> value == 2)"));
}
public void testMap_Find() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("two",
assertEquals("two",
exec("Map m = new TreeMap(); m.one = 1; m.two = 2; return m.find((key,value) -> value == 2).key"));
}
public void testMap_FindAll() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(Collections.singletonMap("two", 2),
assertEquals(Collections.singletonMap("two", 2),
exec("Map m = new TreeMap(); m.one = 1; m.two = 2; return m.findAll((key,value) -> value == 2)"));
}
public void testMap_FindResult() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("found",
assertEquals("found",
exec("Map m = new TreeMap(); m.one = 1; m.two = 2; return m.findResult((key,value) -> value == 2 ? 'found' : null)"));
assertEquals("notfound",
exec("Map m = new TreeMap(); m.one = 1; m.two = 2; " +
assertEquals("notfound",
exec("Map m = new TreeMap(); m.one = 1; m.two = 2; " +
"return m.findResult('notfound', (key,value) -> value == 10 ? 'found' : null)"));
}
public void testMap_FindResults() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(Arrays.asList("negative", "positive"),
exec("Map m = new TreeMap(); m.a = -1; m.b = 1; " +
exec("Map m = new TreeMap(); m.a = -1; m.b = 1; " +
"return m.findResults((key,value) -> value < 0 ? 'negative' : 'positive')"));
}
public void testMap_GroupBy() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
Map<String,Map<String,Integer>> expected = new HashMap<>();
expected.put("negative", Collections.singletonMap("a", -1));
expected.put("positive", Collections.singletonMap("b", 1));
assertEquals(expected,
exec("Map m = new TreeMap(); m.a = -1; m.b = 1; " +
exec("Map m = new TreeMap(); m.a = -1; m.b = 1; " +
"return m.groupBy((key,value) -> value < 0 ? 'negative' : 'positive')"));
}
}

View File

@ -19,7 +19,6 @@
package org.elasticsearch.painless;
import org.apache.lucene.util.Constants;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
@ -33,39 +32,32 @@ import static org.hamcrest.Matchers.startsWith;
public class FunctionRefTests extends ScriptTestCase {
public void testStaticMethodReference() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1, exec("List l = new ArrayList(); l.add(2); l.add(1); l.sort(Integer::compare); return l.get(0);"));
}
public void testStaticMethodReferenceDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1, exec("def l = new ArrayList(); l.add(2); l.add(1); l.sort(Integer::compare); return l.get(0);"));
}
public void testVirtualMethodReference() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2, exec("List l = new ArrayList(); l.add(1); l.add(1); return l.stream().mapToInt(Integer::intValue).sum();"));
}
public void testVirtualMethodReferenceDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2, exec("def l = new ArrayList(); l.add(1); l.add(1); return l.stream().mapToInt(Integer::intValue).sum();"));
}
public void testQualifiedStaticMethodReference() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(true,
exec("List l = [true]; l.stream().map(org.elasticsearch.painless.FeatureTest::overloadedStatic).findFirst().get()"));
}
public void testQualifiedStaticMethodReferenceDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(true,
exec("def l = [true]; l.stream().map(org.elasticsearch.painless.FeatureTest::overloadedStatic).findFirst().get()"));
}
public void testQualifiedVirtualMethodReference() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
long instant = randomLong();
assertEquals(instant, exec(
"List l = [params.d]; return l.stream().mapToLong(org.joda.time.ReadableDateTime::getMillis).sum()",
@ -73,7 +65,6 @@ public class FunctionRefTests extends ScriptTestCase {
}
public void testQualifiedVirtualMethodReferenceDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
long instant = randomLong();
assertEquals(instant, exec(
"def l = [params.d]; return l.stream().mapToLong(org.joda.time.ReadableDateTime::getMillis).sum()",
@ -81,129 +72,112 @@ public class FunctionRefTests extends ScriptTestCase {
}
public void testCtorMethodReference() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(3.0D,
exec("List l = new ArrayList(); l.add(1.0); l.add(2.0); " +
"DoubleStream doubleStream = l.stream().mapToDouble(Double::doubleValue);" +
assertEquals(3.0D,
exec("List l = new ArrayList(); l.add(1.0); l.add(2.0); " +
"DoubleStream doubleStream = l.stream().mapToDouble(Double::doubleValue);" +
"DoubleSummaryStatistics stats = doubleStream.collect(DoubleSummaryStatistics::new, " +
"DoubleSummaryStatistics::accept, " +
"DoubleSummaryStatistics::combine); " +
"DoubleSummaryStatistics::combine); " +
"return stats.getSum()"));
}
public void testCtorMethodReferenceDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(3.0D,
exec("def l = new ArrayList(); l.add(1.0); l.add(2.0); " +
"def doubleStream = l.stream().mapToDouble(Double::doubleValue);" +
assertEquals(3.0D,
exec("def l = new ArrayList(); l.add(1.0); l.add(2.0); " +
"def doubleStream = l.stream().mapToDouble(Double::doubleValue);" +
"def stats = doubleStream.collect(DoubleSummaryStatistics::new, " +
"DoubleSummaryStatistics::accept, " +
"DoubleSummaryStatistics::combine); " +
"DoubleSummaryStatistics::combine); " +
"return stats.getSum()"));
}
public void testArrayCtorMethodRef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1.0D,
exec("List l = new ArrayList(); l.add(1.0); l.add(2.0); " +
"def[] array = l.stream().toArray(Double[]::new);" +
assertEquals(1.0D,
exec("List l = new ArrayList(); l.add(1.0); l.add(2.0); " +
"def[] array = l.stream().toArray(Double[]::new);" +
"return array[0];"));
}
public void testArrayCtorMethodRefDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1.0D,
exec("def l = new ArrayList(); l.add(1.0); l.add(2.0); " +
"def[] array = l.stream().toArray(Double[]::new);" +
assertEquals(1.0D,
exec("def l = new ArrayList(); l.add(1.0); l.add(2.0); " +
"def[] array = l.stream().toArray(Double[]::new);" +
"return array[0];"));
}
public void testCapturingMethodReference() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("5", exec("Integer x = Integer.valueOf(5); return Optional.empty().orElseGet(x::toString);"));
assertEquals("[]", exec("List l = new ArrayList(); return Optional.empty().orElseGet(l::toString);"));
}
public void testCapturingMethodReferenceDefImpl() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("5", exec("def x = Integer.valueOf(5); return Optional.empty().orElseGet(x::toString);"));
assertEquals("[]", exec("def l = new ArrayList(); return Optional.empty().orElseGet(l::toString);"));
}
public void testCapturingMethodReferenceDefInterface() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("5", exec("Integer x = Integer.valueOf(5); def opt = Optional.empty(); return opt.orElseGet(x::toString);"));
assertEquals("[]", exec("List l = new ArrayList(); def opt = Optional.empty(); return opt.orElseGet(l::toString);"));
}
public void testCapturingMethodReferenceDefEverywhere() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("5", exec("def x = Integer.valueOf(5); def opt = Optional.empty(); return opt.orElseGet(x::toString);"));
assertEquals("[]", exec("def l = new ArrayList(); def opt = Optional.empty(); return opt.orElseGet(l::toString);"));
}
public void testCapturingMethodReferenceMultipleLambdas() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("testingcdefg", exec(
"String x = 'testing';" +
"String y = 'abcdefg';" +
"org.elasticsearch.painless.FeatureTest test = new org.elasticsearch.painless.FeatureTest(2,3);" +
"String y = 'abcdefg';" +
"org.elasticsearch.painless.FeatureTest test = new org.elasticsearch.painless.FeatureTest(2,3);" +
"return test.twoFunctionsOfX(x::concat, y::substring);"));
}
public void testCapturingMethodReferenceMultipleLambdasDefImpls() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("testingcdefg", exec(
"def x = 'testing';" +
"def y = 'abcdefg';" +
"org.elasticsearch.painless.FeatureTest test = new org.elasticsearch.painless.FeatureTest(2,3);" +
"def y = 'abcdefg';" +
"org.elasticsearch.painless.FeatureTest test = new org.elasticsearch.painless.FeatureTest(2,3);" +
"return test.twoFunctionsOfX(x::concat, y::substring);"));
}
public void testCapturingMethodReferenceMultipleLambdasDefInterface() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("testingcdefg", exec(
"String x = 'testing';" +
"String y = 'abcdefg';" +
"def test = new org.elasticsearch.painless.FeatureTest(2,3);" +
"String y = 'abcdefg';" +
"def test = new org.elasticsearch.painless.FeatureTest(2,3);" +
"return test.twoFunctionsOfX(x::concat, y::substring);"));
}
public void testCapturingMethodReferenceMultipleLambdasDefEverywhere() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("testingcdefg", exec(
"def x = 'testing';" +
"def y = 'abcdefg';" +
"def test = new org.elasticsearch.painless.FeatureTest(2,3);" +
"def y = 'abcdefg';" +
"def test = new org.elasticsearch.painless.FeatureTest(2,3);" +
"return test.twoFunctionsOfX(x::concat, y::substring);"));
}
public void testOwnStaticMethodReference() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2, exec("int mycompare(int i, int j) { j - i } " +
"List l = new ArrayList(); l.add(2); l.add(1); l.sort(this::mycompare); return l.get(0);"));
}
public void testOwnStaticMethodReferenceDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2, exec("int mycompare(int i, int j) { j - i } " +
"def l = new ArrayList(); l.add(2); l.add(1); l.sort(this::mycompare); return l.get(0);"));
}
public void testInterfaceDefaultMethod() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("bar", exec("String f(BiFunction function) { function.apply('foo', 'bar') }" +
assertEquals("bar", exec("String f(BiFunction function) { function.apply('foo', 'bar') }" +
"Map map = new HashMap(); f(map::getOrDefault)"));
}
public void testInterfaceDefaultMethodDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("bar", exec("String f(BiFunction function) { function.apply('foo', 'bar') }" +
assertEquals("bar", exec("String f(BiFunction function) { function.apply('foo', 'bar') }" +
"def map = new HashMap(); f(map::getOrDefault)"));
}
public void testMethodMissing() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
Exception e = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("List l = [2, 1]; l.sort(Integer::bogus); return l.get(0);");
});
@ -211,7 +185,6 @@ public class FunctionRefTests extends ScriptTestCase {
}
public void testQualifiedMethodMissing() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
Exception e = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("List l = [2, 1]; l.sort(org.joda.time.ReadableDateTime::bogus); return l.get(0);", false);
});
@ -219,7 +192,6 @@ public class FunctionRefTests extends ScriptTestCase {
}
public void testClassMissing() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
Exception e = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("List l = [2, 1]; l.sort(Bogus::bogus); return l.get(0);", false);
});
@ -227,7 +199,6 @@ public class FunctionRefTests extends ScriptTestCase {
}
public void testQualifiedClassMissing() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
Exception e = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("List l = [2, 1]; l.sort(org.joda.time.BogusDateTime::bogus); return l.get(0);", false);
});
@ -237,7 +208,6 @@ public class FunctionRefTests extends ScriptTestCase {
}
public void testNotFunctionalInterface() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("List l = new ArrayList(); l.add(2); l.add(1); l.add(Integer::bogus); return l.get(0);");
});
@ -245,38 +215,33 @@ public class FunctionRefTests extends ScriptTestCase {
}
public void testIncompatible() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
expectScriptThrows(BootstrapMethodError.class, () -> {
exec("List l = new ArrayList(); l.add(2); l.add(1); l.sort(String::startsWith); return l.get(0);");
});
}
public void testWrongArity() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("Optional.empty().orElseGet(String::startsWith);");
});
assertThat(expected.getMessage(), containsString("Unknown reference"));
}
public void testWrongArityNotEnough() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("List l = new ArrayList(); l.add(2); l.add(1); l.sort(String::isEmpty);");
});
assertTrue(expected.getMessage().contains("Unknown reference"));
}
public void testWrongArityDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("def y = Optional.empty(); return y.orElseGet(String::startsWith);");
});
assertThat(expected.getMessage(), containsString("Unknown reference"));
}
public void testWrongArityNotEnoughDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("def l = new ArrayList(); l.add(2); l.add(1); l.sort(String::isEmpty);");
});
@ -284,29 +249,26 @@ public class FunctionRefTests extends ScriptTestCase {
}
public void testReturnVoid() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
Throwable expected = expectScriptThrows(BootstrapMethodError.class, () -> {
exec("StringBuilder b = new StringBuilder(); List l = [1, 2]; l.stream().mapToLong(b::setLength);");
Throwable expected = expectScriptThrows(ClassCastException.class, () -> {
exec("StringBuilder b = new StringBuilder(); List l = [1, 2]; l.stream().mapToLong(b::setLength).sum();");
});
assertThat(expected.getCause().getMessage(),
containsString("Type mismatch for lambda expected return: void is not convertible to long"));
assertThat(expected.getMessage(), containsString("Cannot cast from [void] to [long]."));
}
public void testReturnVoidDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
Exception expected = expectScriptThrows(LambdaConversionException.class, () -> {
exec("StringBuilder b = new StringBuilder(); def l = [1, 2]; l.stream().mapToLong(b::setLength);");
});
assertThat(expected.getMessage(), containsString("Type mismatch for lambda expected return: void is not convertible to long"));
assertThat(expected.getMessage(), containsString("lambda expects return type [long], but found return type [void]"));
expected = expectScriptThrows(LambdaConversionException.class, () -> {
exec("def b = new StringBuilder(); def l = [1, 2]; l.stream().mapToLong(b::setLength);");
});
assertThat(expected.getMessage(), containsString("Type mismatch for lambda expected return: void is not convertible to long"));
assertThat(expected.getMessage(), containsString("lambda expects return type [long], but found return type [void]"));
expected = expectScriptThrows(LambdaConversionException.class, () -> {
exec("def b = new StringBuilder(); List l = [1, 2]; l.stream().mapToLong(b::setLength);");
});
assertThat(expected.getMessage(), containsString("Type mismatch for lambda expected return: void is not convertible to long"));
assertThat(expected.getMessage(), containsString("lambda expects return type [long], but found return type [void]"));
}
}

View File

@ -19,8 +19,6 @@
package org.elasticsearch.painless;
import org.apache.lucene.util.Constants;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@ -30,87 +28,72 @@ import static org.hamcrest.Matchers.containsString;
public class LambdaTests extends ScriptTestCase {
public void testNoArgLambda() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1, exec("Optional.empty().orElseGet(() -> 1);"));
}
public void testNoArgLambdaDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1, exec("def x = Optional.empty(); x.orElseGet(() -> 1);"));
}
public void testLambdaWithArgs() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("short", exec("List l = new ArrayList(); l.add('looooong'); l.add('short'); "
+ "l.sort((a, b) -> a.length() - b.length()); return l.get(0)"));
}
public void testLambdaWithTypedArgs() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("short", exec("List l = new ArrayList(); l.add('looooong'); l.add('short'); "
+ "l.sort((String a, String b) -> a.length() - b.length()); return l.get(0)"));
}
public void testPrimitiveLambdas() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(4, exec("List l = new ArrayList(); l.add(1); l.add(1); "
+ "return l.stream().mapToInt(x -> x + 1).sum();"));
}
public void testPrimitiveLambdasWithTypedArgs() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(4, exec("List l = new ArrayList(); l.add(1); l.add(1); "
+ "return l.stream().mapToInt(int x -> x + 1).sum();"));
}
public void testPrimitiveLambdasDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(4, exec("def l = new ArrayList(); l.add(1); l.add(1); "
+ "return l.stream().mapToInt(x -> x + 1).sum();"));
}
public void testPrimitiveLambdasWithTypedArgsDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(4, exec("def l = new ArrayList(); l.add(1); l.add(1); "
+ "return l.stream().mapToInt(int x -> x + 1).sum();"));
}
public void testPrimitiveLambdasConvertible() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2, exec("List l = new ArrayList(); l.add(1); l.add(1); "
+ "return l.stream().mapToInt(byte x -> x).sum();"));
assertEquals(2, exec("List l = new ArrayList(); l.add((short)1); l.add(1); "
+ "return l.stream().mapToInt(long x -> (int)1).sum();"));
}
public void testPrimitiveArgs() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2, exec("int applyOne(IntFunction arg) { arg.apply(1) } applyOne(x -> x + 1)"));
}
public void testPrimitiveArgsTyped() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2, exec("int applyOne(IntFunction arg) { arg.apply(1) } applyOne(int x -> x + 1)"));
}
public void testPrimitiveArgsTypedOddly() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2L, exec("long applyOne(IntFunction arg) { arg.apply(1) } applyOne(long x -> x + 1)"));
}
public void testMultipleStatements() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2, exec("int applyOne(IntFunction arg) { arg.apply(1) } applyOne(x -> { def y = x + 1; return y })"));
}
public void testUnneededCurlyStatements() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2, exec("int applyOne(IntFunction arg) { arg.apply(1) } applyOne(x -> { x + 1 })"));
}
/** interface ignores return value */
public void testVoidReturn() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2, exec("List list = new ArrayList(); "
+ "list.add(2); "
+ "List list2 = new ArrayList(); "
@ -120,7 +103,6 @@ public class LambdaTests extends ScriptTestCase {
/** interface ignores return value */
public void testVoidReturnDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2, exec("def list = new ArrayList(); "
+ "list.add(2); "
+ "List list2 = new ArrayList(); "
@ -129,19 +111,16 @@ public class LambdaTests extends ScriptTestCase {
}
public void testTwoLambdas() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("testingcdefg", exec(
"org.elasticsearch.painless.FeatureTest test = new org.elasticsearch.painless.FeatureTest(2,3);" +
"return test.twoFunctionsOfX(x -> 'testing'.concat(x), y -> 'abcdefg'.substring(y))"));
}
public void testNestedLambdas() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1, exec("Optional.empty().orElseGet(() -> Optional.empty().orElseGet(() -> 1));"));
}
public void testLambdaInLoop() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(100, exec("int sum = 0; " +
"for (int i = 0; i < 100; i++) {" +
" sum += Optional.empty().orElseGet(() -> 1);" +
@ -150,17 +129,14 @@ public class LambdaTests extends ScriptTestCase {
}
public void testCapture() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(5, exec("int x = 5; return Optional.empty().orElseGet(() -> x);"));
}
public void testTwoCaptures() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals("1test", exec("int x = 1; String y = 'test'; return Optional.empty().orElseGet(() -> x + y);"));
}
public void testCapturesAreReadOnly() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("List l = new ArrayList(); l.add(1); l.add(1); "
+ "return l.stream().mapToInt(x -> { l = null; return x + 1 }).sum();");
@ -170,14 +146,12 @@ public class LambdaTests extends ScriptTestCase {
@AwaitsFix(bugUrl = "def type tracking")
public void testOnlyCapturesAreReadOnly() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(4, exec("List l = new ArrayList(); l.add(1); l.add(1); "
+ "return l.stream().mapToInt(x -> { x += 1; return x }).sum();"));
}
/** Lambda parameters shouldn't be able to mask a variable already in scope */
public void testNoParamMasking() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("int x = 0; List l = new ArrayList(); l.add(1); l.add(1); "
+ "return l.stream().mapToInt(x -> { x += 1; return x }).sum();");
@ -186,24 +160,20 @@ public class LambdaTests extends ScriptTestCase {
}
public void testCaptureDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(5, exec("int x = 5; def y = Optional.empty(); y.orElseGet(() -> x);"));
}
public void testNestedCapture() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(1, exec("boolean x = false; int y = 1;" +
"return Optional.empty().orElseGet(() -> x ? 5 : Optional.empty().orElseGet(() -> y));"));
}
public void testNestedCaptureParams() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(2, exec("int foo(Function f) { return f.apply(1) }" +
"return foo(x -> foo(y -> x + 1))"));
}
public void testWrongArity() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, false, () -> {
exec("Optional.empty().orElseGet(x -> x);");
});
@ -211,7 +181,6 @@ public class LambdaTests extends ScriptTestCase {
}
public void testWrongArityDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("def y = Optional.empty(); return y.orElseGet(x -> x);");
});
@ -219,7 +188,6 @@ public class LambdaTests extends ScriptTestCase {
}
public void testWrongArityNotEnough() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, false, () -> {
exec("List l = new ArrayList(); l.add(1); l.add(1); "
+ "return l.stream().mapToInt(() -> 5).sum();");
@ -228,7 +196,6 @@ public class LambdaTests extends ScriptTestCase {
}
public void testWrongArityNotEnoughDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("def l = new ArrayList(); l.add(1); l.add(1); "
+ "return l.stream().mapToInt(() -> 5).sum();");
@ -237,17 +204,14 @@ public class LambdaTests extends ScriptTestCase {
}
public void testLambdaInFunction() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(5, exec("def foo() { Optional.empty().orElseGet(() -> 5) } return foo();"));
}
public void testLambdaCaptureFunctionParam() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
assertEquals(5, exec("def foo(int x) { Optional.empty().orElseGet(() -> x) } return foo(5);"));
}
public void testReservedCapture() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
String compare = "boolean compare(Supplier s, def v) {s.get() == v}";
assertEquals(true, exec(compare + "compare(() -> new ArrayList(), new ArrayList())"));
assertEquals(true, exec(compare + "compare(() -> { new ArrayList() }, new ArrayList())"));
@ -272,7 +236,6 @@ public class LambdaTests extends ScriptTestCase {
}
public void testReturnVoid() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
Throwable expected = expectScriptThrows(ClassCastException.class, () -> {
exec("StringBuilder b = new StringBuilder(); List l = [1, 2]; l.stream().mapToLong(i -> b.setLength(i))");
});
@ -280,7 +243,6 @@ public class LambdaTests extends ScriptTestCase {
}
public void testReturnVoidDef() {
assumeFalse("JDK is JDK 9", Constants.JRE_IS_MINIMUM_JAVA9);
// If we can catch the error at compile time we do
Exception expected = expectScriptThrows(ClassCastException.class, () -> {
exec("StringBuilder b = new StringBuilder(); def l = [1, 2]; l.stream().mapToLong(i -> b.setLength(i))");

View File

@ -240,7 +240,7 @@ public class NodeToStringTests extends ESTestCase {
+ "}).sum()");
assertToString(
"(SSource (SReturn (PCallInvoke (PCallInvoke (PCallInvoke (EListInit (ENumeric 1) (ENumeric 2) (ENumeric 3)) stream) "
+ "mapToInt (Args (ELambda (Pair def x)\n"
+ "mapToInt (Args (ELambda (Pair null x)\n"
+ " (SReturn (EBinary (EVariable x) + (ENumeric 1)))))) sum)))",
"return [1, 2, 3].stream().mapToInt(x -> x + 1).sum()");
assertToString(
@ -250,7 +250,7 @@ public class NodeToStringTests extends ESTestCase {
+ " return a.length() - b.length()\n"
+ "})");
assertToString(
"(SSource (SReturn (PCallInvoke (EListInit (EString 'a') (EString 'b')) sort (Args (ELambda (Pair def a) (Pair def b)\n"
"(SSource (SReturn (PCallInvoke (EListInit (EString 'a') (EString 'b')) sort (Args (ELambda (Pair null a) (Pair null b)\n"
+ " (SReturn (EBinary (PCallInvoke (EVariable a) length) - (PCallInvoke (EVariable b) length))))))))",
"return ['a', 'b'].sort((a, b) -> a.length() - b.length())");
assertToString(
@ -371,14 +371,14 @@ public class NodeToStringTests extends ESTestCase {
assertToString(
"(SSource\n"
+ " (SDeclBlock (SDeclaration int[] a (ENewArray int dims (Args (ENumeric 10)))))\n"
+ " (SReturn (PField (EVariable a) length)))",
+ " (SReturn (PField (EVariable a) length)))",
"int[] a = new int[10];\n"
+ "return a.length");
assertToString(
"(SSource\n"
+ " (SDeclBlock (SDeclaration org.elasticsearch.painless.FeatureTest a (ENewObj org.elasticsearch.painless.FeatureTest)))\n"
+ " (SExpression (EAssignment (PField (EVariable a) x) = (ENumeric 10)))\n"
+ " (SReturn (PField (EVariable a) x)))",
+ " (SReturn (PField (EVariable a) x)))",
"org.elasticsearch.painless.FeatureTest a = new org.elasticsearch.painless.FeatureTest();\n"
+ "a.x = 10;\n"
+ "return a.x");