Infer lambda arguments/return type

This commit is contained in:
Robert Muir 2016-06-20 14:54:45 -04:00
parent 09305a0f98
commit 1cc0264827
5 changed files with 75 additions and 11 deletions

View File

@ -69,8 +69,9 @@ public final class Locals {
* <p>
* This is just like {@link #newFunctionScope}, except the captured parameters are made read-only.
*/
public static Locals newLambdaScope(Locals programScope, List<Parameter> parameters, int captureCount, int maxLoopCounter) {
Locals locals = new Locals(programScope, Definition.DEF_TYPE);
public static Locals newLambdaScope(Locals programScope, Type returnType, List<Parameter> parameters,
int captureCount, int maxLoopCounter) {
Locals locals = new Locals(programScope, returnType);
for (int i = 0; i < parameters.size(); i++) {
Parameter parameter = parameters.get(i);
boolean isCapture = i < captureCount;

View File

@ -100,6 +100,44 @@ public class ELambda extends AExpression implements ILambda {
@Override
void analyze(Locals locals) {
final Definition.Type returnType;
final List<String> actualParamTypeStrs;
Method interfaceMethod;
// inspect the target first, set interface method if we know it.
if (expected == null) {
interfaceMethod = null;
// we don't know anything: treat as def
returnType = Definition.DEF_TYPE;
// don't infer any types
actualParamTypeStrs = paramTypeStrs;
} else {
// we know the method statically, infer return type and any unknown/def types
interfaceMethod = expected.struct.getFunctionalMethod();
if (interfaceMethod == null) {
throw createError(new IllegalArgumentException("Cannot pass lambda to [" + expected.name +
"], not a functional interface"));
}
// check arity before we manipulate parameters
if (interfaceMethod.arguments.size() != paramTypeStrs.size())
throw new IllegalArgumentException("Incorrect number of parameters for [" + interfaceMethod.name +
"] in [" + expected.clazz + "]");
// for method invocation, its allowed to ignore the return value
if (interfaceMethod.rtn == Definition.VOID_TYPE) {
returnType = Definition.DEF_TYPE;
} else {
returnType = interfaceMethod.rtn;
}
// replace any def types with the actual type (which could still be def)
actualParamTypeStrs = new ArrayList<String>();
for (int i = 0; i < paramTypeStrs.size(); i++) {
String paramType = paramTypeStrs.get(i);
if (paramType.equals(Definition.DEF_TYPE.name)) {
actualParamTypeStrs.add(interfaceMethod.arguments.get(i).name);
} else {
actualParamTypeStrs.add(paramType);
}
}
}
// gather any variables used by the lambda body first.
Set<String> variables = new HashSet<>();
for (AStatement statement : statements) {
@ -119,14 +157,14 @@ public class ELambda extends AExpression implements ILambda {
paramTypes.add(var.type.name);
paramNames.add(var.name);
}
paramTypes.addAll(paramTypeStrs);
paramTypes.addAll(actualParamTypeStrs);
paramNames.addAll(paramNameStrs);
// desugar lambda body into a synthetic method
desugared = new SFunction(reserved, location, "def", name,
desugared = new SFunction(reserved, location, returnType.name, name,
paramTypes, paramNames, statements, true);
desugared.generate();
desugared.analyze(Locals.newLambdaScope(locals.getProgramScope(), desugared.parameters,
desugared.analyze(Locals.newLambdaScope(locals.getProgramScope(), returnType, desugared.parameters,
captures.size(), reserved.getMaxLoopCounter()));
// setup method reference to synthetic method
@ -137,11 +175,6 @@ public class ELambda extends AExpression implements ILambda {
} else {
defPointer = null;
try {
Method interfaceMethod = expected.struct.getFunctionalMethod();
if (interfaceMethod == null) {
throw new IllegalArgumentException("Cannot pass lambda to [" + expected.name +
"], not a functional interface");
}
Class<?> captureClasses[] = new Class<?>[captures.size()];
for (int i = 0; i < captures.size(); i++) {
captureClasses[i] = captures.get(i).type.clazz;

View File

@ -21,13 +21,16 @@ package org.elasticsearch.painless.node;
import org.elasticsearch.painless.Definition.MethodKey;
import org.elasticsearch.painless.Location;
import org.elasticsearch.painless.Definition;
import org.elasticsearch.painless.Definition.Method;
import org.elasticsearch.painless.Definition.Sort;
import org.elasticsearch.painless.Definition.Struct;
import org.elasticsearch.painless.Definition.Type;
import org.elasticsearch.painless.Globals;
import org.elasticsearch.painless.Locals;
import org.elasticsearch.painless.MethodWriter;
import java.lang.invoke.MethodType;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@ -41,6 +44,8 @@ public final class LCallInvoke extends ALink {
final List<AExpression> arguments;
Method method = null;
boolean box = false; // true for primitive types
public LCallInvoke(Location location, String name, List<AExpression> arguments) {
super(location, -1);
@ -68,6 +73,12 @@ public final class LCallInvoke extends ALink {
MethodKey methodKey = new MethodKey(name, arguments.size());
Struct struct = before.struct;
if (before.clazz.isPrimitive()) {
Class<?> wrapper = MethodType.methodType(before.clazz).wrap().returnType();
Type boxed = Definition.getType(wrapper.getSimpleName());
struct = boxed.struct;
box = true;
}
method = statik ? struct.staticMethods.get(methodKey) : struct.methods.get(methodKey);
if (method != null) {
@ -103,6 +114,10 @@ public final class LCallInvoke extends ALink {
@Override
void load(MethodWriter writer, Globals globals) {
writer.writeDebugInfo(location);
if (box) {
writer.box(before.type);
}
for (AExpression argument : arguments) {
argument.write(writer, globals);

View File

@ -19,7 +19,6 @@
package org.elasticsearch.painless;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@ -119,4 +118,10 @@ public class BasicAPITests extends ScriptTestCase {
assertEquals("{}", exec("Map map = new HashMap(); return map.toString();"));
assertEquals("{}", exec("def map = new HashMap(); return map.toString();"));
}
public void testPrimitivesHaveMethods() {
assertEquals(5, exec("int x = 5; return x.intValue();"));
assertEquals("5", exec("int x = 5; return x.toString();"));
assertEquals(0, exec("int x = 5; return x.compareTo(5);"));
}
}

View File

@ -452,4 +452,14 @@ public class DefOptimizationTests extends ScriptTestCase {
assertBytecodeExists("def x = 1; double y = +x; return y",
"INVOKEDYNAMIC plus(Ljava/lang/Object;)D");
}
public void testLambdaReturnType() {
assertBytecodeExists("List l = new ArrayList(); l.removeIf(x -> x < 10)",
"synthetic lambda$0(Ljava/lang/Object;)Z");
}
public void testLambdaArguments() {
assertBytecodeExists("List l = new ArrayList(); l.stream().mapToDouble(Double::valueOf).map(x -> x + 1)",
"synthetic lambda$0(D)D");
}
}