From 80734c75b5f96fed862eb58c780bac861af74bee Mon Sep 17 00:00:00 2001 From: Robert Muir Date: Tue, 21 Jun 2016 08:35:12 -0400 Subject: [PATCH 1/4] get things started --- .../elasticsearch/painless/Augmentation.java | 126 ++++++++++++++++++ .../elasticsearch/painless/Definition.java | 126 ++++++++++++------ .../painless/node/LCallInvoke.java | 10 +- .../painless/node/LListShortcut.java | 12 +- .../painless/node/LMapShortcut.java | 12 +- .../painless/node/LShortcut.java | 12 +- .../elasticsearch/painless/node/SEach.java | 4 +- .../painless/node/SFunction.java | 2 +- .../org/elasticsearch/painless/java.lang.txt | 2 + .../painless/java.util.regex.txt | 2 +- .../org/elasticsearch/painless/java.util.txt | 3 +- .../painless/AugmentationTests.java | 35 +++++ 12 files changed, 262 insertions(+), 84 deletions(-) create mode 100644 modules/lang-painless/src/main/java/org/elasticsearch/painless/Augmentation.java create mode 100644 modules/lang-painless/src/test/java/org/elasticsearch/painless/AugmentationTests.java diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Augmentation.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Augmentation.java new file mode 100644 index 00000000000..456642eb8c5 --- /dev/null +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Augmentation.java @@ -0,0 +1,126 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.painless; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.ObjIntConsumer; +import java.util.function.Predicate; +import java.util.function.ToDoubleFunction; +import java.util.regex.Matcher; + +public class Augmentation { + public static int getLength(List receiver) { + return receiver.size(); + } + + public static String namedGroup(Matcher receiver, String name) { + return receiver.group(name); + } + + public static boolean any(Iterable receiver, Predicate predicate) { + for (T t : receiver) { + if (predicate.test(t)) { + return true; + } + } + return false; + } + + public static int count(Iterable receiver, Predicate predicate) { + int count = 0; + for (T t : receiver) { + if (predicate.test(t)) { + count++; + } + } + return count; + } + + public static Iterable each(Iterable receiver, Consumer consumer) { + receiver.forEach(consumer); + return receiver; + } + + public static Iterable eachWithIndex(Iterable receiver, ObjIntConsumer consumer) { + int count = 0; + for (T t : receiver) { + consumer.accept(t, count++); + } + return receiver; + } + + public static boolean every(Iterable receiver, Predicate predicate) { + for (T t : receiver) { + if (predicate.test(t) == false) { + return false; + } + } + return true; + } + + public static List findResults(Iterable receiver, Function filter) { + List list = new ArrayList<>(); + for (T t: receiver) { + U result = filter.apply(t); + if (result != null) { + list.add(result); + } + } + return list; + } + + public static Map> groupBy(Iterable receiver, Function mapper) { + Map> map = new LinkedHashMap<>(); + for (T t : receiver) { + U mapped = mapper.apply(t); + List results = map.get(mapped); + if (results == null) { + results = new ArrayList<>(); + map.put(mapped, results); + } + results.add(t); + } + return map; + } + + public static String join(Iterable receiver, String separator) { + StringBuilder sb = new StringBuilder(); + for (T t : receiver) { + if (sb.length() > 0) { + sb.append(separator); + } + sb.append(t); + } + return sb.toString(); + } + + public static double sum(Iterable receiver, ToDoubleFunction function) { + double sum = 0; + for (T t : receiver) { + sum += function.applyAsDouble(t); + } + return sum; + } +} diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Definition.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Definition.java index 761a2afeeb1..376be44e28b 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Definition.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Definition.java @@ -186,15 +186,17 @@ public final class Definition { public static class Method { public final String name; public final Struct owner; + public final boolean augmentation; public final Type rtn; public final List arguments; public final org.objectweb.asm.commons.Method method; public final int modifiers; public final MethodHandle handle; - public Method(String name, Struct owner, Type rtn, List arguments, + public Method(String name, Struct owner, boolean augmentation, Type rtn, List arguments, org.objectweb.asm.commons.Method method, int modifiers, MethodHandle handle) { this.name = name; + this.augmentation = augmentation; this.owner = owner; this.rtn = rtn; this.arguments = Collections.unmodifiableList(arguments); @@ -217,7 +219,15 @@ public final class Definition { // otherwise compute it final Class params[]; final Class returnValue; - if (Modifier.isStatic(modifiers)) { + if (augmentation) { + // virtual/interface method disguised as static + params = new Class[1 + arguments.size()]; + params[0] = Augmentation.class; + for (int i = 0; i < arguments.size(); i++) { + params[i + 1] = arguments.get(i).clazz; + } + returnValue = rtn.clazz; + } else if (Modifier.isStatic(modifiers)) { // static method: straightforward copy params = new Class[arguments.size()]; for (int i = 0; i < arguments.size(); i++) { @@ -242,6 +252,24 @@ public final class Definition { } return MethodType.methodType(returnValue, params); } + + public void write(MethodWriter writer) { + final org.objectweb.asm.Type type; + if (augmentation) { + assert java.lang.reflect.Modifier.isStatic(modifiers); + type = org.objectweb.asm.Type.getType(Augmentation.class); + } else { + type = owner.type; + } + + if (java.lang.reflect.Modifier.isStatic(modifiers)) { + writer.invokeStatic(type, method); + } else if (java.lang.reflect.Modifier.isInterface(owner.clazz.getModifiers())) { + writer.invokeInterface(type, method); + } else { + writer.invokeVirtual(type, method); + } + } } public static final class Field { @@ -690,7 +718,7 @@ public final class Definition { " with arguments " + Arrays.toString(classes) + "."); } - final Method constructor = new Method(name, owner, returnType, Arrays.asList(args), asm, reflect.getModifiers(), handle); + final Method constructor = new Method(name, owner, false, returnType, Arrays.asList(args), asm, reflect.getModifiers(), handle); owner.constructors.put(methodKey, constructor); } @@ -734,24 +762,20 @@ public final class Definition { } addConstructorInternal(className, "", args); } else { - if (methodName.indexOf('/') >= 0) { - String nameAndAlias[] = methodName.split("/"); - if (nameAndAlias.length != 2) { - throw new IllegalArgumentException("Currently only two aliases are allowed!"); - } - addMethodInternal(className, nameAndAlias[0], nameAndAlias[1], rtn, args); + if (methodName.indexOf("*") >= 0) { + addMethodInternal(className, methodName.substring(0, methodName.length() - 1), true, rtn, args); } else { - addMethodInternal(className, methodName, null, rtn, args); + addMethodInternal(className, methodName, false, rtn, args); } } } else { // field - addFieldInternal(className, elements[1], null, rtn); + addFieldInternal(className, elements[1], rtn); } } - private final void addMethodInternal(final String struct, final String name, final String alias, - final Type rtn, final Type[] args) { + private final void addMethodInternal(String struct, String name, boolean augmentation, + Type rtn, Type[] args) { final Struct owner = structsMap.get(struct); if (owner == null) { @@ -777,20 +801,32 @@ public final class Definition { "Duplicate method signature [" + methodKey + "] found within the struct [" + owner.name + "]."); } - final Class[] classes = new Class[args.length]; - - for (int count = 0; count < classes.length; ++count) { - classes[count] = args[count].clazz; + final Class implClass; + final Class[] params; + + if (augmentation == false) { + implClass = owner.clazz; + params = new Class[args.length]; + for (int count = 0; count < args.length; ++count) { + params[count] = args[count].clazz; + } + } else { + implClass = Augmentation.class; + params = new Class[args.length + 1]; + params[0] = owner.clazz; + for (int count = 0; count < args.length; ++count) { + params[count+1] = args[count].clazz; + } } - + final java.lang.reflect.Method reflect; try { - reflect = owner.clazz.getMethod(alias == null ? name : alias, classes); - } catch (final NoSuchMethodException exception) { - throw new IllegalArgumentException("Method [" + (alias == null ? name : alias) + + reflect = implClass.getMethod(name, params); + } catch (NoSuchMethodException exception) { + throw new IllegalArgumentException("Method [" + name + "] not found for class [" + owner.clazz.getName() + "]" + - " with arguments " + Arrays.toString(classes) + "."); + " with arguments " + Arrays.toString(params) + "."); } if (!reflect.getReturnType().equals(rtn.clazz)) { @@ -805,25 +841,24 @@ public final class Definition { MethodHandle handle; try { - handle = MethodHandles.publicLookup().in(owner.clazz).unreflect(reflect); + handle = MethodHandles.publicLookup().in(implClass).unreflect(reflect); } catch (final IllegalAccessException exception) { - throw new IllegalArgumentException("Method [" + (alias == null ? name : alias) + "]" + - " not found for class [" + owner.clazz.getName() + "]" + - " with arguments " + Arrays.toString(classes) + "."); + throw new IllegalArgumentException("Method [" + name + "]" + + " not found for class [" + implClass.getName() + "]" + + " with arguments " + Arrays.toString(params) + "."); } final int modifiers = reflect.getModifiers(); - final Method method = new Method(name, owner, rtn, Arrays.asList(args), asm, modifiers, handle); + final Method method = new Method(name, owner, augmentation, rtn, Arrays.asList(args), asm, modifiers, handle); - if (java.lang.reflect.Modifier.isStatic(modifiers)) { + if (augmentation == false && java.lang.reflect.Modifier.isStatic(modifiers)) { owner.staticMethods.put(methodKey, method); } else { owner.methods.put(methodKey, method); } } - private final void addFieldInternal(final String struct, final String name, final String alias, - final Type type) { + private final void addFieldInternal(String struct, String name, Type type) { final Struct owner = structsMap.get(struct); if (owner == null) { @@ -844,9 +879,9 @@ public final class Definition { java.lang.reflect.Field reflect; try { - reflect = owner.clazz.getField(alias == null ? name : alias); + reflect = owner.clazz.getField(name); } catch (final NoSuchFieldException exception) { - throw new IllegalArgumentException("Field [" + (alias == null ? name : alias) + "]" + + throw new IllegalArgumentException("Field [" + name + "]" + " not found for class [" + owner.clazz.getName() + "]."); } @@ -862,7 +897,7 @@ public final class Definition { setter = MethodHandles.publicLookup().unreflectSetter(reflect); } } catch (final IllegalAccessException exception) { - throw new IllegalArgumentException("Getter/Setter [" + (alias == null ? name : alias) + "]" + + throw new IllegalArgumentException("Getter/Setter [" + name + "]" + " not found for class [" + owner.clazz.getName() + "]."); } @@ -875,9 +910,9 @@ public final class Definition { " within the struct [" + owner.name + "] is not final."); } - owner.staticMembers.put(alias == null ? name : alias, field); + owner.staticMembers.put(name, field); } else { - owner.members.put(alias == null ? name : alias, field); + owner.members.put(name, field); } } @@ -915,11 +950,24 @@ public final class Definition { // https://bugs.openjdk.java.net/browse/JDK-8072746 } else { try { - Class arguments[] = new Class[method.arguments.size()]; - for (int i = 0; i < method.arguments.size(); i++) { - arguments[i] = method.arguments.get(i).clazz; + // TODO: we *have* to remove all these public members and use getter methods to encapsulate! + final Class impl; + final Class arguments[]; + if (method.augmentation) { + impl = Augmentation.class; + arguments = new Class[method.arguments.size() + 1]; + arguments[0] = method.owner.clazz; + for (int i = 0; i < method.arguments.size(); i++) { + arguments[i + 1] = method.arguments.get(i).clazz; + } + } else { + impl = owner.clazz; + arguments = new Class[method.arguments.size()]; + for (int i = 0; i < method.arguments.size(); i++) { + arguments[i] = method.arguments.get(i).clazz; + } } - java.lang.reflect.Method m = owner.clazz.getMethod(method.method.getName(), arguments); + java.lang.reflect.Method m = impl.getMethod(method.method.getName(), arguments); if (m.getReturnType() != method.rtn.clazz) { throw new IllegalStateException("missing covariant override for: " + m + " in " + owner.name); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LCallInvoke.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LCallInvoke.java index 21bfff65e8d..1056af2aaca 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LCallInvoke.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LCallInvoke.java @@ -122,14 +122,8 @@ public final class LCallInvoke extends ALink { for (AExpression argument : arguments) { argument.write(writer, globals); } - - if (java.lang.reflect.Modifier.isStatic(method.modifiers)) { - writer.invokeStatic(method.owner.type, method.method); - } else if (java.lang.reflect.Modifier.isInterface(method.owner.clazz.getModifiers())) { - writer.invokeInterface(method.owner.type, method.method); - } else { - writer.invokeVirtual(method.owner.type, method.method); - } + + method.write(writer); } @Override diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LListShortcut.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LListShortcut.java index f2863ce3396..6ef8aedb0bf 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LListShortcut.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LListShortcut.java @@ -92,11 +92,7 @@ final class LListShortcut extends ALink { void load(MethodWriter writer, Globals globals) { writer.writeDebugInfo(location); - if (java.lang.reflect.Modifier.isInterface(getter.owner.clazz.getModifiers())) { - writer.invokeInterface(getter.owner.type, getter.method); - } else { - writer.invokeVirtual(getter.owner.type, getter.method); - } + getter.write(writer); if (!getter.rtn.clazz.equals(getter.handle.type().returnType())) { writer.checkCast(getter.rtn.type); @@ -107,11 +103,7 @@ final class LListShortcut extends ALink { void store(MethodWriter writer, Globals globals) { writer.writeDebugInfo(location); - if (java.lang.reflect.Modifier.isInterface(setter.owner.clazz.getModifiers())) { - writer.invokeInterface(setter.owner.type, setter.method); - } else { - writer.invokeVirtual(setter.owner.type, setter.method); - } + setter.write(writer); writer.writePop(setter.rtn.sort.size); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LMapShortcut.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LMapShortcut.java index 3bc9ab57a37..52d66b0fe75 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LMapShortcut.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LMapShortcut.java @@ -91,11 +91,7 @@ final class LMapShortcut extends ALink { void load(MethodWriter writer, Globals globals) { writer.writeDebugInfo(location); - if (java.lang.reflect.Modifier.isInterface(getter.owner.clazz.getModifiers())) { - writer.invokeInterface(getter.owner.type, getter.method); - } else { - writer.invokeVirtual(getter.owner.type, getter.method); - } + getter.write(writer); if (!getter.rtn.clazz.equals(getter.handle.type().returnType())) { writer.checkCast(getter.rtn.type); @@ -106,11 +102,7 @@ final class LMapShortcut extends ALink { void store(MethodWriter writer, Globals globals) { writer.writeDebugInfo(location); - if (java.lang.reflect.Modifier.isInterface(setter.owner.clazz.getModifiers())) { - writer.invokeInterface(setter.owner.type, setter.method); - } else { - writer.invokeVirtual(setter.owner.type, setter.method); - } + setter.write(writer); writer.writePop(setter.rtn.sort.size); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LShortcut.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LShortcut.java index 7f97b446382..6d669adc658 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LShortcut.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LShortcut.java @@ -95,11 +95,7 @@ final class LShortcut extends ALink { void load(MethodWriter writer, Globals globals) { writer.writeDebugInfo(location); - if (java.lang.reflect.Modifier.isInterface(getter.owner.clazz.getModifiers())) { - writer.invokeInterface(getter.owner.type, getter.method); - } else { - writer.invokeVirtual(getter.owner.type, getter.method); - } + getter.write(writer); if (!getter.rtn.clazz.equals(getter.handle.type().returnType())) { writer.checkCast(getter.rtn.type); @@ -110,11 +106,7 @@ final class LShortcut extends ALink { void store(MethodWriter writer, Globals globals) { writer.writeDebugInfo(location); - if (java.lang.reflect.Modifier.isInterface(setter.owner.clazz.getModifiers())) { - writer.invokeInterface(setter.owner.type, setter.method); - } else { - writer.invokeVirtual(setter.owner.type, setter.method); - } + setter.write(writer); writer.writePop(setter.rtn.sort.size); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java index a90baac3203..772c4af4c48 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java @@ -206,10 +206,8 @@ public class SEach extends AStatement { Type itr = Definition.getType("Iterator"); org.objectweb.asm.Type methodType = org.objectweb.asm.Type.getMethodType(itr.type, Definition.DEF_TYPE.type); writer.invokeDefCall("iterator", methodType, DefBootstrap.ITERATOR); - } else if (java.lang.reflect.Modifier.isInterface(method.owner.clazz.getModifiers())) { - writer.invokeInterface(method.owner.type, method.method); } else { - writer.invokeVirtual(method.owner.type, method.method); + method.write(writer); } writer.visitVarInsn(iterator.type.type.getOpcode(Opcodes.ISTORE), iterator.getSlot()); diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SFunction.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SFunction.java index ba078d03dbd..46dc8af5ab9 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SFunction.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SFunction.java @@ -114,7 +114,7 @@ public class SFunction extends AStatement { org.objectweb.asm.commons.Method method = new org.objectweb.asm.commons.Method(name, MethodType.methodType(rtnType.clazz, paramClasses).toMethodDescriptorString()); - this.method = new Method(name, null, rtnType, paramTypes, method, Modifier.STATIC | Modifier.PRIVATE, null); + this.method = new Method(name, null, false, rtnType, paramTypes, method, Modifier.STATIC | Modifier.PRIVATE, null); } @Override diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.lang.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.lang.txt index ec1a46e633f..859a7b314dd 100644 --- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.lang.txt +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.lang.txt @@ -50,6 +50,8 @@ class Iterable -> java.lang.Iterable { void forEach(Consumer) Iterator iterator() Spliterator spliterator() + # some adaptations of groovy methods + boolean any*(Predicate) } # Readable: i/o diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.util.regex.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.util.regex.txt index 6befc865731..aaea78a7a96 100644 --- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.util.regex.txt +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.util.regex.txt @@ -42,7 +42,7 @@ class Matcher -> java.util.regex.Matcher extends Object { boolean find(int) String group() String group(int) - String namedGroup/group(String) + String namedGroup*(String) int groupCount() boolean hasAnchoringBounds() boolean hasTransparentBounds() diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.util.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.util.txt index cccbf60f1a3..2a41b2c2c13 100644 --- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.util.txt +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.util.txt @@ -114,8 +114,7 @@ class List -> java.util.List extends Collection,Iterable { def remove(int) void replaceAll(UnaryOperator) def set(int,def) - # TODO: wtf? - int getLength/size() + int getLength*() void sort(Comparator) List subList(int,int) } diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/AugmentationTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/AugmentationTests.java new file mode 100644 index 00000000000..ccc3620b1c9 --- /dev/null +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/AugmentationTests.java @@ -0,0 +1,35 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.painless; + +public class AugmentationTests extends ScriptTestCase { + + @AwaitsFix(bugUrl = "rmuir is working on this") + public void testCapturingReference() { + assertEquals(1, exec("int foo(Supplier t) { return t.get() }" + + "def l = new ArrayList(); l.add(1);" + + "return foo(l::getLength);")); + } + + + public void testIterable_Any() { + assertEquals(true, exec("List l = new ArrayList(); l.add(1); l.any(x -> x == 1)")); + } +} From 42d60f9f28ba159a91e991c806831f375006fdb8 Mon Sep 17 00:00:00 2001 From: Robert Muir Date: Tue, 21 Jun 2016 11:25:43 -0400 Subject: [PATCH 2/4] maps n lists --- .../elasticsearch/painless/Augmentation.java | 290 +++++++++++++++++- .../java/org/elasticsearch/painless/Def.java | 4 +- .../elasticsearch/painless/Definition.java | 4 +- .../elasticsearch/painless/FunctionRef.java | 29 +- .../painless/WriterConstants.java | 2 + .../painless/node/ECapturingFunctionRef.java | 2 +- .../painless/node/EFunctionRef.java | 4 +- .../elasticsearch/painless/node/ELambda.java | 6 +- .../org/elasticsearch/painless/java.lang.txt | 7 + .../org/elasticsearch/painless/java.util.txt | 22 ++ .../painless/AugmentationTests.java | 149 ++++++++- 11 files changed, 491 insertions(+), 28 deletions(-) diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Augmentation.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Augmentation.java index 456642eb8c5..4bca673b4dc 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Augmentation.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Augmentation.java @@ -20,9 +20,14 @@ package org.elasticsearch.painless; import java.util.ArrayList; +import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.TreeMap; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.BiPredicate; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.ObjIntConsumer; @@ -30,15 +35,26 @@ import java.util.function.Predicate; import java.util.function.ToDoubleFunction; import java.util.regex.Matcher; +/** Additional methods added to classes. These must be static methods with receiver as first argument */ public class Augmentation { + + // static methods only! + private Augmentation() {} + + /** Exposes List.size() as getLength(), so that .length shortcut works on lists */ public static int getLength(List receiver) { return receiver.size(); } - + + /** Exposes Matcher.group(String) as namedGroup(String), so it doesn't conflict with group(int) */ public static String namedGroup(Matcher receiver, String name) { return receiver.group(name); } + // some groovy methods on iterable + // see http://docs.groovy-lang.org/latest/html/groovy-jdk/java/lang/Iterable.html + + /** Iterates over the contents of an iterable, and checks whether a predicate is valid for at least one element. */ public static boolean any(Iterable receiver, Predicate predicate) { for (T t : receiver) { if (predicate.test(t)) { @@ -48,6 +64,7 @@ public class Augmentation { return false; } + /** Counts the number of occurrences which satisfy the given predicate from inside this Iterable. */ public static int count(Iterable receiver, Predicate predicate) { int count = 0; for (T t : receiver) { @@ -58,12 +75,20 @@ public class Augmentation { return count; } - public static Iterable each(Iterable receiver, Consumer consumer) { + // instead of covariant overrides for every possibility, we just return receiver as 'def' for now + // that way if someone chains the calls, everything works. + + /** Iterates through an Iterable, passing each item to the given consumer. */ + public static Object each(Iterable receiver, Consumer consumer) { receiver.forEach(consumer); return receiver; } - public static Iterable eachWithIndex(Iterable receiver, ObjIntConsumer consumer) { + /** + * Iterates through an iterable type, passing each item and the item's index + * (a counter starting at zero) to the given consumer. + */ + public static Object eachWithIndex(Iterable receiver, ObjIntConsumer consumer) { int count = 0; for (T t : receiver) { consumer.accept(t, count++); @@ -71,6 +96,9 @@ public class Augmentation { return receiver; } + /** + * Used to determine if the given predicate is valid (i.e. returns true for all items in this iterable). + */ public static boolean every(Iterable receiver, Predicate predicate) { for (T t : receiver) { if (predicate.test(t) == false) { @@ -80,6 +108,10 @@ public class Augmentation { return true; } + /** + * Iterates through the Iterable transforming items using the supplied function and + * collecting any non-null results. + */ public static List findResults(Iterable receiver, Function filter) { List list = new ArrayList<>(); for (T t: receiver) { @@ -91,6 +123,9 @@ public class Augmentation { return list; } + /** + * Sorts all Iterable members into groups determined by the supplied mapping function. + */ public static Map> groupBy(Iterable receiver, Function mapper) { Map> map = new LinkedHashMap<>(); for (T t : receiver) { @@ -105,6 +140,10 @@ public class Augmentation { return map; } + /** + * Concatenates the toString() representation of each item in this Iterable, + * with the given String as a separator between each item. + */ public static String join(Iterable receiver, String separator) { StringBuilder sb = new StringBuilder(); for (T t : receiver) { @@ -116,6 +155,9 @@ public class Augmentation { return sb.toString(); } + /** + * Sums the result of applying a function to each item of an Iterable. + */ public static double sum(Iterable receiver, ToDoubleFunction function) { double sum = 0; for (T t : receiver) { @@ -123,4 +165,246 @@ public class Augmentation { } return sum; } + + // some groovy methods on collection + // see http://docs.groovy-lang.org/latest/html/groovy-jdk/java/util/Collection.html + + /** + * Iterates through this collection transforming each entry into a new value using + * the function, returning a list of transformed values. + */ + public static List collect(Collection receiver, Function function) { + List list = new ArrayList<>(); + for (T t : receiver) { + list.add(function.apply(t)); + } + return list; + } + + /** + * Iterates through this collection transforming each entry into a new value using + * the function, adding the values to the specified collection. + */ + public static Object collect(Collection receiver, Collection collection, Function function) { + for (T t : receiver) { + collection.add(function.apply(t)); + } + return collection; + } + + /** + * Finds the first value matching the predicate, or returns null. + */ + public static T find(Collection receiver, Predicate predicate) { + for (T t : receiver) { + if (predicate.test(t)) { + return t; + } + } + return null; + } + + /** + * Finds all values matching the predicate, returns as a list + */ + public static List findAll(Collection receiver, Predicate predicate) { + List list = new ArrayList<>(); + for (T t : receiver) { + if (predicate.test(t)) { + list.add(t); + } + } + return list; + } + + /** + * Iterates through the collection calling the given function for each item + * but stopping once the first non-null result is found and returning that result. + * If all results are null, null is returned. + */ + public static Object findResult(Collection receiver, Function function) { + return findResult(receiver, null, function); + } + + /** + * Iterates through the collection calling the given function for each item + * but stopping once the first non-null result is found and returning that result. + * If all results are null, defaultResult is returned. + */ + public static Object findResult(Collection receiver, Object defaultResult, Function function) { + for (T t : receiver) { + U value = function.apply(t); + if (value != null) { + return value; + } + } + return defaultResult; + } + + /** + * Splits all items into two collections based on the predicate. + * The first list contains all items which match the closure expression. The second list all those that don't. + */ + public static List> split(Collection receiver, Predicate predicate) { + List matched = new ArrayList<>(); + List unmatched = new ArrayList<>(); + List> result = new ArrayList<>(2); + result.add(matched); + result.add(unmatched); + for (T t : receiver) { + if (predicate.test(t)) { + matched.add(t); + } else { + unmatched.add(t); + } + } + return result; + } + + // some groovy methods on map + // see http://docs.groovy-lang.org/latest/html/groovy-jdk/java/util/Map.html + + /** + * Iterates through this map transforming each entry into a new value using + * the function, returning a list of transformed values. + */ + public static List collect(Map receiver, BiFunction function) { + List list = new ArrayList<>(); + for (Map.Entry kvPair : receiver.entrySet()) { + list.add(function.apply(kvPair.getKey(), kvPair.getValue())); + } + return list; + } + + /** + * Iterates through this map transforming each entry into a new value using + * the function, adding the values to the specified collection. + */ + public static Object collect(Map receiver, Collection collection, BiFunction function) { + for (Map.Entry kvPair : receiver.entrySet()) { + collection.add(function.apply(kvPair.getKey(), kvPair.getValue())); + } + return collection; + } + + /** Counts the number of occurrences which satisfy the given predicate from inside this Map */ + public static int count(Map receiver, BiPredicate predicate) { + int count = 0; + for (Map.Entry kvPair : receiver.entrySet()) { + if (predicate.test(kvPair.getKey(), kvPair.getValue())) { + count++; + } + } + return count; + } + + /** Iterates through a Map, passing each item to the given consumer. */ + public static Object each(Map receiver, BiConsumer consumer) { + receiver.forEach(consumer); + return receiver; + } + + /** + * Used to determine if the given predicate is valid (i.e. returns true for all items in this map). + */ + public static boolean every(Map receiver, BiPredicate predicate) { + for (Map.Entry kvPair : receiver.entrySet()) { + if (predicate.test(kvPair.getKey(), kvPair.getValue()) == false) { + return false; + } + } + return true; + } + + /** + * Finds the first entry matching the predicate, or returns null. + */ + public static Map.Entry find(Map receiver, BiPredicate predicate) { + for (Map.Entry kvPair : receiver.entrySet()) { + if (predicate.test(kvPair.getKey(), kvPair.getValue())) { + return kvPair; + } + } + return null; + } + + /** + * Finds all values matching the predicate, returns as a map. + */ + public static Map findAll(Map receiver, BiPredicate predicate) { + // try to preserve some properties of the receiver (see the groovy javadocs) + final Map map; + if (receiver instanceof TreeMap) { + map = new TreeMap<>(); + } else { + map = new LinkedHashMap<>(); + } + for (Map.Entry kvPair : receiver.entrySet()) { + if (predicate.test(kvPair.getKey(), kvPair.getValue())) { + map.put(kvPair.getKey(), kvPair.getValue()); + } + } + return map; + } + + /** + * Iterates through the map calling the given function for each item + * but stopping once the first non-null result is found and returning that result. + * If all results are null, null is returned. + */ + public static Object findResult(Map receiver, BiFunction function) { + return findResult(receiver, null, function); + } + + /** + * Iterates through the map calling the given function for each item + * but stopping once the first non-null result is found and returning that result. + * If all results are null, defaultResult is returned. + */ + public static Object findResult(Map receiver, Object defaultResult, BiFunction function) { + for (Map.Entry kvPair : receiver.entrySet()) { + T value = function.apply(kvPair.getKey(), kvPair.getValue()); + if (value != null) { + return value; + } + } + return defaultResult; + } + + /** + * Iterates through the map transforming items using the supplied function and + * collecting any non-null results. + */ + public static List findResults(Map receiver, BiFunction filter) { + List list = new ArrayList<>(); + for (Map.Entry kvPair : receiver.entrySet()) { + T result = filter.apply(kvPair.getKey(), kvPair.getValue()); + if (result != null) { + list.add(result); + } + } + return list; + } + + /** + * Sorts all Map members into groups determined by the supplied mapping function. + */ + public static Map> groupBy(Map receiver, BiFunction mapper) { + Map> map = new LinkedHashMap<>(); + for (Map.Entry kvPair : receiver.entrySet()) { + T mapped = mapper.apply(kvPair.getKey(), kvPair.getValue()); + Map results = map.get(mapped); + if (results == null) { + // try to preserve some properties of the receiver (see the groovy javadocs) + if (receiver instanceof TreeMap) { + results = new TreeMap<>(); + } else { + results = new LinkedHashMap<>(); + } + map.put(mapped, results); + } + results.put(kvPair.getKey(), kvPair.getValue()); + } + return map; + } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java index 5461771bca6..cd761d0ad44 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java @@ -350,10 +350,10 @@ public final class Def { } throw new IllegalArgumentException("Unknown call [" + call + "] with [" + arity + "] arguments."); } - ref = new FunctionRef(clazz, interfaceMethod, handle, captures); + ref = new FunctionRef(clazz, interfaceMethod, handle, captures.length); } else { // whitelist lookup - ref = new FunctionRef(clazz, type, call, captures); + ref = new FunctionRef(clazz, type, call, captures.length); } final CallSite callSite; if (ref.needsBridges()) { diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Definition.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Definition.java index 376be44e28b..8aa4da4015f 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Definition.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Definition.java @@ -257,7 +257,7 @@ public final class Definition { final org.objectweb.asm.Type type; if (augmentation) { assert java.lang.reflect.Modifier.isStatic(modifiers); - type = org.objectweb.asm.Type.getType(Augmentation.class); + type = WriterConstants.AUGMENTATION_TYPE; } else { type = owner.type; } @@ -825,7 +825,7 @@ public final class Definition { reflect = implClass.getMethod(name, params); } catch (NoSuchMethodException exception) { throw new IllegalArgumentException("Method [" + name + - "] not found for class [" + owner.clazz.getName() + "]" + + "] not found for class [" + implClass.getName() + "]" + " with arguments " + Arrays.toString(params) + "."); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/FunctionRef.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/FunctionRef.java index 72676f1c04c..d5e02e12058 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/FunctionRef.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/FunctionRef.java @@ -53,10 +53,10 @@ public class FunctionRef { * @param expected interface type to implement. * @param type the left hand side of a method reference expression * @param call the right hand side of a method reference expression - * @param captures captured arguments + * @param numCaptures number of captured arguments */ - public FunctionRef(Definition.Type expected, String type, String call, Class... captures) { - this(expected, expected.struct.getFunctionalMethod(), lookup(expected, type, call, captures.length > 0), captures); + public FunctionRef(Definition.Type expected, String type, String call, int numCaptures) { + this(expected, expected.struct.getFunctionalMethod(), lookup(expected, type, call, numCaptures > 0), numCaptures); } /** @@ -64,13 +64,16 @@ public class FunctionRef { * @param expected interface type to implement * @param method functional interface method * @param impl implementation method - * @param captures captured arguments + * @param numCaptures number of captured arguments */ - public FunctionRef(Definition.Type expected, Definition.Method method, Definition.Method impl, Class... captures) { + public FunctionRef(Definition.Type expected, Definition.Method method, Definition.Method impl, int numCaptures) { // e.g. compareTo invokedName = method.name; // e.g. (Object)Comparator - invokedType = MethodType.methodType(expected.clazz, captures); + MethodType implType = impl.getMethodType(); + // only include captured parameters as arguments + invokedType = MethodType.methodType(expected.clazz, + implType.dropParameterTypes(numCaptures, implType.parameterCount())); // e.g. (Object,Object)int interfaceMethodType = method.getMethodType().dropParameterTypes(0, 1); @@ -90,6 +93,9 @@ public class FunctionRef { // owner == null: script class itself ownerIsInterface = false; owner = WriterConstants.CLASS_TYPE.getInternalName(); + } else if (impl.augmentation) { + ownerIsInterface = false; + owner = WriterConstants.AUGMENTATION_TYPE.getInternalName(); } else { ownerIsInterface = impl.owner.clazz.isInterface(); owner = impl.owner.type.getInternalName(); @@ -98,7 +104,7 @@ public class FunctionRef { implMethod = impl.handle; // remove any prepended captured arguments for the 'natural' signature. - samMethodType = adapt(interfaceMethodType, impl.getMethodType().dropParameterTypes(0, captures.length)); + samMethodType = adapt(interfaceMethodType, impl.getMethodType().dropParameterTypes(0, numCaptures)); } /** @@ -106,11 +112,14 @@ public class FunctionRef { *

* This will not set implMethodASM. It is for runtime use only. */ - public FunctionRef(Definition.Type expected, Definition.Method method, MethodHandle impl, Class... captures) { + public FunctionRef(Definition.Type expected, Definition.Method method, MethodHandle impl, int numCaptures) { // e.g. compareTo invokedName = method.name; // e.g. (Object)Comparator - invokedType = MethodType.methodType(expected.clazz, captures); + MethodType implType = impl.type(); + // only include captured parameters as arguments + invokedType = MethodType.methodType(expected.clazz, + implType.dropParameterTypes(numCaptures, implType.parameterCount())); // e.g. (Object,Object)int interfaceMethodType = method.getMethodType().dropParameterTypes(0, 1); @@ -119,7 +128,7 @@ public class FunctionRef { implMethodASM = null; // remove any prepended captured arguments for the 'natural' signature. - samMethodType = adapt(interfaceMethodType, impl.type().dropParameterTypes(0, captures.length)); + samMethodType = adapt(interfaceMethodType, impl.type().dropParameterTypes(0, numCaptures)); } /** diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/WriterConstants.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/WriterConstants.java index d581ba8518a..e2bf804c181 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/WriterConstants.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/WriterConstants.java @@ -72,6 +72,8 @@ public final class WriterConstants { public final static Method CHAR_TO_STRING = getAsmMethod(String.class, "charToString", char.class); public final static Type METHOD_HANDLE_TYPE = Type.getType(MethodHandle.class); + + public static final Type AUGMENTATION_TYPE = Type.getType(Augmentation.class); /** * A Method instance for {@linkplain Pattern#compile}. This isn't available from Definition because we intentionally don't add it there diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/ECapturingFunctionRef.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/ECapturingFunctionRef.java index aa7f807aff1..7bf8e195d6d 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/ECapturingFunctionRef.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/ECapturingFunctionRef.java @@ -76,7 +76,7 @@ public class ECapturingFunctionRef extends AExpression implements ILambda { // static case if (captured.type.sort != Definition.Sort.DEF) { try { - ref = new FunctionRef(expected, captured.type.name, call, captured.type.clazz); + ref = new FunctionRef(expected, captured.type.name, call, 1); } catch (IllegalArgumentException e) { throw createError(e); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EFunctionRef.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EFunctionRef.java index 298b84ffe29..380bfd6c43f 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EFunctionRef.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EFunctionRef.java @@ -76,10 +76,10 @@ public class EFunctionRef extends AExpression implements ILambda { throw new IllegalArgumentException("Cannot convert function reference [" + type + "::" + call + "] " + "to [" + expected.name + "], function not found"); } - ref = new FunctionRef(expected, interfaceMethod, implMethod); + ref = new FunctionRef(expected, interfaceMethod, implMethod, 0); } else { // whitelist lookup - ref = new FunctionRef(expected, type, call); + ref = new FunctionRef(expected, type, call, 0); } } catch (IllegalArgumentException e) { throw createError(e); diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/ELambda.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/ELambda.java index e0ea5e73aeb..0cbb2ed1b33 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/ELambda.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/ELambda.java @@ -175,11 +175,7 @@ public class ELambda extends AExpression implements ILambda { } else { defPointer = null; try { - Class captureClasses[] = new Class[captures.size()]; - for (int i = 0; i < captures.size(); i++) { - captureClasses[i] = captures.get(i).type.clazz; - } - ref = new FunctionRef(expected, interfaceMethod, desugared.method, captureClasses); + ref = new FunctionRef(expected, interfaceMethod, desugared.method, captures.size()); } catch (IllegalArgumentException e) { throw createError(e); } diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.lang.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.lang.txt index 859a7b314dd..035dc9ba0c8 100644 --- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.lang.txt +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.lang.txt @@ -52,6 +52,13 @@ class Iterable -> java.lang.Iterable { Spliterator spliterator() # some adaptations of groovy methods boolean any*(Predicate) + def each*(Consumer) + def eachWithIndex*(ObjIntConsumer) + boolean every*(Predicate) + List findResults*(Function) + Map groupBy*(Function) + String join*(String) + double sum*(ToDoubleFunction) } # Readable: i/o diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.util.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.util.txt index 2a41b2c2c13..66f8f67d869 100644 --- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.util.txt +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/java.util.txt @@ -39,6 +39,15 @@ class Collection -> java.util.Collection extends Iterable { Stream stream() def[] toArray() def[] toArray(def[]) + + # some adaptations of groovy methods + List collect*(Function) + def collect*(Collection,Function) + def find*(Predicate) + List findAll*(Predicate) + def findResult*(Function) + def findResult*(def,Function) + List split*(Predicate) } class Comparator -> java.util.Comparator { @@ -152,6 +161,19 @@ class Map -> java.util.Map { void replaceAll(BiFunction) int size() Collection values() + + # some adaptations of groovy methods + List collect*(BiFunction) + def collect*(Collection,BiFunction) + int count*(BiPredicate) + def each*(BiConsumer) + boolean every*(BiPredicate) + Map.Entry find*(BiPredicate) + Map findAll*(BiPredicate) + def findResult*(BiFunction) + def findResult*(def,BiFunction) + List findResults*(BiFunction) + Map groupBy*(BiFunction) } class Map.Entry -> java.util.Map$Entry { diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/AugmentationTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/AugmentationTests.java index ccc3620b1c9..c0872dd1994 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/AugmentationTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/AugmentationTests.java @@ -19,17 +19,160 @@ package org.elasticsearch.painless; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + public class AugmentationTests extends ScriptTestCase { - @AwaitsFix(bugUrl = "rmuir is working on this") + public void testStatic() { + assertEquals(1, exec("ArrayList l = new ArrayList(); l.add(1); return l.getLength();")); + assertEquals(1, exec("ArrayList l = new ArrayList(); l.add(1); return l.length;")); + } + + public void testSubclass() { + assertEquals(1, exec("List l = new ArrayList(); l.add(1); return l.getLength();")); + assertEquals(1, exec("List l = new ArrayList(); l.add(1); return l.length;")); + } + + public void testDef() { + assertEquals(1, exec("def l = new ArrayList(); l.add(1); return l.getLength();")); + assertEquals(1, exec("def l = new ArrayList(); l.add(1); return l.length;")); + } + public void testCapturingReference() { + assertEquals(1, exec("int foo(Supplier t) { return t.get() }" + + "ArrayList l = new ArrayList(); l.add(1);" + + "return foo(l::getLength);")); + assertEquals(1, exec("int foo(Supplier t) { return t.get() }" + + "List l = new ArrayList(); l.add(1);" + + "return foo(l::getLength);")); assertEquals(1, exec("int foo(Supplier t) { return t.get() }" + "def l = new ArrayList(); l.add(1);" + "return foo(l::getLength);")); } - public void testIterable_Any() { - assertEquals(true, exec("List l = new ArrayList(); l.add(1); l.any(x -> x == 1)")); + assertEquals(true, + exec("List l = new ArrayList(); l.add(1); l.any(x -> x == 1)")); + } + + public void testIterable_Each() { + assertEquals(1, + exec("List l = new ArrayList(); l.add(1); List l2 = new ArrayList(); l.each(l2::add); return l2.size()")); + } + + public void testIterable_EachWithIndex() { + assertEquals(0, + exec("List l = new ArrayList(); l.add(2); Map m = new HashMap(); l.eachWithIndex(m::put); return m.get(2)")); + } + + public void testIterable_Every() { + assertEquals(false, exec("List l = new ArrayList(); l.add(1); l.add(2); l.every(x -> x == 1)")); + } + + public void testIterable_FindResults() { + assertEquals(1, + exec("List l = new ArrayList(); l.add(1); l.add(2); l.findResults(x -> x == 1 ? x : null).size()")); + } + + public void testIterable_GroupBy() { + assertEquals(2, + exec("List l = new ArrayList(); l.add(1); l.add(-1); l.groupBy(x -> x < 0 ? 'negative' : 'positive').size()")); + } + + public void testIterable_Join() { + assertEquals("test,ing", + exec("List l = new ArrayList(); l.add('test'); l.add('ing'); l.join(',')")); + } + + public void testIterable_Sum() { + assertEquals(5.0D, + exec("List l = new ArrayList(); l.add(1); l.add(2); l.sum(x -> x + 1)")); + } + + public void testCollection_Collect() { + assertEquals(Arrays.asList(2, 3), + exec("List l = new ArrayList(); l.add(1); l.add(2); l.collect(x -> x + 1)")); + assertEquals(asSet(2, 3), + exec("List l = new ArrayList(); l.add(1); l.add(2); l.collect(new HashSet(), x -> x + 1)")); + } + + public void testCollection_Find() { + assertEquals(2, + exec("List l = new ArrayList(); l.add(1); l.add(2); return l.find(x -> x == 2)")); + } + + public void testCollection_FindAll() { + assertEquals(Arrays.asList(2), + exec("List l = new ArrayList(); l.add(1); l.add(2); return l.findAll(x -> x == 2)")); + } + + public void testCollection_FindResult() { + assertEquals("found", + exec("List l = new ArrayList(); l.add(1); l.add(2); return l.findResult(x -> x > 1 ? 'found' : null)")); + assertEquals("notfound", + exec("List l = new ArrayList(); l.add(1); l.add(2); return l.findResult('notfound', x -> x > 10 ? 'found' : null)")); + } + + public void testCollection_Split() { + assertEquals(Arrays.asList(Arrays.asList(2), Arrays.asList(1)), + exec("List l = new ArrayList(); l.add(1); l.add(2); return l.split(x -> x == 2)")); + } + + public void testMap_Collect() { + assertEquals(Arrays.asList("one1", "two2"), + exec("Map m = new TreeMap(); m.one = 1; m.two = 2; m.collect((key,value) -> key + value)")); + assertEquals(asSet("one1", "two2"), + exec("Map m = new TreeMap(); m.one = 1; m.two = 2; m.collect(new HashSet(), (key,value) -> key + value)")); + } + + public void testMap_Count() { + assertEquals(1, + exec("Map m = new TreeMap(); m.one = 1; m.two = 2; m.count((key,value) -> value == 2)")); + } + + public void testMap_Each() { + assertEquals(2, + exec("Map m = new TreeMap(); m.one = 1; m.two = 2; Map m2 = new TreeMap(); m.each(m2::put); return m2.size()")); + } + + public void testMap_Every() { + assertEquals(false, + exec("Map m = new TreeMap(); m.one = 1; m.two = 2; m.every((key,value) -> value == 2)")); + } + + public void testMap_Find() { + assertEquals("two", + exec("Map m = new TreeMap(); m.one = 1; m.two = 2; return m.find((key,value) -> value == 2).key")); + } + + public void testMap_FindAll() { + assertEquals(Collections.singletonMap("two", 2), + exec("Map m = new TreeMap(); m.one = 1; m.two = 2; return m.findAll((key,value) -> value == 2)")); + } + + public void testMap_FindResult() { + assertEquals("found", + exec("Map m = new TreeMap(); m.one = 1; m.two = 2; return m.findResult((key,value) -> value == 2 ? 'found' : null)")); + assertEquals("notfound", + exec("Map m = new TreeMap(); m.one = 1; m.two = 2; " + + "return m.findResult('notfound', (key,value) -> value == 10 ? 'found' : null)")); + } + + public void testMap_FindResults() { + assertEquals(Arrays.asList("negative", "positive"), + exec("Map m = new TreeMap(); m.a = -1; m.b = 1; " + + "return m.findResults((key,value) -> value < 0 ? 'negative' : 'positive')")); + } + + public void testMap_GroupBy() { + Map> expected = new HashMap<>(); + expected.put("negative", Collections.singletonMap("a", -1)); + expected.put("positive", Collections.singletonMap("b", 1)); + assertEquals(expected, + exec("Map m = new TreeMap(); m.a = -1; m.b = 1; " + + "return m.groupBy((key,value) -> value < 0 ? 'negative' : 'positive')")); } } From f78ef232dcfabb06697ad73b70c463336a2bd66f Mon Sep 17 00:00:00 2001 From: Robert Muir Date: Tue, 21 Jun 2016 12:05:10 -0400 Subject: [PATCH 3/4] fix bogus comment --- .../src/main/java/org/elasticsearch/painless/Definition.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Definition.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Definition.java index 8aa4da4015f..e57aad862aa 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Definition.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Definition.java @@ -220,7 +220,7 @@ public final class Definition { final Class params[]; final Class returnValue; if (augmentation) { - // virtual/interface method disguised as static + // static method disguised as virtual/interface method params = new Class[1 + arguments.size()]; params[0] = Augmentation.class; for (int i = 0; i < arguments.size(); i++) { From 1b9695a9aa854ecdddbd73af4d4133637c96022d Mon Sep 17 00:00:00 2001 From: Robert Muir Date: Tue, 21 Jun 2016 12:15:59 -0400 Subject: [PATCH 4/4] beef up tests so we ensure you still get good errors in these cases --- .../elasticsearch/painless/FunctionRefTests.java | 14 ++++++++++++++ .../org/elasticsearch/painless/LambdaTests.java | 16 ++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/FunctionRefTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/FunctionRefTests.java index 46d2b7c43fd..7136807bd7e 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/FunctionRefTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/FunctionRefTests.java @@ -170,10 +170,24 @@ public class FunctionRefTests extends ScriptTestCase { assertTrue(expected.getMessage().contains("Unknown reference")); } + public void testWrongArityNotEnough() { + IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> { + exec("List l = new ArrayList(); l.add(2); l.add(1); l.sort(String::isEmpty);"); + }); + assertTrue(expected.getMessage().contains("Unknown reference")); + } + public void testWrongArityDef() { IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> { exec("def y = Optional.empty(); return y.orElseGet(String::startsWith);"); }); assertTrue(expected.getMessage().contains("Unknown reference")); } + + public void testWrongArityNotEnoughDef() { + IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> { + exec("def l = new ArrayList(); l.add(2); l.add(1); l.sort(String::isEmpty);"); + }); + assertTrue(expected.getMessage().contains("Unknown reference")); + } } diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/LambdaTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/LambdaTests.java index 7b2d5e6a935..dbca5243ec2 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/LambdaTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/LambdaTests.java @@ -180,6 +180,22 @@ public class LambdaTests extends ScriptTestCase { assertTrue(expected.getMessage(), expected.getMessage().contains("Incorrect number of parameters")); } + public void testWrongArityNotEnough() { + IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> { + exec("List l = new ArrayList(); l.add(1); l.add(1); " + + "return l.stream().mapToInt(() -> 5).sum();"); + }); + assertTrue(expected.getMessage().contains("Incorrect number of parameters")); + } + + public void testWrongArityNotEnoughDef() { + IllegalArgumentException expected = expectScriptThrows(IllegalArgumentException.class, () -> { + exec("def l = new ArrayList(); l.add(1); l.add(1); " + + "return l.stream().mapToInt(() -> 5).sum();"); + }); + assertTrue(expected.getMessage().contains("Incorrect number of parameters")); + } + public void testLambdaInFunction() { assertEquals(5, exec("def foo() { Optional.empty().orElseGet(() -> 5) } return foo();")); }