Infer lambda arguments/return type
This commit is contained in:
parent
09305a0f98
commit
1cc0264827
|
@ -69,8 +69,9 @@ public final class Locals {
|
|||
* <p>
|
||||
* This is just like {@link #newFunctionScope}, except the captured parameters are made read-only.
|
||||
*/
|
||||
public static Locals newLambdaScope(Locals programScope, List<Parameter> parameters, int captureCount, int maxLoopCounter) {
|
||||
Locals locals = new Locals(programScope, Definition.DEF_TYPE);
|
||||
public static Locals newLambdaScope(Locals programScope, Type returnType, List<Parameter> parameters,
|
||||
int captureCount, int maxLoopCounter) {
|
||||
Locals locals = new Locals(programScope, returnType);
|
||||
for (int i = 0; i < parameters.size(); i++) {
|
||||
Parameter parameter = parameters.get(i);
|
||||
boolean isCapture = i < captureCount;
|
||||
|
|
|
@ -100,6 +100,44 @@ public class ELambda extends AExpression implements ILambda {
|
|||
|
||||
@Override
|
||||
void analyze(Locals locals) {
|
||||
final Definition.Type returnType;
|
||||
final List<String> actualParamTypeStrs;
|
||||
Method interfaceMethod;
|
||||
// inspect the target first, set interface method if we know it.
|
||||
if (expected == null) {
|
||||
interfaceMethod = null;
|
||||
// we don't know anything: treat as def
|
||||
returnType = Definition.DEF_TYPE;
|
||||
// don't infer any types
|
||||
actualParamTypeStrs = paramTypeStrs;
|
||||
} else {
|
||||
// we know the method statically, infer return type and any unknown/def types
|
||||
interfaceMethod = expected.struct.getFunctionalMethod();
|
||||
if (interfaceMethod == null) {
|
||||
throw createError(new IllegalArgumentException("Cannot pass lambda to [" + expected.name +
|
||||
"], not a functional interface"));
|
||||
}
|
||||
// check arity before we manipulate parameters
|
||||
if (interfaceMethod.arguments.size() != paramTypeStrs.size())
|
||||
throw new IllegalArgumentException("Incorrect number of parameters for [" + interfaceMethod.name +
|
||||
"] in [" + expected.clazz + "]");
|
||||
// for method invocation, its allowed to ignore the return value
|
||||
if (interfaceMethod.rtn == Definition.VOID_TYPE) {
|
||||
returnType = Definition.DEF_TYPE;
|
||||
} else {
|
||||
returnType = interfaceMethod.rtn;
|
||||
}
|
||||
// replace any def types with the actual type (which could still be def)
|
||||
actualParamTypeStrs = new ArrayList<String>();
|
||||
for (int i = 0; i < paramTypeStrs.size(); i++) {
|
||||
String paramType = paramTypeStrs.get(i);
|
||||
if (paramType.equals(Definition.DEF_TYPE.name)) {
|
||||
actualParamTypeStrs.add(interfaceMethod.arguments.get(i).name);
|
||||
} else {
|
||||
actualParamTypeStrs.add(paramType);
|
||||
}
|
||||
}
|
||||
}
|
||||
// gather any variables used by the lambda body first.
|
||||
Set<String> variables = new HashSet<>();
|
||||
for (AStatement statement : statements) {
|
||||
|
@ -119,14 +157,14 @@ public class ELambda extends AExpression implements ILambda {
|
|||
paramTypes.add(var.type.name);
|
||||
paramNames.add(var.name);
|
||||
}
|
||||
paramTypes.addAll(paramTypeStrs);
|
||||
paramTypes.addAll(actualParamTypeStrs);
|
||||
paramNames.addAll(paramNameStrs);
|
||||
|
||||
// desugar lambda body into a synthetic method
|
||||
desugared = new SFunction(reserved, location, "def", name,
|
||||
desugared = new SFunction(reserved, location, returnType.name, name,
|
||||
paramTypes, paramNames, statements, true);
|
||||
desugared.generate();
|
||||
desugared.analyze(Locals.newLambdaScope(locals.getProgramScope(), desugared.parameters,
|
||||
desugared.analyze(Locals.newLambdaScope(locals.getProgramScope(), returnType, desugared.parameters,
|
||||
captures.size(), reserved.getMaxLoopCounter()));
|
||||
|
||||
// setup method reference to synthetic method
|
||||
|
@ -137,11 +175,6 @@ public class ELambda extends AExpression implements ILambda {
|
|||
} else {
|
||||
defPointer = null;
|
||||
try {
|
||||
Method interfaceMethod = expected.struct.getFunctionalMethod();
|
||||
if (interfaceMethod == null) {
|
||||
throw new IllegalArgumentException("Cannot pass lambda to [" + expected.name +
|
||||
"], not a functional interface");
|
||||
}
|
||||
Class<?> captureClasses[] = new Class<?>[captures.size()];
|
||||
for (int i = 0; i < captures.size(); i++) {
|
||||
captureClasses[i] = captures.get(i).type.clazz;
|
||||
|
|
|
@ -21,13 +21,16 @@ package org.elasticsearch.painless.node;
|
|||
|
||||
import org.elasticsearch.painless.Definition.MethodKey;
|
||||
import org.elasticsearch.painless.Location;
|
||||
import org.elasticsearch.painless.Definition;
|
||||
import org.elasticsearch.painless.Definition.Method;
|
||||
import org.elasticsearch.painless.Definition.Sort;
|
||||
import org.elasticsearch.painless.Definition.Struct;
|
||||
import org.elasticsearch.painless.Definition.Type;
|
||||
import org.elasticsearch.painless.Globals;
|
||||
import org.elasticsearch.painless.Locals;
|
||||
import org.elasticsearch.painless.MethodWriter;
|
||||
|
||||
import java.lang.invoke.MethodType;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
@ -41,6 +44,8 @@ public final class LCallInvoke extends ALink {
|
|||
final List<AExpression> arguments;
|
||||
|
||||
Method method = null;
|
||||
|
||||
boolean box = false; // true for primitive types
|
||||
|
||||
public LCallInvoke(Location location, String name, List<AExpression> arguments) {
|
||||
super(location, -1);
|
||||
|
@ -68,6 +73,12 @@ public final class LCallInvoke extends ALink {
|
|||
|
||||
MethodKey methodKey = new MethodKey(name, arguments.size());
|
||||
Struct struct = before.struct;
|
||||
if (before.clazz.isPrimitive()) {
|
||||
Class<?> wrapper = MethodType.methodType(before.clazz).wrap().returnType();
|
||||
Type boxed = Definition.getType(wrapper.getSimpleName());
|
||||
struct = boxed.struct;
|
||||
box = true;
|
||||
}
|
||||
method = statik ? struct.staticMethods.get(methodKey) : struct.methods.get(methodKey);
|
||||
|
||||
if (method != null) {
|
||||
|
@ -103,6 +114,10 @@ public final class LCallInvoke extends ALink {
|
|||
@Override
|
||||
void load(MethodWriter writer, Globals globals) {
|
||||
writer.writeDebugInfo(location);
|
||||
|
||||
if (box) {
|
||||
writer.box(before.type);
|
||||
}
|
||||
|
||||
for (AExpression argument : arguments) {
|
||||
argument.write(writer, globals);
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
|
||||
package org.elasticsearch.painless;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -119,4 +118,10 @@ public class BasicAPITests extends ScriptTestCase {
|
|||
assertEquals("{}", exec("Map map = new HashMap(); return map.toString();"));
|
||||
assertEquals("{}", exec("def map = new HashMap(); return map.toString();"));
|
||||
}
|
||||
|
||||
public void testPrimitivesHaveMethods() {
|
||||
assertEquals(5, exec("int x = 5; return x.intValue();"));
|
||||
assertEquals("5", exec("int x = 5; return x.toString();"));
|
||||
assertEquals(0, exec("int x = 5; return x.compareTo(5);"));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -452,4 +452,14 @@ public class DefOptimizationTests extends ScriptTestCase {
|
|||
assertBytecodeExists("def x = 1; double y = +x; return y",
|
||||
"INVOKEDYNAMIC plus(Ljava/lang/Object;)D");
|
||||
}
|
||||
|
||||
public void testLambdaReturnType() {
|
||||
assertBytecodeExists("List l = new ArrayList(); l.removeIf(x -> x < 10)",
|
||||
"synthetic lambda$0(Ljava/lang/Object;)Z");
|
||||
}
|
||||
|
||||
public void testLambdaArguments() {
|
||||
assertBytecodeExists("List l = new ArrayList(); l.stream().mapToDouble(Double::valueOf).map(x -> x + 1)",
|
||||
"synthetic lambda$0(D)D");
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue