diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefBootstrap.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefBootstrap.java index 980225c900a..e3a33ba6331 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefBootstrap.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefBootstrap.java @@ -241,19 +241,21 @@ public final class DefBootstrap { switch(flavor) { case UNARY_OPERATOR: case SHIFT_OPERATOR: - if ((flags & OPERATOR_COMPOUND_ASSIGNMENT) != 0) { - return lookupGeneric(); // XXX: optimize better. - } // shifts are treated as unary, as java allows long arguments without a cast (but bits are ignored) - return DefMath.lookupUnary(args[0].getClass(), name); - case BINARY_OPERATOR: + MethodHandle unary = DefMath.lookupUnary(args[0].getClass(), name); if ((flags & OPERATOR_COMPOUND_ASSIGNMENT) != 0) { - return lookupGeneric(); // XXX: optimize better. + unary = DefMath.cast(args[0].getClass(), unary); } + return unary; + case BINARY_OPERATOR: if (args[0] == null || args[1] == null) { - return lookupGeneric(); // can handle nulls, if supported + return lookupGeneric(); // can handle nulls, casts if supported } else { - return DefMath.lookupBinary(args[0].getClass(), args[1].getClass(), name); + MethodHandle binary = DefMath.lookupBinary(args[0].getClass(), args[1].getClass(), name); + if ((flags & OPERATOR_COMPOUND_ASSIGNMENT) != 0) { + binary = DefMath.cast(args[0].getClass(), binary); + } + return binary; } default: throw new AssertionError(); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefMath.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefMath.java index 665f5acbba3..f56256f6cb2 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefMath.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefMath.java @@ -1133,7 +1133,7 @@ public class DefMath { return handle; } - /** Returns an appropriate method handle for a binary operator, based only promotion of the LHS and RHS arguments */ + /** Returns an appropriate method handle for a binary operator, based on promotion of the LHS and RHS arguments */ public static MethodHandle lookupBinary(Class classA, Class classB, String name) { MethodHandle handle = TYPE_OP_MAPPING.get(promote(promote(unbox(classA)), promote(unbox(classB)))).get(name); if (handle == null) { @@ -1156,6 +1156,9 @@ public class DefMath { static Object dynamicCast(Object returnValue, Object lhs) { if (lhs != null) { Class c = lhs.getClass(); + if (c == returnValue.getClass()) { + return returnValue; + } if (c == Integer.class) { return getNumber(returnValue).intValue(); } else if (c == Long.class) { @@ -1212,4 +1215,18 @@ public class DefMath { // combine: f(x,y) -> g(f(x,y), x, y); return MethodHandles.foldArguments(cast, generic); } + + /** Forces a cast to class A for target (only if types differ) */ + public static MethodHandle cast(Class classA, MethodHandle target) { + MethodType newType = MethodType.methodType(classA).unwrap(); + MethodType targetType = MethodType.methodType(target.type().returnType()).unwrap(); + + if (newType.returnType() == targetType.returnType()) { + return target; // no conversion + } + + // this is safe for our uses of it here only, because we change just the return value, + // the original method itself does all the type checks correctly. + return MethodHandles.explicitCastArguments(target, target.type().changeReturnType(newType.returnType())); + } }