From 7ecb4ca4e99ef2b1fed12284a135b1c7fd113ad9 Mon Sep 17 00:00:00 2001 From: Robert Muir Date: Mon, 13 Jun 2016 18:56:28 -0400 Subject: [PATCH] Refactor def math --- .../java/org/elasticsearch/painless/Def.java | 619 --------- .../elasticsearch/painless/DefBootstrap.java | 139 +- .../org/elasticsearch/painless/DefMath.java | 1149 +++++++++++++++++ .../elasticsearch/painless/MethodWriter.java | 60 +- .../org/elasticsearch/painless/Utility.java | 8 - .../painless/WriterConstants.java | 22 +- .../elasticsearch/painless/node/EBinary.java | 57 +- .../painless/node/ECapturingFunctionRef.java | 2 +- .../elasticsearch/painless/node/EComp.java | 32 +- .../elasticsearch/painless/node/EUnary.java | 18 +- .../painless/node/LDefArray.java | 4 +- .../elasticsearch/painless/node/LDefCall.java | 2 +- .../painless/node/LDefField.java | 4 +- .../elasticsearch/painless/node/SEach.java | 2 +- .../org/elasticsearch/painless/AndTests.java | 9 + .../painless/BinaryOperatorTests.java | 37 +- .../painless/DefBootstrapTests.java | 15 + .../painless/DefOperationTests.java | 293 ++--- .../org/elasticsearch/painless/OrTests.java | 9 + .../elasticsearch/painless/UnaryTests.java | 10 + .../org/elasticsearch/painless/XorTests.java | 9 + 21 files changed, 1587 insertions(+), 913 deletions(-) create mode 100644 modules/lang-painless/src/main/java/org/elasticsearch/painless/DefMath.java diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java index 48775a88a20..158086ab429 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java @@ -655,625 +655,6 @@ public final class Def { } } - // NOTE: Below methods are not cached, instead invoked directly because they are performant. - // We also check for Long values first when possible since the type is more - // likely to be a Long than a Float. - - public static Object not(final Object unary) { - if (unary instanceof Double || unary instanceof Long || unary instanceof Float) { - return ~((Number)unary).longValue(); - } else if (unary instanceof Number) { - return ~((Number)unary).intValue(); - } else if (unary instanceof Character) { - return ~(int)(char)unary; - } - - throw new ClassCastException("Cannot apply [~] operation to type " + - "[" + unary.getClass().getCanonicalName() + "]."); - } - - public static Object neg(final Object unary) { - if (unary instanceof Double) { - return -(double)unary; - } else if (unary instanceof Float) { - return -(float)unary; - } else if (unary instanceof Long) { - return -(long)unary; - } else if (unary instanceof Number) { - return -((Number)unary).intValue(); - } else if (unary instanceof Character) { - return -(char)unary; - } - - throw new ClassCastException("Cannot apply [-] operation to type " + - "[" + unary.getClass().getCanonicalName() + "]."); - } - - public static Object mul(final Object left, final Object right) { - if (left instanceof Number) { - if (right instanceof Number) { - if (left instanceof Double || right instanceof Double) { - return ((Number)left).doubleValue() * ((Number)right).doubleValue(); - } else if (left instanceof Float || right instanceof Float) { - return ((Number)left).floatValue() * ((Number)right).floatValue(); - } else if (left instanceof Long || right instanceof Long) { - return ((Number)left).longValue() * ((Number)right).longValue(); - } else { - return ((Number)left).intValue() * ((Number)right).intValue(); - } - } else if (right instanceof Character) { - if (left instanceof Double) { - return ((Number)left).doubleValue() * (char)right; - } else if (left instanceof Long) { - return ((Number)left).longValue() * (char)right; - } else if (left instanceof Float) { - return ((Number)left).floatValue() * (char)right; - } else { - return ((Number)left).intValue() * (char)right; - } - } - } else if (left instanceof Character) { - if (right instanceof Number) { - if (right instanceof Double) { - return (char)left * ((Number)right).doubleValue(); - } else if (right instanceof Long) { - return (char)left * ((Number)right).longValue(); - } else if (right instanceof Float) { - return (char)left * ((Number)right).floatValue(); - } else { - return (char)left * ((Number)right).intValue(); - } - } else if (right instanceof Character) { - return (char)left * (char)right; - } - } - - throw new ClassCastException("Cannot apply [*] operation to types " + - "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); - } - - public static Object div(final Object left, final Object right) { - if (left instanceof Number) { - if (right instanceof Number) { - if (left instanceof Double || right instanceof Double) { - return ((Number)left).doubleValue() / ((Number)right).doubleValue(); - } else if (left instanceof Float || right instanceof Float) { - return ((Number)left).floatValue() / ((Number)right).floatValue(); - } else if (left instanceof Long || right instanceof Long) { - return ((Number)left).longValue() / ((Number)right).longValue(); - } else { - return ((Number)left).intValue() / ((Number)right).intValue(); - } - } else if (right instanceof Character) { - if (left instanceof Double) { - return ((Number)left).doubleValue() / (char)right; - } else if (left instanceof Long) { - return ((Number)left).longValue() / (char)right; - } else if (left instanceof Float) { - return ((Number)left).floatValue() / (char)right; - } else { - return ((Number)left).intValue() / (char)right; - } - } - } else if (left instanceof Character) { - if (right instanceof Number) { - if (right instanceof Double) { - return (char)left / ((Number)right).doubleValue(); - } else if (right instanceof Long) { - return (char)left / ((Number)right).longValue(); - } else if (right instanceof Float) { - return (char)left / ((Number)right).floatValue(); - } else { - return (char)left / ((Number)right).intValue(); - } - } else if (right instanceof Character) { - return (char)left / (char)right; - } - } - - throw new ClassCastException("Cannot apply [/] operation to types " + - "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); - } - - public static Object rem(final Object left, final Object right) { - if (left instanceof Number) { - if (right instanceof Number) { - if (left instanceof Double || right instanceof Double) { - return ((Number)left).doubleValue() % ((Number)right).doubleValue(); - } else if (left instanceof Float || right instanceof Float) { - return ((Number)left).floatValue() % ((Number)right).floatValue(); - } else if (left instanceof Long || right instanceof Long) { - return ((Number)left).longValue() % ((Number)right).longValue(); - } else { - return ((Number)left).intValue() % ((Number)right).intValue(); - } - } else if (right instanceof Character) { - if (left instanceof Double) { - return ((Number)left).doubleValue() % (char)right; - } else if (left instanceof Long) { - return ((Number)left).longValue() % (char)right; - } else if (left instanceof Float) { - return ((Number)left).floatValue() % (char)right; - } else { - return ((Number)left).intValue() % (char)right; - } - } - } else if (left instanceof Character) { - if (right instanceof Number) { - if (right instanceof Double) { - return (char)left % ((Number)right).doubleValue(); - } else if (right instanceof Long) { - return (char)left % ((Number)right).longValue(); - } else if (right instanceof Float) { - return (char)left % ((Number)right).floatValue(); - } else { - return (char)left % ((Number)right).intValue(); - } - } else if (right instanceof Character) { - return (char)left % (char)right; - } - } - - throw new ClassCastException("Cannot apply [%] operation to types " + - "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); - } - - public static Object add(final Object left, final Object right) { - if (left instanceof String || right instanceof String) { - return "" + left + right; - } else if (left instanceof Number) { - if (right instanceof Number) { - if (left instanceof Double || right instanceof Double) { - return ((Number)left).doubleValue() + ((Number)right).doubleValue(); - } else if (left instanceof Float || right instanceof Float) { - return ((Number)left).floatValue() + ((Number)right).floatValue(); - } else if (left instanceof Long || right instanceof Long) { - return ((Number)left).longValue() + ((Number)right).longValue(); - } else { - return ((Number)left).intValue() + ((Number)right).intValue(); - } - } else if (right instanceof Character) { - if (left instanceof Double) { - return ((Number)left).doubleValue() + (char)right; - } else if (left instanceof Long) { - return ((Number)left).longValue() + (char)right; - } else if (left instanceof Float) { - return ((Number)left).floatValue() + (char)right; - } else { - return ((Number)left).intValue() + (char)right; - } - } - } else if (left instanceof Character) { - if (right instanceof Number) { - if (right instanceof Double) { - return (char)left + ((Number)right).doubleValue(); - } else if (right instanceof Long) { - return (char)left + ((Number)right).longValue(); - } else if (right instanceof Float) { - return (char)left + ((Number)right).floatValue(); - } else { - return (char)left + ((Number)right).intValue(); - } - } else if (right instanceof Character) { - return (char)left + (char)right; - } - } - - throw new ClassCastException("Cannot apply [+] operation to types " + - "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); - } - - public static Object sub(final Object left, final Object right) { - if (left instanceof Number) { - if (right instanceof Number) { - if (left instanceof Double || right instanceof Double) { - return ((Number)left).doubleValue() - ((Number)right).doubleValue(); - } else if (left instanceof Float || right instanceof Float) { - return ((Number)left).floatValue() - ((Number)right).floatValue(); - } else if (left instanceof Long || right instanceof Long) { - return ((Number)left).longValue() - ((Number)right).longValue(); - } else { - return ((Number)left).intValue() - ((Number)right).intValue(); - } - } else if (right instanceof Character) { - if (left instanceof Double) { - return ((Number)left).doubleValue() - (char)right; - } else if (left instanceof Long) { - return ((Number)left).longValue() - (char)right; - } else if (left instanceof Float) { - return ((Number)left).floatValue() - (char)right; - } else { - return ((Number)left).intValue() - (char)right; - } - } - } else if (left instanceof Character) { - if (right instanceof Number) { - if (right instanceof Double) { - return (char)left - ((Number)right).doubleValue(); - } else if (right instanceof Long) { - return (char)left - ((Number)right).longValue(); - } else if (right instanceof Float) { - return (char)left - ((Number)right).floatValue(); - } else { - return (char)left - ((Number)right).intValue(); - } - } else if (right instanceof Character) { - return (char)left - (char)right; - } - } - - throw new ClassCastException("Cannot apply [-] operation to types " + - "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); - } - - public static Object lsh(final Object left, final int right) { - if (left instanceof Double || left instanceof Long || left instanceof Float) { - return ((Number)left).longValue() << right; - } else if (left instanceof Number) { - return ((Number)left).intValue() << right; - } else if (left instanceof Character) { - return (char)left << right; - } - - throw new ClassCastException("Cannot apply [<<] operation to types [" + left.getClass().getCanonicalName() + "] and [int]."); - } - - public static Object rsh(final Object left, final int right) { - if (left instanceof Double || left instanceof Long || left instanceof Float) { - return ((Number)left).longValue() >> right; - } else if (left instanceof Number) { - return ((Number)left).intValue() >> right; - } else if (left instanceof Character) { - return (char)left >> right; - } - - throw new ClassCastException("Cannot apply [>>] operation to types [" + left.getClass().getCanonicalName() + "] and [int]."); - } - - public static Object ush(final Object left, final int right) { - if (left instanceof Double || left instanceof Long || left instanceof Float) { - return ((Number)left).longValue() >>> right; - } else if (left instanceof Number) { - return ((Number)left).intValue() >>> right; - } else if (left instanceof Character) { - return (char)left >>> right; - } - - throw new ClassCastException("Cannot apply [>>>] operation to types [" + left.getClass().getCanonicalName() + "] and [int]."); - } - - public static Object and(final Object left, final Object right) { - if (left instanceof Boolean && right instanceof Boolean) { - return (boolean)left && (boolean)right; - } else if (left instanceof Number) { - if (right instanceof Number) { - if (left instanceof Double || right instanceof Double || - left instanceof Long || right instanceof Long || - left instanceof Float || right instanceof Float) { - return ((Number)left).longValue() & ((Number)right).longValue(); - } else { - return ((Number)left).intValue() & ((Number)right).intValue(); - } - } else if (right instanceof Character) { - if (left instanceof Double || left instanceof Long || left instanceof Float) { - return ((Number)left).longValue() & (char)right; - } else { - return ((Number)left).intValue() & (char)right; - } - } - } else if (left instanceof Character) { - if (right instanceof Number) { - if (right instanceof Double || right instanceof Long || right instanceof Float) { - return (char)left & ((Number)right).longValue(); - } else { - return (char)left & ((Number)right).intValue(); - } - } else if (right instanceof Character) { - return (char)left & (char)right; - } - } - - throw new ClassCastException("Cannot apply [&] operation to types " + - "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); - } - - public static Object xor(final Object left, final Object right) { - if (left instanceof Boolean && right instanceof Boolean) { - return (boolean)left ^ (boolean)right; - } else if (left instanceof Number) { - if (right instanceof Number) { - if (left instanceof Double || right instanceof Double || - left instanceof Long || right instanceof Long || - left instanceof Float || right instanceof Float) { - return ((Number)left).longValue() ^ ((Number)right).longValue(); - } else { - return ((Number)left).intValue() ^ ((Number)right).intValue(); - } - } else if (right instanceof Character) { - if (left instanceof Double || left instanceof Long || left instanceof Float) { - return ((Number)left).longValue() ^ (char)right; - } else { - return ((Number)left).intValue() ^ (char)right; - } - } - } else if (left instanceof Character) { - if (right instanceof Number) { - if (right instanceof Double || right instanceof Long || right instanceof Float) { - return (char)left ^ ((Number)right).longValue(); - } else { - return (char)left ^ ((Number)right).intValue(); - } - } else if (right instanceof Character) { - return (char)left ^ (char)right; - } - } - - throw new ClassCastException("Cannot apply [^] operation to types " + - "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); - } - - public static Object or(final Object left, final Object right) { - if (left instanceof Boolean && right instanceof Boolean) { - return (boolean)left || (boolean)right; - } else if (left instanceof Number) { - if (right instanceof Number) { - if (left instanceof Double || right instanceof Double || - left instanceof Long || right instanceof Long || - left instanceof Float || right instanceof Float) { - return ((Number)left).longValue() | ((Number)right).longValue(); - } else { - return ((Number)left).intValue() | ((Number)right).intValue(); - } - } else if (right instanceof Character) { - if (left instanceof Double || left instanceof Long || left instanceof Float) { - return ((Number)left).longValue() | (char)right; - } else { - return ((Number)left).intValue() | (char)right; - } - } - } else if (left instanceof Character) { - if (right instanceof Number) { - if (right instanceof Double || right instanceof Long || right instanceof Float) { - return (char)left | ((Number)right).longValue(); - } else { - return (char)left | ((Number)right).intValue(); - } - } else if (right instanceof Character) { - return (char)left | (char)right; - } - } - - throw new ClassCastException("Cannot apply [|] operation to types " + - "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); - } - - public static boolean eq(final Object left, final Object right) { - if (left != null && right != null) { - if (left instanceof Double) { - if (right instanceof Number) { - return (double)left == ((Number)right).doubleValue(); - } else if (right instanceof Character) { - return (double)left == (char)right; - } - } else if (right instanceof Double) { - if (left instanceof Number) { - return ((Number)left).doubleValue() == (double)right; - } else if (left instanceof Character) { - return (char)left == ((Number)right).doubleValue(); - } - } else if (left instanceof Float) { - if (right instanceof Number) { - return (float)left == ((Number)right).floatValue(); - } else if (right instanceof Character) { - return (float)left == (char)right; - } - } else if (right instanceof Float) { - if (left instanceof Number) { - return ((Number)left).floatValue() == (float)right; - } else if (left instanceof Character) { - return (char)left == ((Number)right).floatValue(); - } - } else if (left instanceof Long) { - if (right instanceof Number) { - return (long)left == ((Number)right).longValue(); - } else if (right instanceof Character) { - return (long)left == (char)right; - } - } else if (right instanceof Long) { - if (left instanceof Number) { - return ((Number)left).longValue() == (long)right; - } else if (left instanceof Character) { - return (char)left == ((Number)right).longValue(); - } - } else if (left instanceof Number) { - if (right instanceof Number) { - return ((Number)left).intValue() == ((Number)right).intValue(); - } else if (right instanceof Character) { - return ((Number)left).intValue() == (char)right; - } - } else if (right instanceof Number && left instanceof Character) { - return (char)left == ((Number)right).intValue(); - } else if (left instanceof Character && right instanceof Character) { - return (char)left == (char)right; - } - - return left.equals(right); - } - - return left == null && right == null; - } - - public static boolean lt(final Object left, final Object right) { - if (left instanceof Number) { - if (right instanceof Number) { - if (left instanceof Double || right instanceof Double) { - return ((Number)left).doubleValue() < ((Number)right).doubleValue(); - } else if (left instanceof Float || right instanceof Float) { - return ((Number)left).floatValue() < ((Number)right).floatValue(); - } else if (left instanceof Long || right instanceof Long) { - return ((Number)left).longValue() < ((Number)right).longValue(); - } else { - return ((Number)left).intValue() < ((Number)right).intValue(); - } - } else if (right instanceof Character) { - if (left instanceof Double) { - return ((Number)left).doubleValue() < (char)right; - } else if (left instanceof Long) { - return ((Number)left).longValue() < (char)right; - } else if (left instanceof Float) { - return ((Number)left).floatValue() < (char)right; - } else { - return ((Number)left).intValue() < (char)right; - } - } - } else if (left instanceof Character) { - if (right instanceof Number) { - if (right instanceof Double) { - return (char)left < ((Number)right).doubleValue(); - } else if (right instanceof Long) { - return (char)left < ((Number)right).longValue(); - } else if (right instanceof Float) { - return (char)left < ((Number)right).floatValue(); - } else { - return (char)left < ((Number)right).intValue(); - } - } else if (right instanceof Character) { - return (char)left < (char)right; - } - } - - throw new ClassCastException("Cannot apply [<] operation to types " + - "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); - } - - public static boolean lte(final Object left, final Object right) { - if (left instanceof Number) { - if (right instanceof Number) { - if (left instanceof Double || right instanceof Double) { - return ((Number)left).doubleValue() <= ((Number)right).doubleValue(); - } else if (left instanceof Float || right instanceof Float) { - return ((Number)left).floatValue() <= ((Number)right).floatValue(); - } else if (left instanceof Long || right instanceof Long) { - return ((Number)left).longValue() <= ((Number)right).longValue(); - } else { - return ((Number)left).intValue() <= ((Number)right).intValue(); - } - } else if (right instanceof Character) { - if (left instanceof Double) { - return ((Number)left).doubleValue() <= (char)right; - } else if (left instanceof Long) { - return ((Number)left).longValue() <= (char)right; - } else if (left instanceof Float) { - return ((Number)left).floatValue() <= (char)right; - } else { - return ((Number)left).intValue() <= (char)right; - } - } - } else if (left instanceof Character) { - if (right instanceof Number) { - if (right instanceof Double) { - return (char)left <= ((Number)right).doubleValue(); - } else if (right instanceof Long) { - return (char)left <= ((Number)right).longValue(); - } else if (right instanceof Float) { - return (char)left <= ((Number)right).floatValue(); - } else { - return (char)left <= ((Number)right).intValue(); - } - } else if (right instanceof Character) { - return (char)left <= (char)right; - } - } - - throw new ClassCastException("Cannot apply [<=] operation to types " + - "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); - } - - public static boolean gt(final Object left, final Object right) { - if (left instanceof Number) { - if (right instanceof Number) { - if (left instanceof Double || right instanceof Double) { - return ((Number)left).doubleValue() > ((Number)right).doubleValue(); - } else if (left instanceof Float || right instanceof Float) { - return ((Number)left).floatValue() > ((Number)right).floatValue(); - } else if (left instanceof Long || right instanceof Long) { - return ((Number)left).longValue() > ((Number)right).longValue(); - } else { - return ((Number)left).intValue() > ((Number)right).intValue(); - } - } else if (right instanceof Character) { - if (left instanceof Double) { - return ((Number)left).doubleValue() > (char)right; - } else if (left instanceof Long) { - return ((Number)left).longValue() > (char)right; - } else if (left instanceof Float) { - return ((Number)left).floatValue() > (char)right; - } else { - return ((Number)left).intValue() > (char)right; - } - } - } else if (left instanceof Character) { - if (right instanceof Number) { - if (right instanceof Double) { - return (char)left > ((Number)right).doubleValue(); - } else if (right instanceof Long) { - return (char)left > ((Number)right).longValue(); - } else if (right instanceof Float) { - return (char)left > ((Number)right).floatValue(); - } else { - return (char)left > ((Number)right).intValue(); - } - } else if (right instanceof Character) { - return (char)left > (char)right; - } - } - - throw new ClassCastException("Cannot apply [>] operation to types " + - "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); - } - - public static boolean gte(final Object left, final Object right) { - if (left instanceof Number) { - if (right instanceof Number) { - if (left instanceof Double || right instanceof Double) { - return ((Number)left).doubleValue() >= ((Number)right).doubleValue(); - } else if (left instanceof Float || right instanceof Float) { - return ((Number)left).floatValue() >= ((Number)right).floatValue(); - } else if (left instanceof Long || right instanceof Long) { - return ((Number)left).longValue() >= ((Number)right).longValue(); - } else { - return ((Number)left).intValue() >= ((Number)right).intValue(); - } - } else if (right instanceof Character) { - if (left instanceof Double) { - return ((Number)left).doubleValue() >= (char)right; - } else if (left instanceof Long) { - return ((Number)left).longValue() >= (char)right; - } else if (left instanceof Float) { - return ((Number)left).floatValue() >= (char)right; - } else { - return ((Number)left).intValue() >= (char)right; - } - } - } else if (left instanceof Character) { - if (right instanceof Number) { - if (right instanceof Double) { - return (char)left >= ((Number)right).doubleValue(); - } else if (right instanceof Long) { - return (char)left >= ((Number)right).longValue(); - } else if (right instanceof Float) { - return (char)left >= ((Number)right).floatValue(); - } else { - return (char)left >= ((Number)right).intValue(); - } - } else if (right instanceof Character) { - return (char)left >= (char)right; - } - } - - throw new ClassCastException("Cannot apply [>] operation to types " + - "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); - } // Conversion methods for Def to primitive types. diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefBootstrap.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefBootstrap.java index 040008c0d62..c688b314243 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefBootstrap.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefBootstrap.java @@ -31,9 +31,9 @@ import java.lang.invoke.MutableCallSite; /** * Painless invokedynamic bootstrap for the call site. *

