Update Painless to Allow Augmentation from Any Class (#25360)
Custom whitelists in Painless will need to allow classes to be augmented beyond the currently hard-coded Augmentation class tied to Painless directly. This change allows any class to specify an augmentation on a Painless struct using an appropriate static method. Changes to loading the whitelist have also been created to allow for this specification of a different class for augmentation.
This commit is contained in:
parent
e6e5ae6202
commit
96b62409a8
|
@ -199,14 +199,14 @@ public final class Definition {
|
||||||
public static class Method {
|
public static class Method {
|
||||||
public final String name;
|
public final String name;
|
||||||
public final Struct owner;
|
public final Struct owner;
|
||||||
public final boolean augmentation;
|
public final Class<?> augmentation;
|
||||||
public final Type rtn;
|
public final Type rtn;
|
||||||
public final List<Type> arguments;
|
public final List<Type> arguments;
|
||||||
public final org.objectweb.asm.commons.Method method;
|
public final org.objectweb.asm.commons.Method method;
|
||||||
public final int modifiers;
|
public final int modifiers;
|
||||||
public final MethodHandle handle;
|
public final MethodHandle handle;
|
||||||
|
|
||||||
public Method(String name, Struct owner, boolean augmentation, Type rtn, List<Type> arguments,
|
public Method(String name, Struct owner, Class<?> augmentation, Type rtn, List<Type> arguments,
|
||||||
org.objectweb.asm.commons.Method method, int modifiers, MethodHandle handle) {
|
org.objectweb.asm.commons.Method method, int modifiers, MethodHandle handle) {
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.augmentation = augmentation;
|
this.augmentation = augmentation;
|
||||||
|
@ -232,10 +232,10 @@ public final class Definition {
|
||||||
// otherwise compute it
|
// otherwise compute it
|
||||||
final Class<?> params[];
|
final Class<?> params[];
|
||||||
final Class<?> returnValue;
|
final Class<?> returnValue;
|
||||||
if (augmentation) {
|
if (augmentation != null) {
|
||||||
// static method disguised as virtual/interface method
|
// static method disguised as virtual/interface method
|
||||||
params = new Class<?>[1 + arguments.size()];
|
params = new Class<?>[1 + arguments.size()];
|
||||||
params[0] = Augmentation.class;
|
params[0] = augmentation;
|
||||||
for (int i = 0; i < arguments.size(); i++) {
|
for (int i = 0; i < arguments.size(); i++) {
|
||||||
params[i + 1] = arguments.get(i).clazz;
|
params[i + 1] = arguments.get(i).clazz;
|
||||||
}
|
}
|
||||||
|
@ -268,9 +268,9 @@ public final class Definition {
|
||||||
|
|
||||||
public void write(MethodWriter writer) {
|
public void write(MethodWriter writer) {
|
||||||
final org.objectweb.asm.Type type;
|
final org.objectweb.asm.Type type;
|
||||||
if (augmentation) {
|
if (augmentation != null) {
|
||||||
assert java.lang.reflect.Modifier.isStatic(modifiers);
|
assert java.lang.reflect.Modifier.isStatic(modifiers);
|
||||||
type = WriterConstants.AUGMENTATION_TYPE;
|
type = org.objectweb.asm.Type.getType(augmentation);
|
||||||
} else {
|
} else {
|
||||||
type = owner.type;
|
type = owner.type;
|
||||||
}
|
}
|
||||||
|
@ -731,7 +731,7 @@ public final class Definition {
|
||||||
" with arguments " + Arrays.toString(classes) + ".");
|
" with arguments " + Arrays.toString(classes) + ".");
|
||||||
}
|
}
|
||||||
|
|
||||||
final Method constructor = new Method(name, owner, false, returnType, Arrays.asList(args), asm, reflect.getModifiers(), handle);
|
final Method constructor = new Method(name, owner, null, returnType, Arrays.asList(args), asm, reflect.getModifiers(), handle);
|
||||||
|
|
||||||
owner.constructors.put(methodKey, constructor);
|
owner.constructors.put(methodKey, constructor);
|
||||||
}
|
}
|
||||||
|
@ -775,10 +775,14 @@ public final class Definition {
|
||||||
}
|
}
|
||||||
addConstructorInternal(className, "<init>", args);
|
addConstructorInternal(className, "<init>", args);
|
||||||
} else {
|
} else {
|
||||||
if (methodName.indexOf("*") >= 0) {
|
int index = methodName.lastIndexOf(".");
|
||||||
addMethodInternal(className, methodName.substring(0, methodName.length() - 1), true, rtn, args);
|
|
||||||
|
if (index >= 0) {
|
||||||
|
String augmentation = methodName.substring(0, index);
|
||||||
|
methodName = methodName.substring(index + 1);
|
||||||
|
addMethodInternal(className, methodName, augmentation, rtn, args);
|
||||||
} else {
|
} else {
|
||||||
addMethodInternal(className, methodName, false, rtn, args);
|
addMethodInternal(className, methodName, null, rtn, args);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -787,8 +791,7 @@ public final class Definition {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addMethodInternal(String struct, String name, boolean augmentation,
|
private void addMethodInternal(String struct, String name, String augmentation, Type rtn, Type[] args) {
|
||||||
Type rtn, Type[] args) {
|
|
||||||
final Struct owner = structsMap.get(struct);
|
final Struct owner = structsMap.get(struct);
|
||||||
|
|
||||||
if (owner == null) {
|
if (owner == null) {
|
||||||
|
@ -817,14 +820,20 @@ public final class Definition {
|
||||||
final Class<?> implClass;
|
final Class<?> implClass;
|
||||||
final Class<?>[] params;
|
final Class<?>[] params;
|
||||||
|
|
||||||
if (augmentation == false) {
|
if (augmentation == null) {
|
||||||
implClass = owner.clazz;
|
implClass = owner.clazz;
|
||||||
params = new Class<?>[args.length];
|
params = new Class<?>[args.length];
|
||||||
for (int count = 0; count < args.length; ++count) {
|
for (int count = 0; count < args.length; ++count) {
|
||||||
params[count] = args[count].clazz;
|
params[count] = args[count].clazz;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
implClass = Augmentation.class;
|
try {
|
||||||
|
implClass = Class.forName(augmentation);
|
||||||
|
} catch (ClassNotFoundException cnfe) {
|
||||||
|
throw new IllegalArgumentException("Augmentation class [" + augmentation + "]" +
|
||||||
|
" not found for struct [" + struct + "] using method name [" + name + "].", cnfe);
|
||||||
|
}
|
||||||
|
|
||||||
params = new Class<?>[args.length + 1];
|
params = new Class<?>[args.length + 1];
|
||||||
params[0] = owner.clazz;
|
params[0] = owner.clazz;
|
||||||
for (int count = 0; count < args.length; ++count) {
|
for (int count = 0; count < args.length; ++count) {
|
||||||
|
@ -862,9 +871,10 @@ public final class Definition {
|
||||||
}
|
}
|
||||||
|
|
||||||
final int modifiers = reflect.getModifiers();
|
final int modifiers = reflect.getModifiers();
|
||||||
final Method method = new Method(name, owner, augmentation, rtn, Arrays.asList(args), asm, modifiers, handle);
|
final Method method =
|
||||||
|
new Method(name, owner, augmentation == null ? null : implClass, rtn, Arrays.asList(args), asm, modifiers, handle);
|
||||||
|
|
||||||
if (augmentation == false && java.lang.reflect.Modifier.isStatic(modifiers)) {
|
if (augmentation == null && java.lang.reflect.Modifier.isStatic(modifiers)) {
|
||||||
owner.staticMethods.put(methodKey, method);
|
owner.staticMethods.put(methodKey, method);
|
||||||
} else {
|
} else {
|
||||||
owner.methods.put(methodKey, method);
|
owner.methods.put(methodKey, method);
|
||||||
|
@ -966,8 +976,8 @@ public final class Definition {
|
||||||
// TODO: we *have* to remove all these public members and use getter methods to encapsulate!
|
// TODO: we *have* to remove all these public members and use getter methods to encapsulate!
|
||||||
final Class<?> impl;
|
final Class<?> impl;
|
||||||
final Class<?> arguments[];
|
final Class<?> arguments[];
|
||||||
if (method.augmentation) {
|
if (method.augmentation != null) {
|
||||||
impl = Augmentation.class;
|
impl = method.augmentation;
|
||||||
arguments = new Class<?>[method.arguments.size() + 1];
|
arguments = new Class<?>[method.arguments.size() + 1];
|
||||||
arguments[0] = method.owner.clazz;
|
arguments[0] = method.owner.clazz;
|
||||||
for (int i = 0; i < method.arguments.size(); i++) {
|
for (int i = 0; i < method.arguments.size(); i++) {
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.painless;
|
||||||
|
|
||||||
|
public class FeatureTestAugmentation {
|
||||||
|
public static int getTotal(FeatureTest ft) {
|
||||||
|
return ft.getX() + ft.getY();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static int addToTotal(FeatureTest ft, int add) {
|
||||||
|
return getTotal(ft) + add;
|
||||||
|
}
|
||||||
|
|
||||||
|
private FeatureTestAugmentation() {}
|
||||||
|
}
|
|
@ -97,8 +97,8 @@ public class FunctionRef {
|
||||||
// the Painless$Script class can be inferred if owner is null
|
// the Painless$Script class can be inferred if owner is null
|
||||||
if (delegateMethod.owner == null) {
|
if (delegateMethod.owner == null) {
|
||||||
delegateClassName = CLASS_NAME;
|
delegateClassName = CLASS_NAME;
|
||||||
} else if (delegateMethod.augmentation) {
|
} else if (delegateMethod.augmentation != null) {
|
||||||
delegateClassName = Augmentation.class.getName();
|
delegateClassName = delegateMethod.augmentation.getName();
|
||||||
} else {
|
} else {
|
||||||
delegateClassName = delegateMethod.owner.clazz.getName();
|
delegateClassName = delegateMethod.owner.clazz.getName();
|
||||||
}
|
}
|
||||||
|
|
|
@ -135,7 +135,7 @@ public final class SFunction extends AStatement {
|
||||||
|
|
||||||
org.objectweb.asm.commons.Method method =
|
org.objectweb.asm.commons.Method method =
|
||||||
new org.objectweb.asm.commons.Method(name, MethodType.methodType(rtnType.clazz, paramClasses).toMethodDescriptorString());
|
new org.objectweb.asm.commons.Method(name, MethodType.methodType(rtnType.clazz, paramClasses).toMethodDescriptorString());
|
||||||
this.method = new Method(name, null, false, rtnType, paramTypes, method, Modifier.STATIC | Modifier.PRIVATE, null);
|
this.method = new Method(name, null, null, rtnType, paramTypes, method, Modifier.STATIC | Modifier.PRIVATE, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -36,8 +36,8 @@ class CharSequence -> java.lang.CharSequence {
|
||||||
IntStream chars()
|
IntStream chars()
|
||||||
IntStream codePoints()
|
IntStream codePoints()
|
||||||
int length()
|
int length()
|
||||||
String replaceAll*(Pattern,Function)
|
String org.elasticsearch.painless.api.Augmentation.replaceAll(Pattern,Function)
|
||||||
String replaceFirst*(Pattern,Function)
|
String org.elasticsearch.painless.api.Augmentation.replaceFirst(Pattern,Function)
|
||||||
CharSequence subSequence(int,int)
|
CharSequence subSequence(int,int)
|
||||||
String toString()
|
String toString()
|
||||||
}
|
}
|
||||||
|
@ -53,17 +53,17 @@ class Iterable -> java.lang.Iterable {
|
||||||
Iterator iterator()
|
Iterator iterator()
|
||||||
Spliterator spliterator()
|
Spliterator spliterator()
|
||||||
# some adaptations of groovy methods
|
# some adaptations of groovy methods
|
||||||
boolean any*(Predicate)
|
boolean org.elasticsearch.painless.api.Augmentation.any(Predicate)
|
||||||
Collection asCollection*()
|
Collection org.elasticsearch.painless.api.Augmentation.asCollection()
|
||||||
List asList*()
|
List org.elasticsearch.painless.api.Augmentation.asList()
|
||||||
def each*(Consumer)
|
def org.elasticsearch.painless.api.Augmentation.each(Consumer)
|
||||||
def eachWithIndex*(ObjIntConsumer)
|
def org.elasticsearch.painless.api.Augmentation.eachWithIndex(ObjIntConsumer)
|
||||||
boolean every*(Predicate)
|
boolean org.elasticsearch.painless.api.Augmentation.every(Predicate)
|
||||||
List findResults*(Function)
|
List org.elasticsearch.painless.api.Augmentation.findResults(Function)
|
||||||
Map groupBy*(Function)
|
Map org.elasticsearch.painless.api.Augmentation.groupBy(Function)
|
||||||
String join*(String)
|
String org.elasticsearch.painless.api.Augmentation.join(String)
|
||||||
double sum*()
|
double org.elasticsearch.painless.api.Augmentation.sum()
|
||||||
double sum*(ToDoubleFunction)
|
double org.elasticsearch.painless.api.Augmentation.sum(ToDoubleFunction)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Readable: i/o
|
# Readable: i/o
|
||||||
|
@ -756,8 +756,8 @@ class String -> java.lang.String extends CharSequence,Comparable,Object {
|
||||||
boolean contentEquals(CharSequence)
|
boolean contentEquals(CharSequence)
|
||||||
String copyValueOf(char[])
|
String copyValueOf(char[])
|
||||||
String copyValueOf(char[],int,int)
|
String copyValueOf(char[],int,int)
|
||||||
String decodeBase64*()
|
String org.elasticsearch.painless.api.Augmentation.decodeBase64()
|
||||||
String encodeBase64*()
|
String org.elasticsearch.painless.api.Augmentation.encodeBase64()
|
||||||
boolean endsWith(String)
|
boolean endsWith(String)
|
||||||
boolean equalsIgnoreCase(String)
|
boolean equalsIgnoreCase(String)
|
||||||
String format(Locale,String,def[])
|
String format(Locale,String,def[])
|
||||||
|
|
|
@ -42,7 +42,7 @@ class Matcher -> java.util.regex.Matcher extends Object {
|
||||||
boolean find(int)
|
boolean find(int)
|
||||||
String group()
|
String group()
|
||||||
String group(int)
|
String group(int)
|
||||||
String namedGroup*(String)
|
String org.elasticsearch.painless.api.Augmentation.namedGroup(String)
|
||||||
int groupCount()
|
int groupCount()
|
||||||
boolean hasAnchoringBounds()
|
boolean hasAnchoringBounds()
|
||||||
boolean hasTransparentBounds()
|
boolean hasTransparentBounds()
|
||||||
|
|
|
@ -41,13 +41,13 @@ class Collection -> java.util.Collection extends Iterable {
|
||||||
def[] toArray(def[])
|
def[] toArray(def[])
|
||||||
|
|
||||||
# some adaptations of groovy methods
|
# some adaptations of groovy methods
|
||||||
List collect*(Function)
|
List org.elasticsearch.painless.api.Augmentation.collect(Function)
|
||||||
def collect*(Collection,Function)
|
def org.elasticsearch.painless.api.Augmentation.collect(Collection,Function)
|
||||||
def find*(Predicate)
|
def org.elasticsearch.painless.api.Augmentation.find(Predicate)
|
||||||
List findAll*(Predicate)
|
List org.elasticsearch.painless.api.Augmentation.findAll(Predicate)
|
||||||
def findResult*(Function)
|
def org.elasticsearch.painless.api.Augmentation.findResult(Function)
|
||||||
def findResult*(def,Function)
|
def org.elasticsearch.painless.api.Augmentation.findResult(def,Function)
|
||||||
List split*(Predicate)
|
List org.elasticsearch.painless.api.Augmentation.split(Predicate)
|
||||||
}
|
}
|
||||||
|
|
||||||
class Comparator -> java.util.Comparator {
|
class Comparator -> java.util.Comparator {
|
||||||
|
@ -123,7 +123,7 @@ class List -> java.util.List extends Collection,Iterable {
|
||||||
def remove(int)
|
def remove(int)
|
||||||
void replaceAll(UnaryOperator)
|
void replaceAll(UnaryOperator)
|
||||||
def set(int,def)
|
def set(int,def)
|
||||||
int getLength*()
|
int org.elasticsearch.painless.api.Augmentation.getLength()
|
||||||
void sort(Comparator)
|
void sort(Comparator)
|
||||||
List subList(int,int)
|
List subList(int,int)
|
||||||
}
|
}
|
||||||
|
@ -163,17 +163,17 @@ class Map -> java.util.Map {
|
||||||
Collection values()
|
Collection values()
|
||||||
|
|
||||||
# some adaptations of groovy methods
|
# some adaptations of groovy methods
|
||||||
List collect*(BiFunction)
|
List org.elasticsearch.painless.api.Augmentation.collect(BiFunction)
|
||||||
def collect*(Collection,BiFunction)
|
def org.elasticsearch.painless.api.Augmentation.collect(Collection,BiFunction)
|
||||||
int count*(BiPredicate)
|
int org.elasticsearch.painless.api.Augmentation.count(BiPredicate)
|
||||||
def each*(BiConsumer)
|
def org.elasticsearch.painless.api.Augmentation.each(BiConsumer)
|
||||||
boolean every*(BiPredicate)
|
boolean org.elasticsearch.painless.api.Augmentation.every(BiPredicate)
|
||||||
Map.Entry find*(BiPredicate)
|
Map.Entry org.elasticsearch.painless.api.Augmentation.find(BiPredicate)
|
||||||
Map findAll*(BiPredicate)
|
Map org.elasticsearch.painless.api.Augmentation.findAll(BiPredicate)
|
||||||
def findResult*(BiFunction)
|
def org.elasticsearch.painless.api.Augmentation.findResult(BiFunction)
|
||||||
def findResult*(def,BiFunction)
|
def org.elasticsearch.painless.api.Augmentation.findResult(def,BiFunction)
|
||||||
List findResults*(BiFunction)
|
List org.elasticsearch.painless.api.Augmentation.findResults(BiFunction)
|
||||||
Map groupBy*(BiFunction)
|
Map org.elasticsearch.painless.api.Augmentation.groupBy(BiFunction)
|
||||||
}
|
}
|
||||||
|
|
||||||
class Map.Entry -> java.util.Map$Entry {
|
class Map.Entry -> java.util.Map$Entry {
|
||||||
|
|
|
@ -156,6 +156,8 @@ class org.elasticsearch.painless.FeatureTest -> org.elasticsearch.painless.Featu
|
||||||
boolean overloadedStatic(boolean)
|
boolean overloadedStatic(boolean)
|
||||||
Object twoFunctionsOfX(Function,Function)
|
Object twoFunctionsOfX(Function,Function)
|
||||||
void listInput(List)
|
void listInput(List)
|
||||||
|
int org.elasticsearch.painless.FeatureTestAugmentation.getTotal()
|
||||||
|
int org.elasticsearch.painless.FeatureTestAugmentation.addToTotal(int)
|
||||||
}
|
}
|
||||||
|
|
||||||
class org.elasticsearch.search.lookup.FieldLookup -> org.elasticsearch.search.lookup.FieldLookup extends Object {
|
class org.elasticsearch.search.lookup.FieldLookup -> org.elasticsearch.search.lookup.FieldLookup extends Object {
|
||||||
|
|
|
@ -188,4 +188,15 @@ public class AugmentationTests extends ScriptTestCase {
|
||||||
exec("Map m = new TreeMap(); m.a = -1; m.b = 1; " +
|
exec("Map m = new TreeMap(); m.a = -1; m.b = 1; " +
|
||||||
"return m.groupBy((key,value) -> value < 0 ? 'negative' : 'positive')"));
|
"return m.groupBy((key,value) -> value < 0 ? 'negative' : 'positive')"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testFeatureTest() {
|
||||||
|
assertEquals(5, exec("org.elasticsearch.painless.FeatureTest ft = new org.elasticsearch.painless.FeatureTest();" +
|
||||||
|
" ft.setX(3); ft.setY(2); return ft.getTotal()"));
|
||||||
|
assertEquals(5, exec("def ft = new org.elasticsearch.painless.FeatureTest();" +
|
||||||
|
" ft.setX(3); ft.setY(2); return ft.getTotal()"));
|
||||||
|
assertEquals(8, exec("org.elasticsearch.painless.FeatureTest ft = new org.elasticsearch.painless.FeatureTest();" +
|
||||||
|
" ft.setX(3); ft.setY(2); return ft.addToTotal(3)"));
|
||||||
|
assertEquals(8, exec("def ft = new org.elasticsearch.painless.FeatureTest();" +
|
||||||
|
" ft.setX(3); ft.setY(2); return ft.addToTotal(3)"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -164,7 +164,7 @@ public class PainlessDocGenerator {
|
||||||
emitAnchor(stream, method);
|
emitAnchor(stream, method);
|
||||||
stream.print("]]");
|
stream.print("]]");
|
||||||
|
|
||||||
if (false == method.augmentation && Modifier.isStatic(method.modifiers)) {
|
if (null == method.augmentation && Modifier.isStatic(method.modifiers)) {
|
||||||
stream.print("static ");
|
stream.print("static ");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -268,12 +268,12 @@ public class PainlessDocGenerator {
|
||||||
stream.print("link:{");
|
stream.print("link:{");
|
||||||
stream.print(root);
|
stream.print(root);
|
||||||
stream.print("-javadoc}/");
|
stream.print("-javadoc}/");
|
||||||
stream.print((method.augmentation ? Augmentation.class : method.owner.clazz).getName().replace('.', '/'));
|
stream.print((method.augmentation != null ? method.augmentation : method.owner.clazz).getName().replace('.', '/'));
|
||||||
stream.print(".html#");
|
stream.print(".html#");
|
||||||
stream.print(methodName(method));
|
stream.print(methodName(method));
|
||||||
stream.print("%2D");
|
stream.print("%2D");
|
||||||
boolean first = true;
|
boolean first = true;
|
||||||
if (method.augmentation) {
|
if (method.augmentation != null) {
|
||||||
first = false;
|
first = false;
|
||||||
stream.print(method.owner.clazz.getName());
|
stream.print(method.owner.clazz.getName());
|
||||||
}
|
}
|
||||||
|
@ -309,7 +309,7 @@ public class PainlessDocGenerator {
|
||||||
* Pick the javadoc root for a {@link Method}.
|
* Pick the javadoc root for a {@link Method}.
|
||||||
*/
|
*/
|
||||||
private static String javadocRoot(Method method) {
|
private static String javadocRoot(Method method) {
|
||||||
if (method.augmentation) {
|
if (method.augmentation != null) {
|
||||||
return "painless";
|
return "painless";
|
||||||
}
|
}
|
||||||
return javadocRoot(method.owner);
|
return javadocRoot(method.owner);
|
||||||
|
|
Loading…
Reference in New Issue