- * Has 5 flavors (passed as static bootstrap parameters): dynamic method call, + * Has 7 flavors (passed as static bootstrap parameters): dynamic method call, * dynamic field load (getter), and dynamic field store (setter), dynamic array load, - * and dynamic array store. + * dynamic array store, iterator, and method reference. *

* When a new type is encountered at the call site, we lookup from the appropriate * whitelist, and cache with a guard. If we encounter too many types, we stop caching. @@ -62,6 +62,12 @@ public final class DefBootstrap { public static final int ITERATOR = 5; /** static bootstrap parameter indicating a dynamic method reference, e.g. foo::bar */ public static final int REFERENCE = 6; + /** static bootstrap parameter indicating a unary math operator, e.g. ~foo */ + public static final int UNARY_OPERATOR = 7; + /** static bootstrap parameter indicating a binary math operator, e.g. foo / bar */ + public static final int BINARY_OPERATOR = 8; + /** static bootstrap parameter indicating a shift operator, e.g. foo >> bar */ + public static final int SHIFT_OPERATOR = 9; /** * CallSite that implements the polymorphic inlining cache (PIC). @@ -82,6 +88,12 @@ public final class DefBootstrap { this.name = name; this.flavor = flavor; this.args = args; + + // For operators use a monomorphic cache, fallback is fast. + // Just start with a depth of MAX-1, to keep it a constant. + if (flavor == UNARY_OPERATOR || flavor == BINARY_OPERATOR || flavor == SHIFT_OPERATOR) { + depth = MAX_DEPTH - 1; + } final MethodHandle fallback = FALLBACK.bindTo(this) .asCollector(Object[].class, type.parameterCount()) @@ -97,29 +109,69 @@ public final class DefBootstrap { static boolean checkClass(Class clazz, Object receiver) { return receiver.getClass() == clazz; } + + /** + * guard method for inline caching: checks the receiver's class and the first argument + * are the same as the cached receiver and first argument. + */ + static boolean checkBinary(Class left, Class right, Object leftObject, Object rightObject) { + return leftObject.getClass() == left && rightObject.getClass() == right; + } + + /** + * guard method for inline caching: checks the first argument is the same + * as the cached first argument. + */ + static boolean checkBinaryArg(Class left, Class right, Object leftObject, Object rightObject) { + return rightObject.getClass() == right; + } /** * Does a slow lookup against the whitelist. */ - private MethodHandle lookup(int flavor, Class clazz, String name, Object[] args) throws Throwable { + private MethodHandle lookup(int flavor, String name, Object[] args) throws Throwable { switch(flavor) { case METHOD_CALL: - return Def.lookupMethod(lookup, type(), clazz, name, args, (Long) this.args[0]); + return Def.lookupMethod(lookup, type(), args[0].getClass(), name, args, (Long) this.args[0]); case LOAD: - return Def.lookupGetter(clazz, name); + return Def.lookupGetter(args[0].getClass(), name); case STORE: - return Def.lookupSetter(clazz, name); + return Def.lookupSetter(args[0].getClass(), name); case ARRAY_LOAD: - return Def.lookupArrayLoad(clazz); + return Def.lookupArrayLoad(args[0].getClass()); case ARRAY_STORE: - return Def.lookupArrayStore(clazz); + return Def.lookupArrayStore(args[0].getClass()); case ITERATOR: - return Def.lookupIterator(clazz); + return Def.lookupIterator(args[0].getClass()); case REFERENCE: - return Def.lookupReference(lookup, (String) this.args[0], clazz, name); + return Def.lookupReference(lookup, (String) this.args[0], args[0].getClass(), name); + case UNARY_OPERATOR: + case SHIFT_OPERATOR: + // shifts are treated as unary, as java allows long arguments without a cast (but bits are ignored) + return DefMath.lookupUnary(args[0].getClass(), name); + case BINARY_OPERATOR: + if (args[0] == null || args[1] == null) { + return getGeneric(flavor, name); // can handle nulls + } else { + return DefMath.lookupBinary(args[0].getClass(), args[1].getClass(), name); + } default: throw new AssertionError(); } } + + /** + * Installs a permanent, generic solution that works with any parameter types, if possible. + */ + private MethodHandle getGeneric(int flavor, String name) throws Throwable { + switch(flavor) { + case UNARY_OPERATOR: + case BINARY_OPERATOR: + case SHIFT_OPERATOR: + return DefMath.lookupGeneric(name); + default: + return null; + } + } /** * Called when a new type is encountered (or, when we have encountered more than {@code MAX_DEPTH} @@ -127,21 +179,56 @@ public final class DefBootstrap { */ @SuppressForbidden(reason = "slow path") Object fallback(Object[] args) throws Throwable { - final MethodType type = type(); - final Object receiver = args[0]; - final Class receiverClass = receiver.getClass(); - final MethodHandle target = lookup(flavor, receiverClass, name, args).asType(type); - if (depth >= MAX_DEPTH) { - // revert to a vtable call - setTarget(target); - return target.invokeWithArguments(args); + // caching defeated + MethodHandle generic = getGeneric(flavor, name); + if (generic != null) { + setTarget(generic.asType(type())); + return generic.invokeWithArguments(args); + } else { + return lookup(flavor, name, args).invokeWithArguments(args); + } + } + + final MethodType type = type(); + final MethodHandle target = lookup(flavor, name, args).asType(type); + + final MethodHandle test; + if (flavor == BINARY_OPERATOR || flavor == SHIFT_OPERATOR) { + // some binary operators support nulls, we handle them separate + Class clazz0 = args[0] == null ? null : args[0].getClass(); + Class clazz1 = args[1] == null ? null : args[1].getClass(); + if (type.parameterType(1) != Object.class) { + // case 1: only the receiver is unknown, just check that + MethodHandle unaryTest = CHECK_CLASS.bindTo(clazz0); + test = unaryTest.asType(unaryTest.type() + .changeParameterType(0, type.parameterType(0))); + } else if (type.parameterType(0) != Object.class) { + // case 2: only the argument is unknown, just check that + MethodHandle unaryTest = CHECK_BINARY_ARG.bindTo(clazz0).bindTo(clazz1); + test = unaryTest.asType(unaryTest.type() + .changeParameterType(0, type.parameterType(0)) + .changeParameterType(1, type.parameterType(1))); + } else { + // case 3: check both receiver and argument + MethodHandle binaryTest = CHECK_BINARY.bindTo(clazz0).bindTo(clazz1); + test = binaryTest.asType(binaryTest.type() + .changeParameterType(0, type.parameterType(0)) + .changeParameterType(1, type.parameterType(1))); + } + } else { + MethodHandle receiverTest = CHECK_CLASS.bindTo(args[0].getClass()); + test = receiverTest.asType(receiverTest.type() + .changeParameterType(0, type.parameterType(0))); } - MethodHandle test = CHECK_CLASS.bindTo(receiverClass); - test = test.asType(test.type().changeParameterType(0, type.parameterType(0))); - - final MethodHandle guard = MethodHandles.guardWithTest(test, target, getTarget()); + MethodHandle guard = MethodHandles.guardWithTest(test, target, getTarget()); + // very special cases, where even the receiver can be null (see JLS rules for string concat) + // we wrap + with an NPE catcher, and use our generic method in that case. + if (flavor == BINARY_OPERATOR && "add".equals(name) || "eq".equals(name)) { + MethodHandle handler = MethodHandles.dropArguments(getGeneric(flavor, name).asType(type()), 0, NullPointerException.class); + guard = MethodHandles.catchException(guard, NullPointerException.class, handler); + } depth++; @@ -150,12 +237,18 @@ public final class DefBootstrap { } private static final MethodHandle CHECK_CLASS; + private static final MethodHandle CHECK_BINARY; + private static final MethodHandle CHECK_BINARY_ARG; private static final MethodHandle FALLBACK; static { final Lookup lookup = MethodHandles.lookup(); try { CHECK_CLASS = lookup.findStatic(lookup.lookupClass(), "checkClass", - MethodType.methodType(boolean.class, Class.class, Object.class)); + MethodType.methodType(boolean.class, Class.class, Object.class)); + CHECK_BINARY = lookup.findStatic(lookup.lookupClass(), "checkBinary", + MethodType.methodType(boolean.class, Class.class, Class.class, Object.class, Object.class)); + CHECK_BINARY_ARG = lookup.findStatic(lookup.lookupClass(), "checkBinaryArg", + MethodType.methodType(boolean.class, Class.class, Class.class, Object.class, Object.class)); FALLBACK = lookup.findVirtual(lookup.lookupClass(), "fallback", MethodType.methodType(Object.class, Object[].class)); } catch (ReflectiveOperationException e) { diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefMath.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefMath.java new file mode 100644 index 00000000000..e6b3a8c6003 --- /dev/null +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/DefMath.java @@ -0,0 +1,1149 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.painless; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.lang.invoke.MethodHandles.Lookup; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Dynamic operators for painless. + *

+ * Each operator must "support" the following types: + * {@code int,long,float,double,boolean,Object}. Operators can throw exceptions if + * the type is illegal. The {@code Object} type must be a "generic" handler that + * handles all legal types: it must be convertible to every possible legal signature. + */ +@SuppressWarnings("unused") +public class DefMath { + + // Unary not: only applicable to integral types + + private static int not(int v) { + return ~v; + } + + private static long not(long v) { + return ~v; + } + + private static float not(float v) { + throw new ClassCastException("Cannot apply not [~] to type [float]"); + } + + private static double not(double v) { + throw new ClassCastException("Cannot apply not [~] to type [double]"); + } + + private static boolean not(boolean v) { + throw new ClassCastException("Cannot apply not [~] to type [boolean]"); + } + + private static Object not(Object unary) { + if (unary instanceof Long) { + return ~(Long)unary; + } else if (unary instanceof Integer) { + return ~(Integer)unary; + } else if (unary instanceof Short) { + return ~(Short)unary; + } else if (unary instanceof Character) { + return ~(Character)unary; + } else if (unary instanceof Byte) { + return ~(Byte)unary; + } + + throw new ClassCastException("Cannot apply [~] operation to type " + + "[" + unary.getClass().getCanonicalName() + "]."); + } + + // unary negation and plus: applicable to all numeric types + + private static int neg(int v) { + return -v; + } + + private static long neg(long v) { + return -v; + } + + private static float neg(float v) { + return -v; + } + + private static double neg(double v) { + return -v; + } + + private static boolean neg(boolean v) { + throw new ClassCastException("Cannot apply [-] operation to type [boolean]"); + } + + private static Object neg(final Object unary) { + if (unary instanceof Double) { + return -(double)unary; + } else if (unary instanceof Long) { + return -(long)unary; + } else if (unary instanceof Integer) { + return -(int)unary; + } else if (unary instanceof Float) { + return -(float)unary; + } else if (unary instanceof Short) { + return -(short)unary; + } else if (unary instanceof Character) { + return -(char)unary; + } else if (unary instanceof Byte) { + return -(byte)unary; + } + + throw new ClassCastException("Cannot apply [-] operation to type " + + "[" + unary.getClass().getCanonicalName() + "]."); + } + + private static int plus(int v) { + return +v; + } + + private static long plus(long v) { + return +v; + } + + private static float plus(float v) { + return +v; + } + + private static double plus(double v) { + return +v; + } + + private static boolean plus(boolean v) { + throw new ClassCastException("Cannot apply [+] operation to type [boolean]"); + } + + private static Object plus(final Object unary) { + if (unary instanceof Double) { + return +(double)unary; + } else if (unary instanceof Long) { + return +(long)unary; + } else if (unary instanceof Integer) { + return +(int)unary; + } else if (unary instanceof Float) { + return +(float)unary; + } else if (unary instanceof Short) { + return +(short)unary; + } else if (unary instanceof Character) { + return +(char)unary; + } else if (unary instanceof Byte) { + return +(byte)unary; + } + + throw new ClassCastException("Cannot apply [+] operation to type " + + "[" + unary.getClass().getCanonicalName() + "]."); + } + + // multiplication/division/remainder/subtraction: applicable to all integer types + + private static int mul(int a, int b) { + return a * b; + } + + private static long mul(long a, long b) { + return a * b; + } + + private static float mul(float a, float b) { + return a * b; + } + + private static double mul(double a, double b) { + return a * b; + } + + private static boolean mul(boolean a, boolean b) { + throw new ClassCastException("Cannot apply [*] operation to type [boolean]"); + } + + private static Object mul(Object left, Object right) { + if (left instanceof Number) { + if (right instanceof Number) { + if (left instanceof Double || right instanceof Double) { + return ((Number)left).doubleValue() * ((Number)right).doubleValue(); + } else if (left instanceof Float || right instanceof Float) { + return ((Number)left).floatValue() * ((Number)right).floatValue(); + } else if (left instanceof Long || right instanceof Long) { + return ((Number)left).longValue() * ((Number)right).longValue(); + } else { + return ((Number)left).intValue() * ((Number)right).intValue(); + } + } else if (right instanceof Character) { + if (left instanceof Double) { + return ((Number)left).doubleValue() * (char)right; + } else if (left instanceof Long) { + return ((Number)left).longValue() * (char)right; + } else if (left instanceof Float) { + return ((Number)left).floatValue() * (char)right; + } else { + return ((Number)left).intValue() * (char)right; + } + } + } else if (left instanceof Character) { + if (right instanceof Number) { + if (right instanceof Double) { + return (char)left * ((Number)right).doubleValue(); + } else if (right instanceof Long) { + return (char)left * ((Number)right).longValue(); + } else if (right instanceof Float) { + return (char)left * ((Number)right).floatValue(); + } else { + return (char)left * ((Number)right).intValue(); + } + } else if (right instanceof Character) { + return (char)left * (char)right; + } + } + + throw new ClassCastException("Cannot apply [*] operation to types " + + "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); + } + + private static int div(int a, int b) { + return a / b; + } + + private static long div(long a, long b) { + return a / b; + } + + private static float div(float a, float b) { + return a / b; + } + + private static double div(double a, double b) { + return a / b; + } + + private static boolean div(boolean a, boolean b) { + throw new ClassCastException("Cannot apply [/] operation to type [boolean]"); + } + + private static Object div(Object left, Object right) { + if (left instanceof Number) { + if (right instanceof Number) { + if (left instanceof Double || right instanceof Double) { + return ((Number)left).doubleValue() / ((Number)right).doubleValue(); + } else if (left instanceof Float || right instanceof Float) { + return ((Number)left).floatValue() / ((Number)right).floatValue(); + } else if (left instanceof Long || right instanceof Long) { + return ((Number)left).longValue() / ((Number)right).longValue(); + } else { + return ((Number)left).intValue() / ((Number)right).intValue(); + } + } else if (right instanceof Character) { + if (left instanceof Double) { + return ((Number)left).doubleValue() / (char)right; + } else if (left instanceof Long) { + return ((Number)left).longValue() / (char)right; + } else if (left instanceof Float) { + return ((Number)left).floatValue() / (char)right; + } else { + return ((Number)left).intValue() / (char)right; + } + } + } else if (left instanceof Character) { + if (right instanceof Number) { + if (right instanceof Double) { + return (char)left / ((Number)right).doubleValue(); + } else if (right instanceof Long) { + return (char)left / ((Number)right).longValue(); + } else if (right instanceof Float) { + return (char)left / ((Number)right).floatValue(); + } else { + return (char)left / ((Number)right).intValue(); + } + } else if (right instanceof Character) { + return (char)left / (char)right; + } + } + + throw new ClassCastException("Cannot apply [/] operation to types " + + "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); + } + + private static int rem(int a, int b) { + return a % b; + } + + private static long rem(long a, long b) { + return a % b; + } + + private static float rem(float a, float b) { + return a % b; + } + + private static double rem(double a, double b) { + return a % b; + } + + private static boolean rem(boolean a, boolean b) { + throw new ClassCastException("Cannot apply [%] operation to type [boolean]"); + } + + private static Object rem(Object left, Object right) { + if (left instanceof Number) { + if (right instanceof Number) { + if (left instanceof Double || right instanceof Double) { + return ((Number)left).doubleValue() % ((Number)right).doubleValue(); + } else if (left instanceof Float || right instanceof Float) { + return ((Number)left).floatValue() % ((Number)right).floatValue(); + } else if (left instanceof Long || right instanceof Long) { + return ((Number)left).longValue() % ((Number)right).longValue(); + } else { + return ((Number)left).intValue() % ((Number)right).intValue(); + } + } else if (right instanceof Character) { + if (left instanceof Double) { + return ((Number)left).doubleValue() % (char)right; + } else if (left instanceof Long) { + return ((Number)left).longValue() % (char)right; + } else if (left instanceof Float) { + return ((Number)left).floatValue() % (char)right; + } else { + return ((Number)left).intValue() % (char)right; + } + } + } else if (left instanceof Character) { + if (right instanceof Number) { + if (right instanceof Double) { + return (char)left % ((Number)right).doubleValue(); + } else if (right instanceof Long) { + return (char)left % ((Number)right).longValue(); + } else if (right instanceof Float) { + return (char)left % ((Number)right).floatValue(); + } else { + return (char)left % ((Number)right).intValue(); + } + } else if (right instanceof Character) { + return (char)left % (char)right; + } + } + + throw new ClassCastException("Cannot apply [%] operation to types " + + "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); + } + + // addition: applicable to all numeric types. + // additionally, if either type is a string, the other type can be any arbitrary type (including null) + + private static int add(int a, int b) { + return a + b; + } + + private static long add(long a, long b) { + return a + b; + } + + private static float add(float a, float b) { + return a + b; + } + + private static double add(double a, double b) { + return a + b; + } + + private static boolean add(boolean a, boolean b) { + throw new ClassCastException("Cannot apply [+] operation to type [boolean]"); + } + + private static Object add(Object left, Object right) { + if (left instanceof String || right instanceof String) { + return "" + left + right; + } else if (left instanceof Number) { + if (right instanceof Number) { + if (left instanceof Double || right instanceof Double) { + return ((Number)left).doubleValue() + ((Number)right).doubleValue(); + } else if (left instanceof Float || right instanceof Float) { + return ((Number)left).floatValue() + ((Number)right).floatValue(); + } else if (left instanceof Long || right instanceof Long) { + return ((Number)left).longValue() + ((Number)right).longValue(); + } else { + return ((Number)left).intValue() + ((Number)right).intValue(); + } + } else if (right instanceof Character) { + if (left instanceof Double) { + return ((Number)left).doubleValue() + (char)right; + } else if (left instanceof Long) { + return ((Number)left).longValue() + (char)right; + } else if (left instanceof Float) { + return ((Number)left).floatValue() + (char)right; + } else { + return ((Number)left).intValue() + (char)right; + } + } + } else if (left instanceof Character) { + if (right instanceof Number) { + if (right instanceof Double) { + return (char)left + ((Number)right).doubleValue(); + } else if (right instanceof Long) { + return (char)left + ((Number)right).longValue(); + } else if (right instanceof Float) { + return (char)left + ((Number)right).floatValue(); + } else { + return (char)left + ((Number)right).intValue(); + } + } else if (right instanceof Character) { + return (char)left + (char)right; + } + } + + throw new ClassCastException("Cannot apply [+] operation to types " + + "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); + } + + private static int sub(int a, int b) { + return a - b; + } + + private static long sub(long a, long b) { + return a - b; + } + + private static float sub(float a, float b) { + return a - b; + } + + private static double sub(double a, double b) { + return a - b; + } + + private static boolean sub(boolean a, boolean b) { + throw new ClassCastException("Cannot apply [-] operation to type [boolean]"); + } + + private static Object sub(Object left, Object right) { + if (left instanceof Number) { + if (right instanceof Number) { + if (left instanceof Double || right instanceof Double) { + return ((Number)left).doubleValue() - ((Number)right).doubleValue(); + } else if (left instanceof Float || right instanceof Float) { + return ((Number)left).floatValue() - ((Number)right).floatValue(); + } else if (left instanceof Long || right instanceof Long) { + return ((Number)left).longValue() - ((Number)right).longValue(); + } else { + return ((Number)left).intValue() - ((Number)right).intValue(); + } + } else if (right instanceof Character) { + if (left instanceof Double) { + return ((Number)left).doubleValue() - (char)right; + } else if (left instanceof Long) { + return ((Number)left).longValue() - (char)right; + } else if (left instanceof Float) { + return ((Number)left).floatValue() - (char)right; + } else { + return ((Number)left).intValue() - (char)right; + } + } + } else if (left instanceof Character) { + if (right instanceof Number) { + if (right instanceof Double) { + return (char)left - ((Number)right).doubleValue(); + } else if (right instanceof Long) { + return (char)left - ((Number)right).longValue(); + } else if (right instanceof Float) { + return (char)left - ((Number)right).floatValue(); + } else { + return (char)left - ((Number)right).intValue(); + } + } else if (right instanceof Character) { + return (char)left - (char)right; + } + } + + throw new ClassCastException("Cannot apply [-] operation to types " + + "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); + } + + // eq: applicable to any arbitrary type, including nulls for both arguments!!! + + private static boolean eq(int a, int b) { + return a == b; + } + + private static boolean eq(long a, long b) { + return a == b; + } + + private static boolean eq(float a, float b) { + return a == b; + } + + private static boolean eq(double a, double b) { + return a == b; + } + + private static boolean eq(boolean a, boolean b) { + return a == b; + } + + private static boolean eq(Object left, Object right) { + if (left != null && right != null) { + if (left instanceof Double) { + if (right instanceof Number) { + return (double)left == ((Number)right).doubleValue(); + } else if (right instanceof Character) { + return (double)left == (char)right; + } + } else if (right instanceof Double) { + if (left instanceof Number) { + return ((Number)left).doubleValue() == (double)right; + } else if (left instanceof Character) { + return (char)left == ((Number)right).doubleValue(); + } + } else if (left instanceof Float) { + if (right instanceof Number) { + return (float)left == ((Number)right).floatValue(); + } else if (right instanceof Character) { + return (float)left == (char)right; + } + } else if (right instanceof Float) { + if (left instanceof Number) { + return ((Number)left).floatValue() == (float)right; + } else if (left instanceof Character) { + return (char)left == ((Number)right).floatValue(); + } + } else if (left instanceof Long) { + if (right instanceof Number) { + return (long)left == ((Number)right).longValue(); + } else if (right instanceof Character) { + return (long)left == (char)right; + } + } else if (right instanceof Long) { + if (left instanceof Number) { + return ((Number)left).longValue() == (long)right; + } else if (left instanceof Character) { + return (char)left == ((Number)right).longValue(); + } + } else if (left instanceof Number) { + if (right instanceof Number) { + return ((Number)left).intValue() == ((Number)right).intValue(); + } else if (right instanceof Character) { + return ((Number)left).intValue() == (char)right; + } + } else if (right instanceof Number && left instanceof Character) { + return (char)left == ((Number)right).intValue(); + } else if (left instanceof Character && right instanceof Character) { + return (char)left == (char)right; + } + + return left.equals(right); + } + + return left == null && right == null; + } + + // comparison operators: applicable for any numeric type + + private static boolean lt(int a, int b) { + return a < b; + } + + private static boolean lt(long a, long b) { + return a < b; + } + + private static boolean lt(float a, float b) { + return a < b; + } + + private static boolean lt(double a, double b) { + return a < b; + } + + private static boolean lt(boolean a, boolean b) { + throw new ClassCastException("Cannot apply [<] operation to type [boolean]"); + } + + private static boolean lt(Object left, Object right) { + if (left instanceof Number) { + if (right instanceof Number) { + if (left instanceof Double || right instanceof Double) { + return ((Number)left).doubleValue() < ((Number)right).doubleValue(); + } else if (left instanceof Float || right instanceof Float) { + return ((Number)left).floatValue() < ((Number)right).floatValue(); + } else if (left instanceof Long || right instanceof Long) { + return ((Number)left).longValue() < ((Number)right).longValue(); + } else { + return ((Number)left).intValue() < ((Number)right).intValue(); + } + } else if (right instanceof Character) { + if (left instanceof Double) { + return ((Number)left).doubleValue() < (char)right; + } else if (left instanceof Long) { + return ((Number)left).longValue() < (char)right; + } else if (left instanceof Float) { + return ((Number)left).floatValue() < (char)right; + } else { + return ((Number)left).intValue() < (char)right; + } + } + } else if (left instanceof Character) { + if (right instanceof Number) { + if (right instanceof Double) { + return (char)left < ((Number)right).doubleValue(); + } else if (right instanceof Long) { + return (char)left < ((Number)right).longValue(); + } else if (right instanceof Float) { + return (char)left < ((Number)right).floatValue(); + } else { + return (char)left < ((Number)right).intValue(); + } + } else if (right instanceof Character) { + return (char)left < (char)right; + } + } + + throw new ClassCastException("Cannot apply [<] operation to types " + + "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); + } + + private static boolean lte(int a, int b) { + return a <= b; + } + + private static boolean lte(long a, long b) { + return a <= b; + } + + private static boolean lte(float a, float b) { + return a <= b; + } + + private static boolean lte(double a, double b) { + return a <= b; + } + + private static boolean lte(boolean a, boolean b) { + throw new ClassCastException("Cannot apply [<=] operation to type [boolean]"); + } + + private static boolean lte(Object left, Object right) { + if (left instanceof Number) { + if (right instanceof Number) { + if (left instanceof Double || right instanceof Double) { + return ((Number)left).doubleValue() <= ((Number)right).doubleValue(); + } else if (left instanceof Float || right instanceof Float) { + return ((Number)left).floatValue() <= ((Number)right).floatValue(); + } else if (left instanceof Long || right instanceof Long) { + return ((Number)left).longValue() <= ((Number)right).longValue(); + } else { + return ((Number)left).intValue() <= ((Number)right).intValue(); + } + } else if (right instanceof Character) { + if (left instanceof Double) { + return ((Number)left).doubleValue() <= (char)right; + } else if (left instanceof Long) { + return ((Number)left).longValue() <= (char)right; + } else if (left instanceof Float) { + return ((Number)left).floatValue() <= (char)right; + } else { + return ((Number)left).intValue() <= (char)right; + } + } + } else if (left instanceof Character) { + if (right instanceof Number) { + if (right instanceof Double) { + return (char)left <= ((Number)right).doubleValue(); + } else if (right instanceof Long) { + return (char)left <= ((Number)right).longValue(); + } else if (right instanceof Float) { + return (char)left <= ((Number)right).floatValue(); + } else { + return (char)left <= ((Number)right).intValue(); + } + } else if (right instanceof Character) { + return (char)left <= (char)right; + } + } + + throw new ClassCastException("Cannot apply [<=] operation to types " + + "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); + } + + private static boolean gt(int a, int b) { + return a > b; + } + + private static boolean gt(long a, long b) { + return a > b; + } + + private static boolean gt(float a, float b) { + return a > b; + } + + private static boolean gt(double a, double b) { + return a > b; + } + + private static boolean gt(boolean a, boolean b) { + throw new ClassCastException("Cannot apply [>] operation to type [boolean]"); + } + + private static boolean gt(Object left, Object right) { + if (left instanceof Number) { + if (right instanceof Number) { + if (left instanceof Double || right instanceof Double) { + return ((Number)left).doubleValue() > ((Number)right).doubleValue(); + } else if (left instanceof Float || right instanceof Float) { + return ((Number)left).floatValue() > ((Number)right).floatValue(); + } else if (left instanceof Long || right instanceof Long) { + return ((Number)left).longValue() > ((Number)right).longValue(); + } else { + return ((Number)left).intValue() > ((Number)right).intValue(); + } + } else if (right instanceof Character) { + if (left instanceof Double) { + return ((Number)left).doubleValue() > (char)right; + } else if (left instanceof Long) { + return ((Number)left).longValue() > (char)right; + } else if (left instanceof Float) { + return ((Number)left).floatValue() > (char)right; + } else { + return ((Number)left).intValue() > (char)right; + } + } + } else if (left instanceof Character) { + if (right instanceof Number) { + if (right instanceof Double) { + return (char)left > ((Number)right).doubleValue(); + } else if (right instanceof Long) { + return (char)left > ((Number)right).longValue(); + } else if (right instanceof Float) { + return (char)left > ((Number)right).floatValue(); + } else { + return (char)left > ((Number)right).intValue(); + } + } else if (right instanceof Character) { + return (char)left > (char)right; + } + } + + throw new ClassCastException("Cannot apply [>] operation to types " + + "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); + } + + private static boolean gte(int a, int b) { + return a >= b; + } + + private static boolean gte(long a, long b) { + return a >= b; + } + + private static boolean gte(float a, float b) { + return a >= b; + } + + private static boolean gte(double a, double b) { + return a >= b; + } + + private static boolean gte(boolean a, boolean b) { + throw new ClassCastException("Cannot apply [>=] operation to type [boolean]"); + } + + private static boolean gte(Object left, Object right) { + if (left instanceof Number) { + if (right instanceof Number) { + if (left instanceof Double || right instanceof Double) { + return ((Number)left).doubleValue() >= ((Number)right).doubleValue(); + } else if (left instanceof Float || right instanceof Float) { + return ((Number)left).floatValue() >= ((Number)right).floatValue(); + } else if (left instanceof Long || right instanceof Long) { + return ((Number)left).longValue() >= ((Number)right).longValue(); + } else { + return ((Number)left).intValue() >= ((Number)right).intValue(); + } + } else if (right instanceof Character) { + if (left instanceof Double) { + return ((Number)left).doubleValue() >= (char)right; + } else if (left instanceof Long) { + return ((Number)left).longValue() >= (char)right; + } else if (left instanceof Float) { + return ((Number)left).floatValue() >= (char)right; + } else { + return ((Number)left).intValue() >= (char)right; + } + } + } else if (left instanceof Character) { + if (right instanceof Number) { + if (right instanceof Double) { + return (char)left >= ((Number)right).doubleValue(); + } else if (right instanceof Long) { + return (char)left >= ((Number)right).longValue(); + } else if (right instanceof Float) { + return (char)left >= ((Number)right).floatValue(); + } else { + return (char)left >= ((Number)right).intValue(); + } + } else if (right instanceof Character) { + return (char)left >= (char)right; + } + } + + throw new ClassCastException("Cannot apply [>] operation to types " + + "[" + left.getClass().getCanonicalName() + "] and [" + right.getClass().getCanonicalName() + "]."); + } + + // helper methods to convert an integral according to numeric promotion + // this is used by the generic code for bitwise and shift operators + + private static long longIntegralValue(Object o) { + if (o instanceof Long) { + return (long)o; + } else if (o instanceof Integer || o instanceof Short || o instanceof Byte) { + return ((Number)o).longValue(); + } else if (o instanceof Character) { + return (char)o; + } else { + throw new ClassCastException("Cannot convert [" + o.getClass().getCanonicalName() + "] to an integral value."); + } + } + + private static int intIntegralValue(Object o) { + if (o instanceof Integer || o instanceof Short || o instanceof Byte) { + return ((Number)o).intValue(); + } else if (o instanceof Character) { + return (char)o; + } else { + throw new ClassCastException("Cannot convert [" + o.getClass().getCanonicalName() + "] to an integral value."); + } + } + + // bitwise operators: valid only for integral types + + private static int and(int a, int b) { + return a & b; + } + + private static long and(long a, long b) { + return a & b; + } + + private static float and(float a, float b) { + throw new ClassCastException("Cannot apply [&] operation to type [float]"); + } + + private static double and(double a, double b) { + throw new ClassCastException("Cannot apply [&] operation to type [float]"); + } + + private static boolean and(boolean a, boolean b) { + return a & b; + } + + private static Object and(Object left, Object right) { + if (left instanceof Boolean && right instanceof Boolean) { + return (boolean)left & (boolean)right; + } else if (left instanceof Long || right instanceof Long) { + return longIntegralValue(left) & longIntegralValue(right); + } else { + return intIntegralValue(left) & intIntegralValue(right); + } + } + + private static int xor(int a, int b) { + return a ^ b; + } + + private static long xor(long a, long b) { + return a ^ b; + } + + private static float xor(float a, float b) { + throw new ClassCastException("Cannot apply [^] operation to type [float]"); + } + + private static double xor(double a, double b) { + throw new ClassCastException("Cannot apply [^] operation to type [float]"); + } + + private static boolean xor(boolean a, boolean b) { + return a ^ b; + } + + private static Object xor(Object left, Object right) { + if (left instanceof Boolean && right instanceof Boolean) { + return (boolean)left ^ (boolean)right; + } else if (left instanceof Long || right instanceof Long) { + return longIntegralValue(left) ^ longIntegralValue(right); + } else { + return intIntegralValue(left) ^ intIntegralValue(right); + } + } + + private static int or(int a, int b) { + return a | b; + } + + private static long or(long a, long b) { + return a | b; + } + + private static float or(float a, float b) { + throw new ClassCastException("Cannot apply [|] operation to type [float]"); + } + + private static double or(double a, double b) { + throw new ClassCastException("Cannot apply [|] operation to type [float]"); + } + + private static boolean or(boolean a, boolean b) { + return a | b; + } + + private static Object or(Object left, Object right) { + if (left instanceof Boolean && right instanceof Boolean) { + return (boolean)left | (boolean)right; + } else if (left instanceof Long || right instanceof Long) { + return longIntegralValue(left) | longIntegralValue(right); + } else { + return intIntegralValue(left) | intIntegralValue(right); + } + } + + // shift operators, valid for any integral types, but does not promote. + // we implement all shifts as long shifts, because the extra bits are ignored anyway. + + private static int lsh(int a, long b) { + return a << b; + } + + private static long lsh(long a, long b) { + return a << b; + } + + private static float lsh(float a, long b) { + throw new ClassCastException("Cannot apply [<<] operation to type [float]"); + } + + private static double lsh(double a, long b) { + throw new ClassCastException("Cannot apply [<<] operation to type [double]"); + } + + private static boolean lsh(boolean a, long b) { + throw new ClassCastException("Cannot apply [<<] operation to type [boolean]"); + } + + public static Object lsh(Object left, long right) { + if (left instanceof Long) { + return (long)(left) << right; + } else { + return intIntegralValue(left) << right; + } + } + + private static int rsh(int a, long b) { + return a >> b; + } + + private static long rsh(long a, long b) { + return a >> b; + } + + private static float rsh(float a, long b) { + throw new ClassCastException("Cannot apply [>>] operation to type [float]"); + } + + private static double rsh(double a, long b) { + throw new ClassCastException("Cannot apply [>>] operation to type [double]"); + } + + private static boolean rsh(boolean a, long b) { + throw new ClassCastException("Cannot apply [>>] operation to type [boolean]"); + } + + public static Object rsh(Object left, long right) { + if (left instanceof Long) { + return (long)left >> right; + } else { + return intIntegralValue(left) >> right; + } + } + + private static int ush(int a, long b) { + return a >>> b; + } + + private static long ush(long a, long b) { + return a >>> b; + } + + private static float ush(float a, long b) { + throw new ClassCastException("Cannot apply [>>>] operation to type [float]"); + } + + private static double ush(double a, long b) { + throw new ClassCastException("Cannot apply [>>>] operation to type [double]"); + } + + private static boolean ush(boolean a, long b) { + throw new ClassCastException("Cannot apply [>>>] operation to type [boolean]"); + } + + public static Object ush(Object left, long right) { + if (left instanceof Long) { + return (long)(left) >>> right; + } else { + return intIntegralValue(left) >>> right; + } + } + + /** + * unboxes a class to its primitive type, or returns the original + * class if its not a boxed type. + */ + private static Class unbox(Class clazz) { + if (clazz == Boolean.class) { + return boolean.class; + } else if (clazz == Byte.class) { + return byte.class; + } else if (clazz == Short.class) { + return short.class; + } else if (clazz == Character.class) { + return char.class; + } else if (clazz == Integer.class) { + return int.class; + } else if (clazz == Long.class) { + return long.class; + } else if (clazz == Float.class) { + return float.class; + } else if (clazz == Double.class) { + return double.class; + } else { + return clazz; + } + } + + /** Unary promotion. All Objects are promoted to Object. */ + private static Class promote(Class clazz) { + // if either is a non-primitive type -> Object. + if (clazz.isPrimitive() == false) { + return Object.class; + } + // always promoted to integer + if (clazz == byte.class || clazz == short.class || clazz == char.class || clazz == int.class) { + return int.class; + } else { + return clazz; + } + } + + /** Binary promotion. */ + private static Class promote(Class a, Class b) { + // if either is a non-primitive type -> Object. + if (a.isPrimitive() == false || b.isPrimitive() == false) { + return Object.class; + } + + // boolean -> boolean + if (a == boolean.class && b == boolean.class) { + return boolean.class; + } + + // ordinary numeric promotion + if (a == double.class || b == double.class) { + return double.class; + } else if (a == float.class || b == float.class) { + return float.class; + } else if (a == long.class || b == long.class) { + return long.class; + } else { + return int.class; + } + } + + private static final Lookup PRIV_LOOKUP = MethodHandles.lookup(); + + private static final Map,Map> TYPE_OP_MAPPING = Collections.unmodifiableMap( + Stream.of(boolean.class, int.class, long.class, float.class, double.class, Object.class) + .collect(Collectors.toMap(Function.identity(), type -> { + try { + Map map = new HashMap<>(); + MethodType unary = MethodType.methodType(type, type); + MethodType binary = MethodType.methodType(type, type, type); + MethodType comparison = MethodType.methodType(boolean.class, type, type); + MethodType shift = MethodType.methodType(type, type, long.class); + Class clazz = PRIV_LOOKUP.lookupClass(); + map.put("not", PRIV_LOOKUP.findStatic(clazz, "not", unary)); + map.put("neg", PRIV_LOOKUP.findStatic(clazz, "neg", unary)); + map.put("plus", PRIV_LOOKUP.findStatic(clazz, "plus", unary)); + map.put("mul", PRIV_LOOKUP.findStatic(clazz, "mul", binary)); + map.put("div", PRIV_LOOKUP.findStatic(clazz, "div", binary)); + map.put("rem", PRIV_LOOKUP.findStatic(clazz, "rem", binary)); + map.put("add", PRIV_LOOKUP.findStatic(clazz, "add", binary)); + map.put("sub", PRIV_LOOKUP.findStatic(clazz, "sub", binary)); + map.put("and", PRIV_LOOKUP.findStatic(clazz, "and", binary)); + map.put("or", PRIV_LOOKUP.findStatic(clazz, "or", binary)); + map.put("xor", PRIV_LOOKUP.findStatic(clazz, "xor", binary)); + map.put("eq", PRIV_LOOKUP.findStatic(clazz, "eq", comparison)); + map.put("lt", PRIV_LOOKUP.findStatic(clazz, "lt", comparison)); + map.put("lte", PRIV_LOOKUP.findStatic(clazz, "lte", comparison)); + map.put("gt", PRIV_LOOKUP.findStatic(clazz, "gt", comparison)); + map.put("gte", PRIV_LOOKUP.findStatic(clazz, "gte", comparison)); + map.put("lsh", PRIV_LOOKUP.findStatic(clazz, "lsh", shift)); + map.put("rsh", PRIV_LOOKUP.findStatic(clazz, "rsh", shift)); + map.put("ush", PRIV_LOOKUP.findStatic(clazz, "ush", shift)); + return map; + } catch (ReflectiveOperationException e) { + throw new AssertionError(e); + } + })) + ); + + /** Returns an appropriate method handle for a unary or shift operator, based only on the receiver (LHS) */ + public static MethodHandle lookupUnary(Class receiverClass, String name) { + MethodHandle handle = TYPE_OP_MAPPING.get(promote(unbox(receiverClass))).get(name); + if (handle == null) { + throw new ClassCastException("Cannot apply operator [" + name + "] to type [" + receiverClass + "]"); + } + return handle; + } + + /** Returns an appropriate method handle for a binary operator, based only promotion of the LHS and RHS arguments */ + public static MethodHandle lookupBinary(Class classA, Class classB, String name) { + MethodHandle handle = TYPE_OP_MAPPING.get(promote(promote(unbox(classA)), promote(unbox(classB)))).get(name); + if (handle == null) { + throw new ClassCastException("Cannot apply operator [" + name + "] to types [" + classA + "] and [" + classB + "]"); + } + return handle; + } + + /** Returns a generic method handle for any operator, that can handle all valid signatures, nulls, corner cases */ + public static MethodHandle lookupGeneric(String name) { + return TYPE_OP_MAPPING.get(Object.class).get(name); + } +} diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java index 2abcace5d6d..ae46ee54f35 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java @@ -35,15 +35,7 @@ import java.util.Deque; import java.util.List; import static org.elasticsearch.painless.WriterConstants.CHAR_TO_STRING; -import static org.elasticsearch.painless.WriterConstants.DEF_ADD_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_AND_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_DIV_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_LSH_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_MUL_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_OR_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_REM_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_RSH_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_SUB_CALL; +import static org.elasticsearch.painless.WriterConstants.DEF_BOOTSTRAP_HANDLE; import static org.elasticsearch.painless.WriterConstants.DEF_TO_BOOLEAN; import static org.elasticsearch.painless.WriterConstants.DEF_TO_BYTE_EXPLICIT; import static org.elasticsearch.painless.WriterConstants.DEF_TO_BYTE_IMPLICIT; @@ -59,9 +51,7 @@ import static org.elasticsearch.painless.WriterConstants.DEF_TO_LONG_EXPLICIT; import static org.elasticsearch.painless.WriterConstants.DEF_TO_LONG_IMPLICIT; import static org.elasticsearch.painless.WriterConstants.DEF_TO_SHORT_EXPLICIT; import static org.elasticsearch.painless.WriterConstants.DEF_TO_SHORT_IMPLICIT; -import static org.elasticsearch.painless.WriterConstants.DEF_USH_CALL; import static org.elasticsearch.painless.WriterConstants.DEF_UTIL_TYPE; -import static org.elasticsearch.painless.WriterConstants.DEF_XOR_CALL; import static org.elasticsearch.painless.WriterConstants.INDY_STRING_CONCAT_BOOTSTRAP_HANDLE; import static org.elasticsearch.painless.WriterConstants.MAX_INDY_STRING_CONCAT_ARGS; import static org.elasticsearch.painless.WriterConstants.PAINLESS_ERROR_TYPE; @@ -283,18 +273,44 @@ public final class MethodWriter extends GeneratorAdapter { } if (sort == Sort.DEF) { + // XXX: move this out, so we can populate descriptor with what we really have (instead of casts/boxing!) + org.objectweb.asm.Type objectType = org.objectweb.asm.Type.getType(Object.class); + org.objectweb.asm.Type descriptor = org.objectweb.asm.Type.getMethodType(objectType, objectType, objectType); + switch (operation) { - case MUL: invokeStatic(DEF_UTIL_TYPE, DEF_MUL_CALL); break; - case DIV: invokeStatic(DEF_UTIL_TYPE, DEF_DIV_CALL); break; - case REM: invokeStatic(DEF_UTIL_TYPE, DEF_REM_CALL); break; - case ADD: invokeStatic(DEF_UTIL_TYPE, DEF_ADD_CALL); break; - case SUB: invokeStatic(DEF_UTIL_TYPE, DEF_SUB_CALL); break; - case LSH: invokeStatic(DEF_UTIL_TYPE, DEF_LSH_CALL); break; - case USH: invokeStatic(DEF_UTIL_TYPE, DEF_RSH_CALL); break; - case RSH: invokeStatic(DEF_UTIL_TYPE, DEF_USH_CALL); break; - case BWAND: invokeStatic(DEF_UTIL_TYPE, DEF_AND_CALL); break; - case XOR: invokeStatic(DEF_UTIL_TYPE, DEF_XOR_CALL); break; - case BWOR: invokeStatic(DEF_UTIL_TYPE, DEF_OR_CALL); break; + case MUL: + invokeDynamic("mul", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); + break; + case DIV: + invokeDynamic("div", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); + break; + case REM: + invokeDynamic("rem", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); + break; + case ADD: + invokeDynamic("add", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); + break; + case SUB: + invokeDynamic("sub", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); + break; + case LSH: + invokeDynamic("lsh", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.SHIFT_OPERATOR); + break; + case USH: + invokeDynamic("ush", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.SHIFT_OPERATOR); + break; + case RSH: + invokeDynamic("rsh", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.SHIFT_OPERATOR); + break; + case BWAND: + invokeDynamic("and", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); + break; + case XOR: + invokeDynamic("xor", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); + break; + case BWOR: + invokeDynamic("or", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); + break; default: throw location.createError(new IllegalStateException("Illegal tree structure.")); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Utility.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Utility.java index 5ab3450db7e..b965f25bf0e 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Utility.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Utility.java @@ -37,13 +37,5 @@ public class Utility { return value.charAt(0); } - public static boolean checkEquals(final Object left, final Object right) { - if (left != null) { - return left.equals(right); - } - - return right == null || right.equals(null); - } - private Utility() {} } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/WriterConstants.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/WriterConstants.java index 521287a02d7..68814a2e1a6 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/WriterConstants.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/WriterConstants.java @@ -34,6 +34,7 @@ import java.lang.invoke.MethodType; import java.util.BitSet; import java.util.Iterator; import java.util.Map; +import java.util.Objects; import java.util.regex.Pattern; /** @@ -102,24 +103,6 @@ public final class WriterConstants { public final static Method DEF_TO_LONG_EXPLICIT = getAsmMethod(long.class , "DefTolongExplicit" , Object.class); public final static Method DEF_TO_FLOAT_EXPLICIT = getAsmMethod(float.class , "DefTofloatExplicit" , Object.class); public final static Method DEF_TO_DOUBLE_EXPLICIT = getAsmMethod(double.class , "DefTodoubleExplicit", Object.class); - public final static Method DEF_NOT_CALL = getAsmMethod(Object.class , "not", Object.class); - public final static Method DEF_NEG_CALL = getAsmMethod(Object.class , "neg", Object.class); - public final static Method DEF_MUL_CALL = getAsmMethod(Object.class , "mul", Object.class, Object.class); - public final static Method DEF_DIV_CALL = getAsmMethod(Object.class , "div", Object.class, Object.class); - public final static Method DEF_REM_CALL = getAsmMethod(Object.class , "rem", Object.class, Object.class); - public final static Method DEF_ADD_CALL = getAsmMethod(Object.class , "add", Object.class, Object.class); - public final static Method DEF_SUB_CALL = getAsmMethod(Object.class , "sub", Object.class, Object.class); - public final static Method DEF_LSH_CALL = getAsmMethod(Object.class , "lsh", Object.class, int.class); - public final static Method DEF_RSH_CALL = getAsmMethod(Object.class , "rsh", Object.class, int.class); - public final static Method DEF_USH_CALL = getAsmMethod(Object.class , "ush", Object.class, int.class); - public final static Method DEF_AND_CALL = getAsmMethod(Object.class , "and", Object.class, Object.class); - public final static Method DEF_XOR_CALL = getAsmMethod(Object.class , "xor", Object.class, Object.class); - public final static Method DEF_OR_CALL = getAsmMethod(Object.class , "or" , Object.class, Object.class); - public final static Method DEF_EQ_CALL = getAsmMethod(boolean.class, "eq" , Object.class, Object.class); - public final static Method DEF_LT_CALL = getAsmMethod(boolean.class, "lt" , Object.class, Object.class); - public final static Method DEF_LTE_CALL = getAsmMethod(boolean.class, "lte", Object.class, Object.class); - public final static Method DEF_GT_CALL = getAsmMethod(boolean.class, "gt" , Object.class, Object.class); - public final static Method DEF_GTE_CALL = getAsmMethod(boolean.class, "gte", Object.class, Object.class); /** invokedynamic bootstrap for lambda expression/method references */ public final static MethodType LAMBDA_BOOTSTRAP_TYPE = @@ -163,7 +146,8 @@ public final class WriterConstants { public final static Method STRINGBUILDER_APPEND_OBJECT = getAsmMethod(StringBuilder.class, "append", Object.class); public final static Method STRINGBUILDER_TOSTRING = getAsmMethod(String.class, "toString"); - public final static Method CHECKEQUALS = getAsmMethod(boolean.class, "checkEquals", Object.class, Object.class); + public final static Type OBJECTS_TYPE = Type.getType(Objects.class); + public final static Method EQUALS = getAsmMethod(boolean.class, "equals", Object.class, Object.class); private static Method getAsmMethod(final Class rtype, final String name, final Class... ptypes) { return new Method(name, MethodType.methodType(rtype, ptypes).toMethodDescriptorString()); diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EBinary.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EBinary.java index 55c2cc18210..9f8f1553770 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EBinary.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EBinary.java @@ -284,22 +284,27 @@ public final class EBinary extends AExpression { left.analyze(variables); right.analyze(variables); - Type promote = AnalyzerCaster.promoteNumeric(left.actual, false); + Type lhspromote = AnalyzerCaster.promoteNumeric(left.actual, false); + Type rhspromote = AnalyzerCaster.promoteNumeric(right.actual, false); - if (promote == null) { + if (lhspromote == null || rhspromote == null) { throw createError(new ClassCastException("Cannot apply left shift [<<] to types " + "[" + left.actual.name + "] and [" + right.actual.name + "].")); } - left.expected = promote; - right.expected = Definition.INT_TYPE; - right.explicit = true; + left.expected = lhspromote; + if (rhspromote.sort == Sort.LONG) { + right.expected = Definition.INT_TYPE; + right.explicit = true; + } else { + right.expected = rhspromote; + } left = left.cast(variables); right = right.cast(variables); if (left.constant != null && right.constant != null) { - Sort sort = promote.sort; + Sort sort = lhspromote.sort; if (sort == Sort.INT) { constant = (int)left.constant << (int)right.constant; @@ -310,29 +315,34 @@ public final class EBinary extends AExpression { } } - actual = promote; + actual = lhspromote; } private void analyzeRSH(Locals variables) { left.analyze(variables); right.analyze(variables); - Type promote = AnalyzerCaster.promoteNumeric(left.actual, false); + Type lhspromote = AnalyzerCaster.promoteNumeric(left.actual, false); + Type rhspromote = AnalyzerCaster.promoteNumeric(right.actual, false); - if (promote == null) { + if (lhspromote == null || rhspromote == null) { throw createError(new ClassCastException("Cannot apply right shift [>>] to types " + "[" + left.actual.name + "] and [" + right.actual.name + "].")); } - left.expected = promote; - right.expected = Definition.INT_TYPE; - right.explicit = true; + left.expected = lhspromote; + if (rhspromote.sort == Sort.LONG) { + right.expected = Definition.INT_TYPE; + right.explicit = true; + } else { + right.expected = rhspromote; + } left = left.cast(variables); right = right.cast(variables); if (left.constant != null && right.constant != null) { - Sort sort = promote.sort; + Sort sort = lhspromote.sort; if (sort == Sort.INT) { constant = (int)left.constant >> (int)right.constant; @@ -343,29 +353,34 @@ public final class EBinary extends AExpression { } } - actual = promote; + actual = lhspromote; } private void analyzeUSH(Locals variables) { left.analyze(variables); right.analyze(variables); - Type promote = AnalyzerCaster.promoteNumeric(left.actual, false); + Type lhspromote = AnalyzerCaster.promoteNumeric(left.actual, false); + Type rhspromote = AnalyzerCaster.promoteNumeric(right.actual, false); - if (promote == null) { + if (lhspromote == null || rhspromote == null) { throw createError(new ClassCastException("Cannot apply unsigned shift [>>>] to types " + "[" + left.actual.name + "] and [" + right.actual.name + "].")); } - left.expected = promote; - right.expected = Definition.INT_TYPE; - right.explicit = true; + left.expected = lhspromote; + if (rhspromote.sort == Sort.LONG) { + right.expected = Definition.INT_TYPE; + right.explicit = true; + } else { + right.expected = rhspromote; + } left = left.cast(variables); right = right.cast(variables); if (left.constant != null && right.constant != null) { - Sort sort = promote.sort; + Sort sort = lhspromote.sort; if (sort == Sort.INT) { constant = (int)left.constant >>> (int)right.constant; @@ -376,7 +391,7 @@ public final class EBinary extends AExpression { } } - actual = promote; + actual = lhspromote; } private void analyzeBWAnd(Locals variables) { diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/ECapturingFunctionRef.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/ECapturingFunctionRef.java index 3e35602a3a2..ac316f07fbd 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/ECapturingFunctionRef.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/ECapturingFunctionRef.java @@ -87,7 +87,7 @@ public class ECapturingFunctionRef extends AExpression { // typed interface, dynamic implementation writer.visitVarInsn(captured.type.type.getOpcode(Opcodes.ILOAD), captured.slot); String descriptor = Type.getMethodType(expected.type, captured.type.type).getDescriptor(); - writer.invokeDynamic(call, descriptor, DEF_BOOTSTRAP_HANDLE, (Object)DefBootstrap.REFERENCE, expected.name); + writer.invokeDynamic(call, descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.REFERENCE, expected.name); } else { // typed interface, typed implementation writer.visitVarInsn(captured.type.type.getOpcode(Opcodes.ILOAD), captured.slot); diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EComp.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EComp.java index ec84600d323..a126cd7d936 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EComp.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EComp.java @@ -24,19 +24,15 @@ import org.elasticsearch.painless.Definition.Sort; import org.elasticsearch.painless.Definition.Type; import org.elasticsearch.painless.Location; import org.elasticsearch.painless.AnalyzerCaster; +import org.elasticsearch.painless.DefBootstrap; import org.elasticsearch.painless.Operation; import org.elasticsearch.painless.Locals; import org.objectweb.asm.Label; import org.elasticsearch.painless.MethodWriter; -import static org.elasticsearch.painless.WriterConstants.CHECKEQUALS; -import static org.elasticsearch.painless.WriterConstants.DEF_EQ_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_GTE_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_GT_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_LTE_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_LT_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_UTIL_TYPE; -import static org.elasticsearch.painless.WriterConstants.UTILITY_TYPE; +import static org.elasticsearch.painless.WriterConstants.OBJECTS_TYPE; +import static org.elasticsearch.painless.WriterConstants.EQUALS; +import static org.elasticsearch.painless.WriterConstants.DEF_BOOTSTRAP_HANDLE; /** * Represents a comparison expression. @@ -456,11 +452,15 @@ public final class EComp extends AExpression { break; case DEF: + // XXX: move this out, so we can populate descriptor with what we really have (instead of casts/boxing!) + org.objectweb.asm.Type booleanType = org.objectweb.asm.Type.getType(boolean.class); + org.objectweb.asm.Type objectType = org.objectweb.asm.Type.getType(Object.class); + org.objectweb.asm.Type descriptor = org.objectweb.asm.Type.getMethodType(booleanType, objectType, objectType); if (eq) { if (right.isNull) { writer.ifNull(jump); } else if (!left.isNull && (operation == Operation.EQ || operation == Operation.NE)) { - writer.invokeStatic(DEF_UTIL_TYPE, DEF_EQ_CALL); + writer.invokeDynamic("eq", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); writejump = false; } else { writer.ifCmp(rtype, MethodWriter.EQ, jump); @@ -469,22 +469,22 @@ public final class EComp extends AExpression { if (right.isNull) { writer.ifNonNull(jump); } else if (!left.isNull && (operation == Operation.EQ || operation == Operation.NE)) { - writer.invokeStatic(DEF_UTIL_TYPE, DEF_EQ_CALL); + writer.invokeDynamic("eq", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); writer.ifZCmp(MethodWriter.EQ, jump); } else { writer.ifCmp(rtype, MethodWriter.NE, jump); } } else if (lt) { - writer.invokeStatic(DEF_UTIL_TYPE, DEF_LT_CALL); + writer.invokeDynamic("lt", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); writejump = false; } else if (lte) { - writer.invokeStatic(DEF_UTIL_TYPE, DEF_LTE_CALL); + writer.invokeDynamic("lte", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); writejump = false; } else if (gt) { - writer.invokeStatic(DEF_UTIL_TYPE, DEF_GT_CALL); + writer.invokeDynamic("gt", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); writejump = false; } else if (gte) { - writer.invokeStatic(DEF_UTIL_TYPE, DEF_GTE_CALL); + writer.invokeDynamic("gte", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); writejump = false; } else { throw createError(new IllegalStateException("Illegal tree structure.")); @@ -500,7 +500,7 @@ public final class EComp extends AExpression { if (right.isNull) { writer.ifNull(jump); } else if (operation == Operation.EQ || operation == Operation.NE) { - writer.invokeStatic(UTILITY_TYPE, CHECKEQUALS); + writer.invokeStatic(OBJECTS_TYPE, EQUALS); if (branch) { writer.ifZCmp(MethodWriter.NE, jump); @@ -514,7 +514,7 @@ public final class EComp extends AExpression { if (right.isNull) { writer.ifNonNull(jump); } else if (operation == Operation.EQ || operation == Operation.NE) { - writer.invokeStatic(UTILITY_TYPE, CHECKEQUALS); + writer.invokeStatic(OBJECTS_TYPE, EQUALS); writer.ifZCmp(MethodWriter.EQ, jump); } else { writer.ifCmp(rtype, MethodWriter.NE, jump); diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EUnary.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EUnary.java index 2d7d8d4fd49..d184961fd28 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EUnary.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EUnary.java @@ -24,14 +24,13 @@ import org.elasticsearch.painless.Location; import org.elasticsearch.painless.Definition.Sort; import org.elasticsearch.painless.Definition.Type; import org.elasticsearch.painless.AnalyzerCaster; +import org.elasticsearch.painless.DefBootstrap; import org.elasticsearch.painless.Operation; import org.elasticsearch.painless.Locals; import org.objectweb.asm.Label; import org.elasticsearch.painless.MethodWriter; -import static org.elasticsearch.painless.WriterConstants.DEF_NEG_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_NOT_CALL; -import static org.elasticsearch.painless.WriterConstants.DEF_UTIL_TYPE; +import static org.elasticsearch.painless.WriterConstants.DEF_BOOTSTRAP_HANDLE; /** * Represents a unary math expression. @@ -194,7 +193,8 @@ public final class EUnary extends AExpression { if (operation == Operation.BWNOT) { if (sort == Sort.DEF) { - writer.invokeStatic(DEF_UTIL_TYPE, DEF_NOT_CALL); + org.objectweb.asm.Type descriptor = org.objectweb.asm.Type.getMethodType(expected.type, child.actual.type); + writer.invokeDynamic("not", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.UNARY_OPERATOR); } else { if (sort == Sort.INT) { writer.push(-1); @@ -208,11 +208,17 @@ public final class EUnary extends AExpression { } } else if (operation == Operation.SUB) { if (sort == Sort.DEF) { - writer.invokeStatic(DEF_UTIL_TYPE, DEF_NEG_CALL); + org.objectweb.asm.Type descriptor = org.objectweb.asm.Type.getMethodType(expected.type, child.actual.type); + writer.invokeDynamic("neg", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.UNARY_OPERATOR); } else { writer.math(MethodWriter.NEG, type); } - } else if (operation != Operation.ADD) { + } else if (operation == Operation.ADD) { + if (sort == Sort.DEF) { + org.objectweb.asm.Type descriptor = org.objectweb.asm.Type.getMethodType(expected.type, child.actual.type); + writer.invokeDynamic("plus", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.UNARY_OPERATOR); + } + } else { throw createError(new IllegalStateException("Illegal tree structure.")); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LDefArray.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LDefArray.java index dc8890e5122..5a5333aabe8 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LDefArray.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LDefArray.java @@ -62,7 +62,7 @@ final class LDefArray extends ALink implements IDefLink { writer.writeDebugInfo(location); String desc = Type.getMethodDescriptor(after.type, Definition.DEF_TYPE.type, index.actual.type); - writer.invokeDynamic("arrayLoad", desc, DEF_BOOTSTRAP_HANDLE, (Object)DefBootstrap.ARRAY_LOAD); + writer.invokeDynamic("arrayLoad", desc, DEF_BOOTSTRAP_HANDLE, DefBootstrap.ARRAY_LOAD); } @Override @@ -70,6 +70,6 @@ final class LDefArray extends ALink implements IDefLink { writer.writeDebugInfo(location); String desc = Type.getMethodDescriptor(Definition.VOID_TYPE.type, Definition.DEF_TYPE.type, index.actual.type, after.type); - writer.invokeDynamic("arrayStore", desc, DEF_BOOTSTRAP_HANDLE, (Object)DefBootstrap.ARRAY_STORE); + writer.invokeDynamic("arrayStore", desc, DEF_BOOTSTRAP_HANDLE, DefBootstrap.ARRAY_STORE); } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LDefCall.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LDefCall.java index 554144c2999..5301dd2b08d 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LDefCall.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LDefCall.java @@ -105,7 +105,7 @@ final class LDefCall extends ALink implements IDefLink { // return value signature.append(after.type.getDescriptor()); - writer.invokeDynamic(name, signature.toString(), DEF_BOOTSTRAP_HANDLE, (Object)DefBootstrap.METHOD_CALL, recipe); + writer.invokeDynamic(name, signature.toString(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.METHOD_CALL, recipe); } @Override diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LDefField.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LDefField.java index 91ee8e0f03d..09b48a94f86 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LDefField.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LDefField.java @@ -59,7 +59,7 @@ final class LDefField extends ALink implements IDefLink { writer.writeDebugInfo(location); String desc = Type.getMethodDescriptor(after.type, Definition.DEF_TYPE.type); - writer.invokeDynamic(value, desc, DEF_BOOTSTRAP_HANDLE, (Object)DefBootstrap.LOAD); + writer.invokeDynamic(value, desc, DEF_BOOTSTRAP_HANDLE, DefBootstrap.LOAD); } @Override @@ -67,6 +67,6 @@ final class LDefField extends ALink implements IDefLink { writer.writeDebugInfo(location); String desc = Type.getMethodDescriptor(Definition.VOID_TYPE.type, Definition.DEF_TYPE.type, after.type); - writer.invokeDynamic(value, desc, DEF_BOOTSTRAP_HANDLE, (Object)DefBootstrap.STORE); + writer.invokeDynamic(value, desc, DEF_BOOTSTRAP_HANDLE, DefBootstrap.STORE); } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java index a6156324873..48a3f6e3eae 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java @@ -195,7 +195,7 @@ public class SEach extends AStatement { if (method == null) { Type itr = Definition.getType("Iterator"); String desc = org.objectweb.asm.Type.getMethodDescriptor(itr.type, Definition.DEF_TYPE.type); - writer.invokeDynamic("iterator", desc, DEF_BOOTSTRAP_HANDLE, (Object)DefBootstrap.ITERATOR); + writer.invokeDynamic("iterator", desc, DEF_BOOTSTRAP_HANDLE, DefBootstrap.ITERATOR); } else if (java.lang.reflect.Modifier.isInterface(method.owner.clazz.getModifiers())) { writer.invokeInterface(method.owner.type, method.method); } else { diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/AndTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/AndTests.java index 2c86250da83..6be068a10d1 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/AndTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/AndTests.java @@ -45,4 +45,13 @@ public class AndTests extends ScriptTestCase { assertEquals(5L & -12L, exec("return 5L & -12L;")); assertEquals(7L & 15L & 3L, exec("return 7L & 15L & 3L;")); } + + public void testIllegal() throws Exception { + expectScriptThrows(ClassCastException.class, () -> { + exec("float x = (float)4; int y = 1; return x & y"); + }); + expectScriptThrows(ClassCastException.class, () -> { + exec("double x = (double)4; int y = 1; return x & y"); + }); + } } diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/BinaryOperatorTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/BinaryOperatorTests.java index 54a36234fd8..61bac8bd6b5 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/BinaryOperatorTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/BinaryOperatorTests.java @@ -45,13 +45,12 @@ public class BinaryOperatorTests extends ScriptTestCase { } public void testLongShifts() { - // note: we always promote the results of shifts too (unlike java) assertEquals(1L << 2, exec("long x = 1L; int y = 2; return x << y;")); - assertEquals(1L << 2L, exec("long x = 1L; long y = 2L; return x << y;")); - assertEquals(4L >> 2L, exec("long x = 4L; long y = 2L; return x >> y;")); + assertEquals(1 << 2L, exec("int x = 1; long y = 2L; return x << y;")); + assertEquals(4 >> 2L, exec("int x = 4; long y = 2L; return x >> y;")); assertEquals(4L >> 2, exec("long x = 4L; int y = 2; return x >> y;")); assertEquals(-1L >>> 29, exec("long x = -1L; int y = 29; return x >>> y;")); - assertEquals(-1L >>> 29L, exec("long x = -1L; long y = 29L; return x >>> y;")); + assertEquals(-1 >>> 29L, exec("int x = -1; long y = 29L; return x >>> y;")); } public void testLongShiftsConst() { @@ -62,6 +61,36 @@ public class BinaryOperatorTests extends ScriptTestCase { assertEquals(-1L >>> 29, exec("return -1L >>> 29;")); assertEquals(-1 >>> 29L, exec("return -1 >>> 29L;")); } + + public void testBogusShifts() { + expectScriptThrows(ClassCastException.class, ()-> { + exec("long x = 1L; float y = 2; return x << y;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("int x = 1; double y = 2L; return x << y;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("float x = 1F; int y = 2; return x << y;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("double x = 1D; int y = 2L; return x << y;"); + }); + } + + public void testBogusShiftsConst() { + expectScriptThrows(ClassCastException.class, ()-> { + exec("return 1L << 2F;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("return 1L << 2.0;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("return 1F << 2;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("return 1D << 2L"); + }); + } public void testMixedTypes() { assertEquals(8, exec("int x = 4; char y = 2; return x*y;")); diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/DefBootstrapTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/DefBootstrapTests.java index 4330c613e14..12104844e7c 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/DefBootstrapTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/DefBootstrapTests.java @@ -23,6 +23,8 @@ import java.lang.invoke.CallSite; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; +import java.util.Arrays; +import java.util.Collections; import org.elasticsearch.test.ESTestCase; @@ -90,6 +92,19 @@ public class DefBootstrapTests extends ESTestCase { assertDepthEquals(site, 5); } + /** test that we really revert to a "generic" method that can handle any receiver types */ + public void testMegamorphic() throws Throwable { + DefBootstrap.PIC site = (DefBootstrap.PIC) DefBootstrap.bootstrap(MethodHandles.publicLookup(), + "size", + MethodType.methodType(int.class, Object.class), + DefBootstrap.METHOD_CALL, 0L); + site.depth = DefBootstrap.PIC.MAX_DEPTH; // mark megamorphic + MethodHandle handle = site.dynamicInvoker(); + // arguments are cast to object here, or IDE compilers eat it :) + assertEquals(2, handle.invoke((Object) Arrays.asList("1", "2"))); + assertEquals(1, handle.invoke((Object) Collections.singletonMap("a", "b"))); + } + static void assertDepthEquals(CallSite site, int expected) { DefBootstrap.PIC dsite = (DefBootstrap.PIC) site; assertEquals(expected, dsite.depth); diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/DefOperationTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/DefOperationTests.java index ca611760f07..bb720943f9d 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/DefOperationTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/DefOperationTests.java @@ -49,6 +49,16 @@ public class DefOperationTests extends ScriptTestCase { assertEquals(-1.0F, exec("def x = 1F; return -x")); assertEquals(-1.0, exec("def x = 1.0; return -x")); } + + public void testPlus() { + assertEquals(-1, exec("def x = (byte)-1; return +x")); + assertEquals(-1, exec("def x = (short)-1; return +x")); + assertEquals(65535, exec("def x = (char)-1; return +x")); + assertEquals(-1, exec("def x = -1; return +x")); + assertEquals(-1L, exec("def x = -1L; return +x")); + assertEquals(-1.0F, exec("def x = -1F; return +x")); + assertEquals(-1.0D, exec("def x = -1.0; return +x")); + } public void testMul() { assertEquals(4, exec("def x = (byte)2; def y = (byte)2; return x * y")); @@ -313,6 +323,29 @@ public class DefOperationTests extends ScriptTestCase { assertEquals(2F, exec("def x = (float)1; def y = (float)1; return x + y")); assertEquals(2D, exec("def x = (double)1; def y = (double)1; return x + y")); } + + public void testAddConcat() { + assertEquals("a" + (byte)2, exec("def x = 'a'; def y = (byte)2; return x + y")); + assertEquals("a" + (short)2, exec("def x = 'a'; def y = (short)2; return x + y")); + assertEquals("a" + (char)2, exec("def x = 'a'; def y = (char)2; return x + y")); + assertEquals("a" + 2, exec("def x = 'a'; def y = (int)2; return x + y")); + assertEquals("a" + 2L, exec("def x = 'a'; def y = (long)2; return x + y")); + assertEquals("a" + 2F, exec("def x = 'a'; def y = (float)2; return x + y")); + assertEquals("a" + 2D, exec("def x = 'a'; def y = (double)2; return x + y")); + assertEquals("ab", exec("def x = 'a'; def y = 'b'; return x + y")); + assertEquals((byte)2 + "a", exec("def x = 'a'; def y = (byte)2; return y + x")); + assertEquals((short)2 + "a", exec("def x = 'a'; def y = (short)2; return y + x")); + assertEquals((char)2 + "a", exec("def x = 'a'; def y = (char)2; return y + x")); + assertEquals(2 + "a", exec("def x = 'a'; def y = (int)2; return y + x")); + assertEquals(2L + "a", exec("def x = 'a'; def y = (long)2; return y + x")); + assertEquals(2F + "a", exec("def x = 'a'; def y = (float)2; return y + x")); + assertEquals(2D + "a", exec("def x = 'a'; def y = (double)2; return y + x")); + assertEquals("anull", exec("def x = 'a'; def y = null; return x + y")); + assertEquals("nullb", exec("def x = null; def y = 'b'; return x + y")); + expectScriptThrows(NullPointerException.class, () -> { + exec("def x = null; def y = null; return x + y"); + }); + } public void testSub() { assertEquals(0, exec("def x = (byte)1; def y = (byte)1; return x - y")); @@ -386,64 +419,36 @@ public class DefOperationTests extends ScriptTestCase { assertEquals(2, exec("def x = (char)1; def y = (byte)1; return x << y")); assertEquals(2, exec("def x = (int)1; def y = (byte)1; return x << y")); assertEquals(2L, exec("def x = (long)1; def y = (byte)1; return x << y")); - assertEquals(2L, exec("def x = (float)1; def y = (byte)1; return x << y")); - assertEquals(2L, exec("def x = (double)1; def y = (byte)1; return x << y")); assertEquals(2, exec("def x = (byte)1; def y = (short)1; return x << y")); assertEquals(2, exec("def x = (short)1; def y = (short)1; return x << y")); assertEquals(2, exec("def x = (char)1; def y = (short)1; return x << y")); assertEquals(2, exec("def x = (int)1; def y = (short)1; return x << y")); assertEquals(2L, exec("def x = (long)1; def y = (short)1; return x << y")); - assertEquals(2L, exec("def x = (float)1; def y = (short)1; return x << y")); - assertEquals(2L, exec("def x = (double)1; def y = (short)1; return x << y")); assertEquals(2, exec("def x = (byte)1; def y = (char)1; return x << y")); assertEquals(2, exec("def x = (short)1; def y = (char)1; return x << y")); assertEquals(2, exec("def x = (char)1; def y = (char)1; return x << y")); assertEquals(2, exec("def x = (int)1; def y = (char)1; return x << y")); assertEquals(2L, exec("def x = (long)1; def y = (char)1; return x << y")); - assertEquals(2L, exec("def x = (float)1; def y = (char)1; return x << y")); - assertEquals(2L, exec("def x = (double)1; def y = (char)1; return x << y")); assertEquals(2, exec("def x = (byte)1; def y = (int)1; return x << y")); assertEquals(2, exec("def x = (short)1; def y = (int)1; return x << y")); assertEquals(2, exec("def x = (char)1; def y = (int)1; return x << y")); assertEquals(2, exec("def x = (int)1; def y = (int)1; return x << y")); assertEquals(2L, exec("def x = (long)1; def y = (int)1; return x << y")); - assertEquals(2L, exec("def x = (float)1; def y = (int)1; return x << y")); - assertEquals(2L, exec("def x = (double)1; def y = (int)1; return x << y")); assertEquals(2, exec("def x = (byte)1; def y = (long)1; return x << y")); assertEquals(2, exec("def x = (short)1; def y = (long)1; return x << y")); assertEquals(2, exec("def x = (char)1; def y = (long)1; return x << y")); assertEquals(2, exec("def x = (int)1; def y = (long)1; return x << y")); assertEquals(2L, exec("def x = (long)1; def y = (long)1; return x << y")); - assertEquals(2L, exec("def x = (float)1; def y = (long)1; return x << y")); - assertEquals(2L, exec("def x = (double)1; def y = (long)1; return x << y")); - - assertEquals(2, exec("def x = (byte)1; def y = (float)1; return x << y")); - assertEquals(2, exec("def x = (short)1; def y = (float)1; return x << y")); - assertEquals(2, exec("def x = (char)1; def y = (float)1; return x << y")); - assertEquals(2, exec("def x = (int)1; def y = (float)1; return x << y")); - assertEquals(2L, exec("def x = (long)1; def y = (float)1; return x << y")); - assertEquals(2L, exec("def x = (float)1; def y = (float)1; return x << y")); - assertEquals(2L, exec("def x = (double)1; def y = (float)1; return x << y")); - - assertEquals(2, exec("def x = (byte)1; def y = (double)1; return x << y")); - assertEquals(2, exec("def x = (short)1; def y = (double)1; return x << y")); - assertEquals(2, exec("def x = (char)1; def y = (double)1; return x << y")); - assertEquals(2, exec("def x = (int)1; def y = (double)1; return x << y")); - assertEquals(2L, exec("def x = (long)1; def y = (double)1; return x << y")); - assertEquals(2L, exec("def x = (float)1; def y = (double)1; return x << y")); - assertEquals(2L, exec("def x = (double)1; def y = (double)1; return x << y")); assertEquals(2, exec("def x = (byte)1; def y = (byte)1; return x << y")); assertEquals(2, exec("def x = (short)1; def y = (short)1; return x << y")); assertEquals(2, exec("def x = (char)1; def y = (char)1; return x << y")); assertEquals(2, exec("def x = (int)1; def y = (int)1; return x << y")); assertEquals(2L, exec("def x = (long)1; def y = (long)1; return x << y")); - assertEquals(2L, exec("def x = (float)1; def y = (float)1; return x << y")); - assertEquals(2L, exec("def x = (double)1; def y = (double)1; return x << y")); } public void testRsh() { @@ -452,64 +457,36 @@ public class DefOperationTests extends ScriptTestCase { assertEquals(2, exec("def x = (char)4; def y = (byte)1; return x >> y")); assertEquals(2, exec("def x = (int)4; def y = (byte)1; return x >> y")); assertEquals(2L, exec("def x = (long)4; def y = (byte)1; return x >> y")); - assertEquals(2L, exec("def x = (float)4; def y = (byte)1; return x >> y")); - assertEquals(2L, exec("def x = (double)4; def y = (byte)1; return x >> y")); assertEquals(2, exec("def x = (byte)4; def y = (short)1; return x >> y")); assertEquals(2, exec("def x = (short)4; def y = (short)1; return x >> y")); assertEquals(2, exec("def x = (char)4; def y = (short)1; return x >> y")); assertEquals(2, exec("def x = (int)4; def y = (short)1; return x >> y")); assertEquals(2L, exec("def x = (long)4; def y = (short)1; return x >> y")); - assertEquals(2L, exec("def x = (float)4; def y = (short)1; return x >> y")); - assertEquals(2L, exec("def x = (double)4; def y = (short)1; return x >> y")); assertEquals(2, exec("def x = (byte)4; def y = (char)1; return x >> y")); assertEquals(2, exec("def x = (short)4; def y = (char)1; return x >> y")); assertEquals(2, exec("def x = (char)4; def y = (char)1; return x >> y")); assertEquals(2, exec("def x = (int)4; def y = (char)1; return x >> y")); assertEquals(2L, exec("def x = (long)4; def y = (char)1; return x >> y")); - assertEquals(2L, exec("def x = (float)4; def y = (char)1; return x >> y")); - assertEquals(2L, exec("def x = (double)4; def y = (char)1; return x >> y")); assertEquals(2, exec("def x = (byte)4; def y = (int)1; return x >> y")); assertEquals(2, exec("def x = (short)4; def y = (int)1; return x >> y")); assertEquals(2, exec("def x = (char)4; def y = (int)1; return x >> y")); assertEquals(2, exec("def x = (int)4; def y = (int)1; return x >> y")); assertEquals(2L, exec("def x = (long)4; def y = (int)1; return x >> y")); - assertEquals(2L, exec("def x = (float)4; def y = (int)1; return x >> y")); - assertEquals(2L, exec("def x = (double)4; def y = (int)1; return x >> y")); assertEquals(2, exec("def x = (byte)4; def y = (long)1; return x >> y")); assertEquals(2, exec("def x = (short)4; def y = (long)1; return x >> y")); assertEquals(2, exec("def x = (char)4; def y = (long)1; return x >> y")); assertEquals(2, exec("def x = (int)4; def y = (long)1; return x >> y")); assertEquals(2L, exec("def x = (long)4; def y = (long)1; return x >> y")); - assertEquals(2L, exec("def x = (float)4; def y = (long)1; return x >> y")); - assertEquals(2L, exec("def x = (double)4; def y = (long)1; return x >> y")); - - assertEquals(2, exec("def x = (byte)4; def y = (float)1; return x >> y")); - assertEquals(2, exec("def x = (short)4; def y = (float)1; return x >> y")); - assertEquals(2, exec("def x = (char)4; def y = (float)1; return x >> y")); - assertEquals(2, exec("def x = (int)4; def y = (float)1; return x >> y")); - assertEquals(2L, exec("def x = (long)4; def y = (float)1; return x >> y")); - assertEquals(2L, exec("def x = (float)4; def y = (float)1; return x >> y")); - assertEquals(2L, exec("def x = (double)4; def y = (float)1; return x >> y")); - - assertEquals(2, exec("def x = (byte)4; def y = (double)1; return x >> y")); - assertEquals(2, exec("def x = (short)4; def y = (double)1; return x >> y")); - assertEquals(2, exec("def x = (char)4; def y = (double)1; return x >> y")); - assertEquals(2, exec("def x = (int)4; def y = (double)1; return x >> y")); - assertEquals(2L, exec("def x = (long)4; def y = (double)1; return x >> y")); - assertEquals(2L, exec("def x = (float)4; def y = (double)1; return x >> y")); - assertEquals(2L, exec("def x = (double)4; def y = (double)1; return x >> y")); assertEquals(2, exec("def x = (byte)4; def y = (byte)1; return x >> y")); assertEquals(2, exec("def x = (short)4; def y = (short)1; return x >> y")); assertEquals(2, exec("def x = (char)4; def y = (char)1; return x >> y")); assertEquals(2, exec("def x = (int)4; def y = (int)1; return x >> y")); assertEquals(2L, exec("def x = (long)4; def y = (long)1; return x >> y")); - assertEquals(2L, exec("def x = (float)4; def y = (float)1; return x >> y")); - assertEquals(2L, exec("def x = (double)4; def y = (double)1; return x >> y")); } public void testUsh() { @@ -518,262 +495,224 @@ public class DefOperationTests extends ScriptTestCase { assertEquals(2, exec("def x = (char)4; def y = (byte)1; return x >>> y")); assertEquals(2, exec("def x = (int)4; def y = (byte)1; return x >>> y")); assertEquals(2L, exec("def x = (long)4; def y = (byte)1; return x >>> y")); - assertEquals(2L, exec("def x = (float)4; def y = (byte)1; return x >>> y")); - assertEquals(2L, exec("def x = (double)4; def y = (byte)1; return x >>> y")); assertEquals(2, exec("def x = (byte)4; def y = (short)1; return x >>> y")); assertEquals(2, exec("def x = (short)4; def y = (short)1; return x >>> y")); assertEquals(2, exec("def x = (char)4; def y = (short)1; return x >>> y")); assertEquals(2, exec("def x = (int)4; def y = (short)1; return x >>> y")); assertEquals(2L, exec("def x = (long)4; def y = (short)1; return x >>> y")); - assertEquals(2L, exec("def x = (float)4; def y = (short)1; return x >>> y")); - assertEquals(2L, exec("def x = (double)4; def y = (short)1; return x >>> y")); assertEquals(2, exec("def x = (byte)4; def y = (char)1; return x >>> y")); assertEquals(2, exec("def x = (short)4; def y = (char)1; return x >>> y")); assertEquals(2, exec("def x = (char)4; def y = (char)1; return x >>> y")); assertEquals(2, exec("def x = (int)4; def y = (char)1; return x >>> y")); assertEquals(2L, exec("def x = (long)4; def y = (char)1; return x >>> y")); - assertEquals(2L, exec("def x = (float)4; def y = (char)1; return x >>> y")); - assertEquals(2L, exec("def x = (double)4; def y = (char)1; return x >>> y")); assertEquals(2, exec("def x = (byte)4; def y = (int)1; return x >>> y")); assertEquals(2, exec("def x = (short)4; def y = (int)1; return x >>> y")); assertEquals(2, exec("def x = (char)4; def y = (int)1; return x >>> y")); assertEquals(2, exec("def x = (int)4; def y = (int)1; return x >>> y")); assertEquals(2L, exec("def x = (long)4; def y = (int)1; return x >>> y")); - assertEquals(2L, exec("def x = (float)4; def y = (int)1; return x >>> y")); - assertEquals(2L, exec("def x = (double)4; def y = (int)1; return x >>> y")); assertEquals(2, exec("def x = (byte)4; def y = (long)1; return x >>> y")); assertEquals(2, exec("def x = (short)4; def y = (long)1; return x >>> y")); assertEquals(2, exec("def x = (char)4; def y = (long)1; return x >>> y")); assertEquals(2, exec("def x = (int)4; def y = (long)1; return x >>> y")); assertEquals(2L, exec("def x = (long)4; def y = (long)1; return x >>> y")); - assertEquals(2L, exec("def x = (float)4; def y = (long)1; return x >>> y")); - assertEquals(2L, exec("def x = (double)4; def y = (long)1; return x >>> y")); - - assertEquals(2, exec("def x = (byte)4; def y = (float)1; return x >>> y")); - assertEquals(2, exec("def x = (short)4; def y = (float)1; return x >>> y")); - assertEquals(2, exec("def x = (char)4; def y = (float)1; return x >>> y")); - assertEquals(2, exec("def x = (int)4; def y = (float)1; return x >>> y")); - assertEquals(2L, exec("def x = (long)4; def y = (float)1; return x >>> y")); - assertEquals(2L, exec("def x = (float)4; def y = (float)1; return x >>> y")); - assertEquals(2L, exec("def x = (double)4; def y = (float)1; return x >>> y")); - - assertEquals(2, exec("def x = (byte)4; def y = (double)1; return x >>> y")); - assertEquals(2, exec("def x = (short)4; def y = (double)1; return x >>> y")); - assertEquals(2, exec("def x = (char)4; def y = (double)1; return x >>> y")); - assertEquals(2, exec("def x = (int)4; def y = (double)1; return x >>> y")); - assertEquals(2L, exec("def x = (long)4; def y = (double)1; return x >>> y")); - assertEquals(2L, exec("def x = (float)4; def y = (double)1; return x >>> y")); - assertEquals(2L, exec("def x = (double)4; def y = (double)1; return x >>> y")); assertEquals(2, exec("def x = (byte)4; def y = (byte)1; return x >>> y")); assertEquals(2, exec("def x = (short)4; def y = (short)1; return x >>> y")); assertEquals(2, exec("def x = (char)4; def y = (char)1; return x >>> y")); assertEquals(2, exec("def x = (int)4; def y = (int)1; return x >>> y")); assertEquals(2L, exec("def x = (long)4; def y = (long)1; return x >>> y")); - assertEquals(2L, exec("def x = (float)4; def y = (float)1; return x >>> y")); - assertEquals(2L, exec("def x = (double)4; def y = (double)1; return x >>> y")); + } + + public void testBogusShifts() { + expectScriptThrows(ClassCastException.class, ()-> { + exec("def x = 1L; def y = 2F; return x << y;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("def x = 1; def y = 2D; return x << y;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("def x = 1F; def y = 2; return x << y;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("def x = 1D; def y = 2L; return x << y;"); + }); + + expectScriptThrows(ClassCastException.class, ()-> { + exec("def x = 1L; def y = 2F; return x >> y;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("def x = 1; def y = 2D; return x >> y;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("def x = 1F; def y = 2; return x >> y;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("def x = 1D; def y = 2L; return x >> y;"); + }); + + expectScriptThrows(ClassCastException.class, ()-> { + exec("def x = 1L; def y = 2F; return x >>> y;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("def x = 1; def y = 2D; return x >>> y;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("def x = 1F; def y = 2; return x >>> y;"); + }); + expectScriptThrows(ClassCastException.class, ()-> { + exec("def x = 1D; def y = 2L; return x >>> y;"); + }); } public void testAnd() { + expectScriptThrows(ClassCastException.class, () -> { + exec("def x = (float)4; def y = (byte)1; return x & y"); + }); + expectScriptThrows(ClassCastException.class, () -> { + exec("def x = (double)4; def y = (byte)1; return x & y"); + }); assertEquals(0, exec("def x = (byte)4; def y = (byte)1; return x & y")); assertEquals(0, exec("def x = (short)4; def y = (byte)1; return x & y")); assertEquals(0, exec("def x = (char)4; def y = (byte)1; return x & y")); assertEquals(0, exec("def x = (int)4; def y = (byte)1; return x & y")); assertEquals(0L, exec("def x = (long)4; def y = (byte)1; return x & y")); - assertEquals(0L, exec("def x = (float)4; def y = (byte)1; return x & y")); - assertEquals(0L, exec("def x = (double)4; def y = (byte)1; return x & y")); assertEquals(0, exec("def x = (byte)4; def y = (short)1; return x & y")); assertEquals(0, exec("def x = (short)4; def y = (short)1; return x & y")); assertEquals(0, exec("def x = (char)4; def y = (short)1; return x & y")); assertEquals(0, exec("def x = (int)4; def y = (short)1; return x & y")); assertEquals(0L, exec("def x = (long)4; def y = (short)1; return x & y")); - assertEquals(0L, exec("def x = (float)4; def y = (short)1; return x & y")); - assertEquals(0L, exec("def x = (double)4; def y = (short)1; return x & y")); assertEquals(0, exec("def x = (byte)4; def y = (char)1; return x & y")); assertEquals(0, exec("def x = (short)4; def y = (char)1; return x & y")); assertEquals(0, exec("def x = (char)4; def y = (char)1; return x & y")); assertEquals(0, exec("def x = (int)4; def y = (char)1; return x & y")); assertEquals(0L, exec("def x = (long)4; def y = (char)1; return x & y")); - assertEquals(0L, exec("def x = (float)4; def y = (char)1; return x & y")); - assertEquals(0L, exec("def x = (double)4; def y = (char)1; return x & y")); assertEquals(0, exec("def x = (byte)4; def y = (int)1; return x & y")); assertEquals(0, exec("def x = (short)4; def y = (int)1; return x & y")); assertEquals(0, exec("def x = (char)4; def y = (int)1; return x & y")); assertEquals(0, exec("def x = (int)4; def y = (int)1; return x & y")); assertEquals(0L, exec("def x = (long)4; def y = (int)1; return x & y")); - assertEquals(0L, exec("def x = (float)4; def y = (int)1; return x & y")); - assertEquals(0L, exec("def x = (double)4; def y = (int)1; return x & y")); assertEquals(0L, exec("def x = (byte)4; def y = (long)1; return x & y")); assertEquals(0L, exec("def x = (short)4; def y = (long)1; return x & y")); assertEquals(0L, exec("def x = (char)4; def y = (long)1; return x & y")); assertEquals(0L, exec("def x = (int)4; def y = (long)1; return x & y")); assertEquals(0L, exec("def x = (long)4; def y = (long)1; return x & y")); - assertEquals(0L, exec("def x = (float)4; def y = (long)1; return x & y")); - assertEquals(0L, exec("def x = (double)4; def y = (long)1; return x & y")); - - assertEquals(0L, exec("def x = (byte)4; def y = (float)1; return x & y")); - assertEquals(0L, exec("def x = (short)4; def y = (float)1; return x & y")); - assertEquals(0L, exec("def x = (char)4; def y = (float)1; return x & y")); - assertEquals(0L, exec("def x = (int)4; def y = (float)1; return x & y")); - assertEquals(0L, exec("def x = (long)4; def y = (float)1; return x & y")); - assertEquals(0L, exec("def x = (float)4; def y = (float)1; return x & y")); - assertEquals(0L, exec("def x = (double)4; def y = (float)1; return x & y")); - - assertEquals(0L, exec("def x = (byte)4; def y = (double)1; return x & y")); - assertEquals(0L, exec("def x = (short)4; def y = (double)1; return x & y")); - assertEquals(0L, exec("def x = (char)4; def y = (double)1; return x & y")); - assertEquals(0L, exec("def x = (int)4; def y = (double)1; return x & y")); - assertEquals(0L, exec("def x = (long)4; def y = (double)1; return x & y")); - assertEquals(0L, exec("def x = (float)4; def y = (double)1; return x & y")); - assertEquals(0L, exec("def x = (double)4; def y = (double)1; return x & y")); assertEquals(0, exec("def x = (byte)4; def y = (byte)1; return x & y")); assertEquals(0, exec("def x = (short)4; def y = (short)1; return x & y")); assertEquals(0, exec("def x = (char)4; def y = (char)1; return x & y")); assertEquals(0, exec("def x = (int)4; def y = (int)1; return x & y")); assertEquals(0L, exec("def x = (long)4; def y = (long)1; return x & y")); - assertEquals(0L, exec("def x = (float)4; def y = (float)1; return x & y")); - assertEquals(0L, exec("def x = (double)4; def y = (double)1; return x & y")); + + assertEquals(true, exec("def x = true; def y = true; return x & y")); + assertEquals(false, exec("def x = true; def y = false; return x & y")); + assertEquals(false, exec("def x = false; def y = true; return x & y")); + assertEquals(false, exec("def x = false; def y = false; return x & y")); } public void testXor() { + expectScriptThrows(ClassCastException.class, () -> { + exec("def x = (float)4; def y = (byte)1; return x ^ y"); + }); + expectScriptThrows(ClassCastException.class, () -> { + exec("def x = (double)4; def y = (byte)1; return x ^ y"); + }); assertEquals(5, exec("def x = (byte)4; def y = (byte)1; return x ^ y")); assertEquals(5, exec("def x = (short)4; def y = (byte)1; return x ^ y")); assertEquals(5, exec("def x = (char)4; def y = (byte)1; return x ^ y")); assertEquals(5, exec("def x = (int)4; def y = (byte)1; return x ^ y")); assertEquals(5L, exec("def x = (long)4; def y = (byte)1; return x ^ y")); - assertEquals(5L, exec("def x = (float)4; def y = (byte)1; return x ^ y")); - assertEquals(5L, exec("def x = (double)4; def y = (byte)1; return x ^ y")); assertEquals(5, exec("def x = (byte)4; def y = (short)1; return x ^ y")); assertEquals(5, exec("def x = (short)4; def y = (short)1; return x ^ y")); assertEquals(5, exec("def x = (char)4; def y = (short)1; return x ^ y")); assertEquals(5, exec("def x = (int)4; def y = (short)1; return x ^ y")); assertEquals(5L, exec("def x = (long)4; def y = (short)1; return x ^ y")); - assertEquals(5L, exec("def x = (float)4; def y = (short)1; return x ^ y")); - assertEquals(5L, exec("def x = (double)4; def y = (short)1; return x ^ y")); assertEquals(5, exec("def x = (byte)4; def y = (char)1; return x ^ y")); assertEquals(5, exec("def x = (short)4; def y = (char)1; return x ^ y")); assertEquals(5, exec("def x = (char)4; def y = (char)1; return x ^ y")); assertEquals(5, exec("def x = (int)4; def y = (char)1; return x ^ y")); assertEquals(5L, exec("def x = (long)4; def y = (char)1; return x ^ y")); - assertEquals(5L, exec("def x = (float)4; def y = (char)1; return x ^ y")); - assertEquals(5L, exec("def x = (double)4; def y = (char)1; return x ^ y")); assertEquals(5, exec("def x = (byte)4; def y = (int)1; return x ^ y")); assertEquals(5, exec("def x = (short)4; def y = (int)1; return x ^ y")); assertEquals(5, exec("def x = (char)4; def y = (int)1; return x ^ y")); assertEquals(5, exec("def x = (int)4; def y = (int)1; return x ^ y")); assertEquals(5L, exec("def x = (long)4; def y = (int)1; return x ^ y")); - assertEquals(5L, exec("def x = (float)4; def y = (int)1; return x ^ y")); - assertEquals(5L, exec("def x = (double)4; def y = (int)1; return x ^ y")); assertEquals(5L, exec("def x = (byte)4; def y = (long)1; return x ^ y")); assertEquals(5L, exec("def x = (short)4; def y = (long)1; return x ^ y")); assertEquals(5L, exec("def x = (char)4; def y = (long)1; return x ^ y")); assertEquals(5L, exec("def x = (int)4; def y = (long)1; return x ^ y")); assertEquals(5L, exec("def x = (long)4; def y = (long)1; return x ^ y")); - assertEquals(5L, exec("def x = (float)4; def y = (long)1; return x ^ y")); - assertEquals(5L, exec("def x = (double)4; def y = (long)1; return x ^ y")); - - assertEquals(5L, exec("def x = (byte)4; def y = (float)1; return x ^ y")); - assertEquals(5L, exec("def x = (short)4; def y = (float)1; return x ^ y")); - assertEquals(5L, exec("def x = (char)4; def y = (float)1; return x ^ y")); - assertEquals(5L, exec("def x = (int)4; def y = (float)1; return x ^ y")); - assertEquals(5L, exec("def x = (long)4; def y = (float)1; return x ^ y")); - assertEquals(5L, exec("def x = (float)4; def y = (float)1; return x ^ y")); - assertEquals(5L, exec("def x = (double)4; def y = (float)1; return x ^ y")); - - assertEquals(5L, exec("def x = (byte)4; def y = (double)1; return x ^ y")); - assertEquals(5L, exec("def x = (short)4; def y = (double)1; return x ^ y")); - assertEquals(5L, exec("def x = (char)4; def y = (double)1; return x ^ y")); - assertEquals(5L, exec("def x = (int)4; def y = (double)1; return x ^ y")); - assertEquals(5L, exec("def x = (long)4; def y = (double)1; return x ^ y")); - assertEquals(5L, exec("def x = (float)4; def y = (double)1; return x ^ y")); - assertEquals(5L, exec("def x = (double)4; def y = (double)1; return x ^ y")); assertEquals(5, exec("def x = (byte)4; def y = (byte)1; return x ^ y")); assertEquals(5, exec("def x = (short)4; def y = (short)1; return x ^ y")); assertEquals(5, exec("def x = (char)4; def y = (char)1; return x ^ y")); assertEquals(5, exec("def x = (int)4; def y = (int)1; return x ^ y")); assertEquals(5L, exec("def x = (long)4; def y = (long)1; return x ^ y")); - assertEquals(5L, exec("def x = (float)4; def y = (float)1; return x ^ y")); - assertEquals(5L, exec("def x = (double)4; def y = (double)1; return x ^ y")); + + assertEquals(false, exec("def x = true; def y = true; return x ^ y")); + assertEquals(true, exec("def x = true; def y = false; return x ^ y")); + assertEquals(true, exec("def x = false; def y = true; return x ^ y")); + assertEquals(false, exec("def x = false; def y = false; return x ^ y")); } public void testOr() { + expectScriptThrows(ClassCastException.class, () -> { + exec("def x = (float)4; def y = (byte)1; return x | y"); + }); + expectScriptThrows(ClassCastException.class, () -> { + exec("def x = (double)4; def y = (byte)1; return x | y"); + }); assertEquals(5, exec("def x = (byte)4; def y = (byte)1; return x | y")); assertEquals(5, exec("def x = (short)4; def y = (byte)1; return x | y")); assertEquals(5, exec("def x = (char)4; def y = (byte)1; return x | y")); assertEquals(5, exec("def x = (int)4; def y = (byte)1; return x | y")); assertEquals(5L, exec("def x = (long)4; def y = (byte)1; return x | y")); - assertEquals(5L, exec("def x = (float)4; def y = (byte)1; return x | y")); - assertEquals(5L, exec("def x = (double)4; def y = (byte)1; return x | y")); assertEquals(5, exec("def x = (byte)4; def y = (short)1; return x | y")); assertEquals(5, exec("def x = (short)4; def y = (short)1; return x | y")); assertEquals(5, exec("def x = (char)4; def y = (short)1; return x | y")); assertEquals(5, exec("def x = (int)4; def y = (short)1; return x | y")); assertEquals(5L, exec("def x = (long)4; def y = (short)1; return x | y")); - assertEquals(5L, exec("def x = (float)4; def y = (short)1; return x | y")); - assertEquals(5L, exec("def x = (double)4; def y = (short)1; return x | y")); assertEquals(5, exec("def x = (byte)4; def y = (char)1; return x | y")); assertEquals(5, exec("def x = (short)4; def y = (char)1; return x | y")); assertEquals(5, exec("def x = (char)4; def y = (char)1; return x | y")); assertEquals(5, exec("def x = (int)4; def y = (char)1; return x | y")); assertEquals(5L, exec("def x = (long)4; def y = (char)1; return x | y")); - assertEquals(5L, exec("def x = (float)4; def y = (char)1; return x | y")); - assertEquals(5L, exec("def x = (double)4; def y = (char)1; return x | y")); assertEquals(5, exec("def x = (byte)4; def y = (int)1; return x | y")); assertEquals(5, exec("def x = (short)4; def y = (int)1; return x | y")); assertEquals(5, exec("def x = (char)4; def y = (int)1; return x | y")); assertEquals(5, exec("def x = (int)4; def y = (int)1; return x | y")); assertEquals(5L, exec("def x = (long)4; def y = (int)1; return x | y")); - assertEquals(5L, exec("def x = (float)4; def y = (int)1; return x | y")); - assertEquals(5L, exec("def x = (double)4; def y = (int)1; return x | y")); assertEquals(5L, exec("def x = (byte)4; def y = (long)1; return x | y")); assertEquals(5L, exec("def x = (short)4; def y = (long)1; return x | y")); assertEquals(5L, exec("def x = (char)4; def y = (long)1; return x | y")); assertEquals(5L, exec("def x = (int)4; def y = (long)1; return x | y")); assertEquals(5L, exec("def x = (long)4; def y = (long)1; return x | y")); - assertEquals(5L, exec("def x = (float)4; def y = (long)1; return x | y")); - assertEquals(5L, exec("def x = (double)4; def y = (long)1; return x | y")); - - assertEquals(5L, exec("def x = (byte)4; def y = (float)1; return x | y")); - assertEquals(5L, exec("def x = (short)4; def y = (float)1; return x | y")); - assertEquals(5L, exec("def x = (char)4; def y = (float)1; return x | y")); - assertEquals(5L, exec("def x = (int)4; def y = (float)1; return x | y")); - assertEquals(5L, exec("def x = (long)4; def y = (float)1; return x | y")); - assertEquals(5L, exec("def x = (float)4; def y = (float)1; return x | y")); - assertEquals(5L, exec("def x = (double)4; def y = (float)1; return x | y")); - - assertEquals(5L, exec("def x = (byte)4; def y = (double)1; return x | y")); - assertEquals(5L, exec("def x = (short)4; def y = (double)1; return x | y")); - assertEquals(5L, exec("def x = (char)4; def y = (double)1; return x | y")); - assertEquals(5L, exec("def x = (int)4; def y = (double)1; return x | y")); - assertEquals(5L, exec("def x = (long)4; def y = (double)1; return x | y")); - assertEquals(5L, exec("def x = (float)4; def y = (double)1; return x | y")); - assertEquals(5L, exec("def x = (double)4; def y = (double)1; return x | y")); assertEquals(5, exec("def x = (byte)4; def y = (byte)1; return x | y")); assertEquals(5, exec("def x = (short)4; def y = (short)1; return x | y")); assertEquals(5, exec("def x = (char)4; def y = (char)1; return x | y")); assertEquals(5, exec("def x = (int)4; def y = (int)1; return x | y")); assertEquals(5L, exec("def x = (long)4; def y = (long)1; return x | y")); - assertEquals(5L, exec("def x = (float)4; def y = (float)1; return x | y")); - assertEquals(5L, exec("def x = (double)4; def y = (double)1; return x | y")); + + assertEquals(true, exec("def x = true; def y = true; return x | y")); + assertEquals(true, exec("def x = true; def y = false; return x | y")); + assertEquals(true, exec("def x = false; def y = true; return x | y")); + assertEquals(false, exec("def x = false; def y = false; return x | y")); } public void testEq() { @@ -792,11 +731,23 @@ public class DefOperationTests extends ScriptTestCase { assertEquals(false, exec("def x = (long)5; def y = (double)3; return x == y")); assertEquals(false, exec("def x = (float)6; def y = (double)2; return x == y")); assertEquals(false, exec("def x = (double)7; def y = (double)1; return x == y")); + + assertEquals(false, exec("def x = false; def y = true; return x == y")); + assertEquals(false, exec("def x = true; def y = false; return x == y")); + assertEquals(false, exec("def x = true; def y = null; return x == y")); + assertEquals(false, exec("def x = null; def y = true; return x == y")); + assertEquals(true, exec("def x = true; def y = true; return x == y")); + assertEquals(true, exec("def x = false; def y = false; return x == y")); assertEquals(true, exec("def x = new HashMap(); def y = new HashMap(); return x == y")); assertEquals(false, exec("def x = new HashMap(); x.put(3, 3); def y = new HashMap(); return x == y")); assertEquals(true, exec("def x = new HashMap(); x.put(3, 3); def y = new HashMap(); y.put(3, 3); return x == y")); assertEquals(true, exec("def x = new HashMap(); def y = x; x.put(3, 3); y.put(3, 3); return x == y")); + + assertEquals(true, exec("def x = true; def y = true; return x == y")); + assertEquals(false, exec("def x = true; def y = false; return x == y")); + assertEquals(false, exec("def x = false; def y = true; return x == y")); + assertEquals(true, exec("def x = false; def y = false; return x == y")); } public void testEqr() { @@ -807,6 +758,7 @@ public class DefOperationTests extends ScriptTestCase { assertEquals(false, exec("def x = (long)5; def y = (int)3; return x === y")); assertEquals(false, exec("def x = (float)6; def y = (int)2; return x === y")); assertEquals(false, exec("def x = (double)7; def y = (int)1; return x === y")); + assertEquals(false, exec("def x = false; def y = true; return x === y")); assertEquals(false, exec("def x = new HashMap(); def y = new HashMap(); return x === y")); assertEquals(false, exec("def x = new HashMap(); x.put(3, 3); def y = new HashMap(); return x === y")); @@ -835,6 +787,11 @@ public class DefOperationTests extends ScriptTestCase { assertEquals(true, exec("def x = new HashMap(); x.put(3, 3); def y = new HashMap(); return x != y")); assertEquals(false, exec("def x = new HashMap(); x.put(3, 3); def y = new HashMap(); y.put(3, 3); return x != y")); assertEquals(false, exec("def x = new HashMap(); def y = x; x.put(3, 3); y.put(3, 3); return x != y")); + + assertEquals(false, exec("def x = true; def y = true; return x != y")); + assertEquals(true, exec("def x = true; def y = false; return x != y")); + assertEquals(true, exec("def x = false; def y = true; return x != y")); + assertEquals(false, exec("def x = false; def y = false; return x != y")); } public void testNer() { diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/OrTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/OrTests.java index f287b1e4cf4..d27c019a798 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/OrTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/OrTests.java @@ -45,4 +45,13 @@ public class OrTests extends ScriptTestCase { assertEquals(5L | -12L, exec("return 5L | -12L;")); assertEquals(7L | 15L | 3L, exec("return 7L | 15L | 3L;")); } + + public void testIllegal() throws Exception { + expectScriptThrows(ClassCastException.class, () -> { + exec("float x = (float)4; int y = 1; return x | y"); + }); + expectScriptThrows(ClassCastException.class, () -> { + exec("double x = (double)4; int y = 1; return x | y"); + }); + } } diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/UnaryTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/UnaryTests.java index e670e23925b..ad2876778db 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/UnaryTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/UnaryTests.java @@ -39,4 +39,14 @@ public class UnaryTests extends ScriptTestCase { assertEquals(1, exec("return -(-1);")); assertEquals(0, exec("return -0;")); } + + public void testPlus() { + assertEquals(-1, exec("byte x = (byte)-1; return +x")); + assertEquals(-1, exec("short x = (short)-1; return +x")); + assertEquals(65535, exec("char x = (char)-1; return +x")); + assertEquals(-1, exec("int x = -1; return +x")); + assertEquals(-1L, exec("long x = -1L; return +x")); + assertEquals(-1.0F, exec("float x = -1F; return +x")); + assertEquals(-1.0, exec("double x = -1.0; return +x")); + } } diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/XorTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/XorTests.java index f5dd0a92011..16436240459 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/XorTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/XorTests.java @@ -59,4 +59,13 @@ public class XorTests extends ScriptTestCase { assertEquals(true, exec("return false ^ true;")); assertEquals(false, exec("return false ^ false;")); } + + public void testIllegal() throws Exception { + expectScriptThrows(ClassCastException.class, () -> { + exec("float x = (float)4; int y = 1; return x ^ y"); + }); + expectScriptThrows(ClassCastException.class, () -> { + exec("double x = (double)4; int y = 1; return x ^ y"); + }); + } }