mirror of https://github.com/apache/druid.git
expression aggregator (#11104)
* add experimental expression aggregator * add test * fix lgtm * fix test * adjust test * use not null constant * array_set_concat docs * add equals and hashcode and tostring * fix it * spelling * do multi-value magic for expression agg, more javadocs, tests * formatting * fix inspection * more better * nullable
This commit is contained in:
parent
49a9c3ffb7
commit
57ff1f9cdb
|
@ -183,9 +183,9 @@ class NullLongExpr extends ConstantExpr<Long>
|
|||
|
||||
class LongArrayExpr extends ConstantExpr<Long[]>
|
||||
{
|
||||
LongArrayExpr(Long[] value)
|
||||
LongArrayExpr(@Nullable Long[] value)
|
||||
{
|
||||
super(ExprType.LONG_ARRAY, Preconditions.checkNotNull(value, "value"));
|
||||
super(ExprType.LONG_ARRAY, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -320,9 +320,9 @@ class NullDoubleExpr extends ConstantExpr<Double>
|
|||
|
||||
class DoubleArrayExpr extends ConstantExpr<Double[]>
|
||||
{
|
||||
DoubleArrayExpr(Double[] value)
|
||||
DoubleArrayExpr(@Nullable Double[] value)
|
||||
{
|
||||
super(ExprType.DOUBLE_ARRAY, Preconditions.checkNotNull(value, "value"));
|
||||
super(ExprType.DOUBLE_ARRAY, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -426,9 +426,9 @@ class StringExpr extends ConstantExpr<String>
|
|||
|
||||
class StringArrayExpr extends ConstantExpr<String[]>
|
||||
{
|
||||
StringArrayExpr(String[] value)
|
||||
StringArrayExpr(@Nullable String[] value)
|
||||
{
|
||||
super(ExprType.STRING_ARRAY, Preconditions.checkNotNull(value, "value"));
|
||||
super(ExprType.STRING_ARRAY, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -23,15 +23,357 @@ import com.google.common.primitives.Doubles;
|
|||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.common.guava.GuavaUtils;
|
||||
import org.apache.druid.java.util.common.IAE;
|
||||
import org.apache.druid.java.util.common.ISE;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.java.util.common.UOE;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Generic result holder for evaluated {@link Expr} containing the value and {@link ExprType} of the value to allow
|
||||
*/
|
||||
public abstract class ExprEval<T>
|
||||
{
|
||||
private static final int NULL_LENGTH = -1;
|
||||
|
||||
/**
|
||||
* Deserialize an expression stored in a bytebuffer, e.g. for an agg.
|
||||
*
|
||||
* This should be refactored to be consolidated with some of the standard type handling of aggregators probably
|
||||
*/
|
||||
public static ExprEval deserialize(ByteBuffer buffer, int position)
|
||||
{
|
||||
// | expression type (byte) | expression bytes |
|
||||
ExprType type = ExprType.fromByte(buffer.get(position));
|
||||
int offset = position + 1;
|
||||
switch (type) {
|
||||
case LONG:
|
||||
// | expression type (byte) | is null (byte) | long bytes |
|
||||
if (buffer.get(offset++) == NullHandling.IS_NOT_NULL_BYTE) {
|
||||
return of(buffer.getLong(offset));
|
||||
}
|
||||
return ofLong(null);
|
||||
case DOUBLE:
|
||||
// | expression type (byte) | is null (byte) | double bytes |
|
||||
if (buffer.get(offset++) == NullHandling.IS_NOT_NULL_BYTE) {
|
||||
return of(buffer.getDouble(offset));
|
||||
}
|
||||
return ofDouble(null);
|
||||
case STRING:
|
||||
// | expression type (byte) | string length (int) | string bytes |
|
||||
final int length = buffer.getInt(offset);
|
||||
if (length < 0) {
|
||||
return of(null);
|
||||
}
|
||||
final byte[] stringBytes = new byte[length];
|
||||
final int oldPosition = buffer.position();
|
||||
buffer.position(offset + Integer.BYTES);
|
||||
buffer.get(stringBytes, 0, length);
|
||||
buffer.position(oldPosition);
|
||||
return of(StringUtils.fromUtf8(stringBytes));
|
||||
case LONG_ARRAY:
|
||||
// | expression type (byte) | array length (int) | array bytes |
|
||||
final int longArrayLength = buffer.getInt(offset);
|
||||
offset += Integer.BYTES;
|
||||
if (longArrayLength < 0) {
|
||||
return ofLongArray(null);
|
||||
}
|
||||
final Long[] longs = new Long[longArrayLength];
|
||||
for (int i = 0; i < longArrayLength; i++) {
|
||||
final byte isNull = buffer.get(offset);
|
||||
offset += Byte.BYTES;
|
||||
if (isNull == NullHandling.IS_NOT_NULL_BYTE) {
|
||||
// | is null (byte) | long bytes |
|
||||
longs[i] = buffer.getLong(offset);
|
||||
offset += Long.BYTES;
|
||||
} else {
|
||||
// | is null (byte) |
|
||||
longs[i] = null;
|
||||
}
|
||||
}
|
||||
return ofLongArray(longs);
|
||||
case DOUBLE_ARRAY:
|
||||
// | expression type (byte) | array length (int) | array bytes |
|
||||
final int doubleArrayLength = buffer.getInt(offset);
|
||||
offset += Integer.BYTES;
|
||||
if (doubleArrayLength < 0) {
|
||||
return ofDoubleArray(null);
|
||||
}
|
||||
final Double[] doubles = new Double[doubleArrayLength];
|
||||
for (int i = 0; i < doubleArrayLength; i++) {
|
||||
final byte isNull = buffer.get(offset);
|
||||
offset += Byte.BYTES;
|
||||
if (isNull == NullHandling.IS_NOT_NULL_BYTE) {
|
||||
// | is null (byte) | double bytes |
|
||||
doubles[i] = buffer.getDouble(offset);
|
||||
offset += Double.BYTES;
|
||||
} else {
|
||||
// | is null (byte) |
|
||||
doubles[i] = null;
|
||||
}
|
||||
}
|
||||
return ofDoubleArray(doubles);
|
||||
case STRING_ARRAY:
|
||||
// | expression type (byte) | array length (int) | array bytes |
|
||||
final int stringArrayLength = buffer.getInt(offset);
|
||||
offset += Integer.BYTES;
|
||||
if (stringArrayLength < 0) {
|
||||
return ofStringArray(null);
|
||||
}
|
||||
final String[] stringArray = new String[stringArrayLength];
|
||||
for (int i = 0; i < stringArrayLength; i++) {
|
||||
final int stringElementLength = buffer.getInt(offset);
|
||||
offset += Integer.BYTES;
|
||||
if (stringElementLength < 0) {
|
||||
// | string length (int) |
|
||||
stringArray[i] = null;
|
||||
} else {
|
||||
// | string length (int) | string bytes |
|
||||
final byte[] stringElementBytes = new byte[stringElementLength];
|
||||
final int oldPosition2 = buffer.position();
|
||||
buffer.position(offset);
|
||||
buffer.get(stringElementBytes, 0, stringElementLength);
|
||||
buffer.position(oldPosition2);
|
||||
stringArray[i] = StringUtils.fromUtf8(stringElementBytes);
|
||||
offset += stringElementLength;
|
||||
}
|
||||
}
|
||||
return ofStringArray(stringArray);
|
||||
default:
|
||||
throw new UOE("how can this be?");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Write an expression result to a bytebuffer, throwing an {@link ISE} if the data exceeds a maximum size. Primitive
|
||||
* numeric types are not validated to be lower than max size, so it is expected to be at least 10 bytes. Callers
|
||||
* of this method should enforce this themselves (instead of doing it here, which might be done every row)
|
||||
*
|
||||
* This should be refactored to be consolidated with some of the standard type handling of aggregators probably
|
||||
*/
|
||||
public static void serialize(ByteBuffer buffer, int position, ExprEval<?> eval, int maxSizeBytes)
|
||||
{
|
||||
int offset = position;
|
||||
buffer.put(offset++, eval.type().getId());
|
||||
switch (eval.type()) {
|
||||
case LONG:
|
||||
if (eval.isNumericNull()) {
|
||||
buffer.put(offset, NullHandling.IS_NULL_BYTE);
|
||||
} else {
|
||||
buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
|
||||
buffer.putLong(offset, eval.asLong());
|
||||
}
|
||||
break;
|
||||
case DOUBLE:
|
||||
if (eval.isNumericNull()) {
|
||||
buffer.put(offset, NullHandling.IS_NULL_BYTE);
|
||||
} else {
|
||||
buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
|
||||
buffer.putDouble(offset, eval.asDouble());
|
||||
}
|
||||
break;
|
||||
case STRING:
|
||||
final byte[] stringBytes = StringUtils.toUtf8Nullable(eval.asString());
|
||||
if (stringBytes != null) {
|
||||
// | expression type (byte) | string length (int) | string bytes |
|
||||
checkMaxBytes(eval.type(), 1 + Integer.BYTES + stringBytes.length, maxSizeBytes);
|
||||
buffer.putInt(offset, stringBytes.length);
|
||||
offset += Integer.BYTES;
|
||||
final int oldPosition = buffer.position();
|
||||
buffer.position(offset);
|
||||
buffer.put(stringBytes, 0, stringBytes.length);
|
||||
buffer.position(oldPosition);
|
||||
} else {
|
||||
checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes);
|
||||
buffer.putInt(offset, NULL_LENGTH);
|
||||
}
|
||||
break;
|
||||
case LONG_ARRAY:
|
||||
Long[] longs = eval.asLongArray();
|
||||
if (longs == null) {
|
||||
// | expression type (byte) | array length (int) |
|
||||
checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes);
|
||||
buffer.putInt(offset, NULL_LENGTH);
|
||||
} else {
|
||||
// | expression type (byte) | array length (int) | array bytes |
|
||||
final int sizeBytes = 1 + Integer.BYTES + (Long.BYTES * longs.length);
|
||||
checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
|
||||
buffer.putInt(offset, longs.length);
|
||||
offset += Integer.BYTES;
|
||||
for (Long aLong : longs) {
|
||||
if (aLong != null) {
|
||||
buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
|
||||
offset++;
|
||||
buffer.putLong(offset, aLong);
|
||||
offset += Long.BYTES;
|
||||
} else {
|
||||
buffer.put(offset++, NullHandling.IS_NULL_BYTE);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case DOUBLE_ARRAY:
|
||||
Double[] doubles = eval.asDoubleArray();
|
||||
if (doubles == null) {
|
||||
// | expression type (byte) | array length (int) |
|
||||
checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes);
|
||||
buffer.putInt(offset, NULL_LENGTH);
|
||||
} else {
|
||||
// | expression type (byte) | array length (int) | array bytes |
|
||||
final int sizeBytes = 1 + Integer.BYTES + (Double.BYTES * doubles.length);
|
||||
checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
|
||||
buffer.putInt(offset, doubles.length);
|
||||
offset += Integer.BYTES;
|
||||
|
||||
for (Double aDouble : doubles) {
|
||||
if (aDouble != null) {
|
||||
buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
|
||||
offset++;
|
||||
buffer.putDouble(offset, aDouble);
|
||||
offset += Long.BYTES;
|
||||
} else {
|
||||
buffer.put(offset++, NullHandling.IS_NULL_BYTE);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case STRING_ARRAY:
|
||||
String[] strings = eval.asStringArray();
|
||||
if (strings == null) {
|
||||
// | expression type (byte) | array length (int) |
|
||||
checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes);
|
||||
buffer.putInt(offset, NULL_LENGTH);
|
||||
} else {
|
||||
// | expression type (byte) | array length (int) | array bytes |
|
||||
buffer.putInt(offset, strings.length);
|
||||
offset += Integer.BYTES;
|
||||
int sizeBytes = 1 + Integer.BYTES;
|
||||
for (String string : strings) {
|
||||
if (string == null) {
|
||||
// | string length (int) |
|
||||
sizeBytes += Integer.BYTES;
|
||||
checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
|
||||
buffer.putInt(offset, NULL_LENGTH);
|
||||
offset += Integer.BYTES;
|
||||
} else {
|
||||
// | string length (int) | string bytes |
|
||||
final byte[] stringElementBytes = StringUtils.toUtf8(string);
|
||||
sizeBytes += Integer.BYTES + stringElementBytes.length;
|
||||
checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
|
||||
buffer.putInt(offset, stringElementBytes.length);
|
||||
offset += Integer.BYTES;
|
||||
final int oldPosition = buffer.position();
|
||||
buffer.position(offset);
|
||||
buffer.put(stringElementBytes, 0, stringElementBytes.length);
|
||||
buffer.position(oldPosition);
|
||||
offset += stringElementBytes.length;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new UOE("how can this be?");
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkMaxBytes(ExprType type, int sizeBytes, int maxSizeBytes)
|
||||
{
|
||||
if (sizeBytes > maxSizeBytes) {
|
||||
throw new ISE("Unable to serialize [%s], size [%s] is larger than max [%s]", type, sizeBytes, maxSizeBytes);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a List to an appropriate array type, optionally doing some conversion to make multi-valued strings
|
||||
* consistent across selector types, which are not consistent in treatment of null, [], and [null].
|
||||
*
|
||||
* If homogenizeMultiValueStrings is true, null and [] will be converted to [null], otherwise they will retain
|
||||
*/
|
||||
@Nullable
|
||||
public static Object coerceListToArray(@Nullable List<?> val, boolean homogenizeMultiValueStrings)
|
||||
{
|
||||
// if value is not null and has at least 1 element, conversion is unambigous regardless of the selector
|
||||
if (val != null && val.size() > 0) {
|
||||
Class<?> coercedType = null;
|
||||
|
||||
for (Object elem : val) {
|
||||
if (elem != null) {
|
||||
coercedType = convertType(coercedType, elem.getClass());
|
||||
}
|
||||
}
|
||||
|
||||
if (coercedType == Long.class || coercedType == Integer.class) {
|
||||
return val.stream().map(x -> x != null ? ((Number) x).longValue() : null).toArray(Long[]::new);
|
||||
}
|
||||
if (coercedType == Float.class || coercedType == Double.class) {
|
||||
return val.stream().map(x -> x != null ? ((Number) x).doubleValue() : null).toArray(Double[]::new);
|
||||
}
|
||||
// default to string
|
||||
return val.stream().map(x -> x != null ? x.toString() : null).toArray(String[]::new);
|
||||
}
|
||||
if (homogenizeMultiValueStrings) {
|
||||
return new String[]{null};
|
||||
} else {
|
||||
if (val != null) {
|
||||
return val.toArray();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the common type to use between 2 types, useful for choosing the appropriate type for an array given a set
|
||||
* of objects with unknown type, following rules similar to Java, our own native Expr, and SQL implicit type
|
||||
* conversions. This is used to assist in preparing native java objects for {@link Expr.ObjectBinding} which will
|
||||
* later be wrapped in {@link ExprEval} when evaluating {@link IdentifierExpr}.
|
||||
*
|
||||
* If any type is string, then the result will be string because everything can be converted to a string, but a string
|
||||
* cannot be converted to everything.
|
||||
*
|
||||
* For numbers, integer is the most restrictive type, only chosen if both types are integers. Longs win over integers,
|
||||
* floats over longs and integers, and doubles win over everything.
|
||||
*/
|
||||
private static Class convertType(@Nullable Class existing, Class next)
|
||||
{
|
||||
if (Number.class.isAssignableFrom(next) || next == String.class) {
|
||||
if (existing == null) {
|
||||
return next;
|
||||
}
|
||||
// string wins everything
|
||||
if (existing == String.class) {
|
||||
return existing;
|
||||
}
|
||||
if (next == String.class) {
|
||||
return next;
|
||||
}
|
||||
// all numbers win over Integer
|
||||
if (existing == Integer.class) {
|
||||
return next;
|
||||
}
|
||||
if (existing == Float.class) {
|
||||
// doubles win over floats
|
||||
if (next == Double.class) {
|
||||
return next;
|
||||
}
|
||||
return existing;
|
||||
}
|
||||
if (existing == Long.class) {
|
||||
if (next == Integer.class) {
|
||||
// long beats int
|
||||
return existing;
|
||||
}
|
||||
// double and float win over longs
|
||||
return next;
|
||||
}
|
||||
// otherwise double
|
||||
return Double.class;
|
||||
}
|
||||
throw new UOE("Invalid array expression type: %s", next);
|
||||
}
|
||||
|
||||
public static ExprEval ofLong(@Nullable Number longValue)
|
||||
{
|
||||
return new LongExprEval(longValue);
|
||||
|
@ -118,6 +460,13 @@ public abstract class ExprEval<T>
|
|||
return new StringArrayExprEval((String[]) val);
|
||||
}
|
||||
|
||||
if (val instanceof List) {
|
||||
// do not convert empty lists to arrays with a single null element here, because that should have been done
|
||||
// by the selectors preparing their ObjectBindings if necessary. If we get to this point it was legitimately
|
||||
// empty
|
||||
return bestEffortOf(coerceListToArray((List<?>) val, false));
|
||||
}
|
||||
|
||||
return new StringExprEval(val == null ? null : String.valueOf(val));
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
|
||||
package org.apache.druid.math.expr;
|
||||
|
||||
import it.unimi.dsi.fastutil.bytes.Byte2ObjectArrayMap;
|
||||
import it.unimi.dsi.fastutil.bytes.Byte2ObjectMap;
|
||||
import org.apache.druid.java.util.common.ISE;
|
||||
import org.apache.druid.segment.column.ValueType;
|
||||
|
||||
|
@ -29,13 +31,32 @@ import javax.annotation.Nullable;
|
|||
*/
|
||||
public enum ExprType
|
||||
{
|
||||
DOUBLE,
|
||||
LONG,
|
||||
STRING,
|
||||
DOUBLE_ARRAY,
|
||||
LONG_ARRAY,
|
||||
STRING_ARRAY;
|
||||
DOUBLE((byte) 0x01),
|
||||
LONG((byte) 0x02),
|
||||
STRING((byte) 0x03),
|
||||
DOUBLE_ARRAY((byte) 0x04),
|
||||
LONG_ARRAY((byte) 0x05),
|
||||
STRING_ARRAY((byte) 0x06);
|
||||
|
||||
private static final Byte2ObjectMap<ExprType> TYPE_BYTES = new Byte2ObjectArrayMap<>(ExprType.values().length);
|
||||
|
||||
static {
|
||||
for (ExprType type : ExprType.values()) {
|
||||
TYPE_BYTES.put(type.getId(), type);
|
||||
}
|
||||
}
|
||||
|
||||
final byte id;
|
||||
|
||||
ExprType(byte id)
|
||||
{
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public byte getId()
|
||||
{
|
||||
return id;
|
||||
}
|
||||
|
||||
public boolean isNumeric()
|
||||
{
|
||||
|
@ -47,6 +68,11 @@ public enum ExprType
|
|||
return isScalar(this);
|
||||
}
|
||||
|
||||
public static ExprType fromByte(byte id)
|
||||
{
|
||||
return TYPE_BYTES.get(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* The expression system does not distinguish between {@link ValueType#FLOAT} and {@link ValueType#DOUBLE}, and
|
||||
* cannot currently handle {@link ValueType#COMPLEX} inputs. This method will convert {@link ValueType#FLOAT} to
|
||||
|
@ -177,5 +203,4 @@ public enum ExprType
|
|||
}
|
||||
return elementType;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -42,6 +42,7 @@ import java.util.ArrayList;
|
|||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
@ -384,13 +385,13 @@ public interface Function
|
|||
@Override
|
||||
public Set<Expr> getScalarInputs(List<Expr> args)
|
||||
{
|
||||
return ImmutableSet.of(args.get(1));
|
||||
return ImmutableSet.of(getScalarArgument(args));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<Expr> getArrayInputs(List<Expr> args)
|
||||
{
|
||||
return ImmutableSet.of(args.get(0));
|
||||
return ImmutableSet.of(getArrayArgument(args));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -402,14 +403,24 @@ public interface Function
|
|||
@Override
|
||||
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
|
||||
{
|
||||
final ExprEval arrayExpr = args.get(0).eval(bindings);
|
||||
final ExprEval scalarExpr = args.get(1).eval(bindings);
|
||||
final ExprEval arrayExpr = getArrayArgument(args).eval(bindings);
|
||||
final ExprEval scalarExpr = getScalarArgument(args).eval(bindings);
|
||||
if (arrayExpr.asArray() == null) {
|
||||
return ExprEval.of(null);
|
||||
}
|
||||
return doApply(arrayExpr, scalarExpr);
|
||||
}
|
||||
|
||||
Expr getScalarArgument(List<Expr> args)
|
||||
{
|
||||
return args.get(1);
|
||||
}
|
||||
|
||||
Expr getArrayArgument(List<Expr> args)
|
||||
{
|
||||
return args.get(0);
|
||||
}
|
||||
|
||||
abstract ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr);
|
||||
}
|
||||
|
||||
|
@ -450,8 +461,11 @@ public interface Function
|
|||
final ExprEval arrayExpr1 = args.get(0).eval(bindings);
|
||||
final ExprEval arrayExpr2 = args.get(1).eval(bindings);
|
||||
|
||||
if (arrayExpr1.asArray() == null || arrayExpr2.asArray() == null) {
|
||||
return ExprEval.of(null);
|
||||
if (arrayExpr1.asArray() == null) {
|
||||
return arrayExpr1;
|
||||
}
|
||||
if (arrayExpr2.asArray() == null) {
|
||||
return arrayExpr2;
|
||||
}
|
||||
|
||||
return doApply(arrayExpr1, arrayExpr2);
|
||||
|
@ -460,6 +474,118 @@ public interface Function
|
|||
abstract ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Scaffolding for a 2 argument {@link Function} which accepts one array and one scalar input and adds the scalar
|
||||
* input to the array in some way.
|
||||
*/
|
||||
abstract class ArrayAddElementFunction extends ArrayScalarFunction
|
||||
{
|
||||
@Override
|
||||
public boolean hasArrayOutput()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
|
||||
{
|
||||
ExprType arrayType = getArrayArgument(args).getOutputType(inspector);
|
||||
return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
|
||||
}
|
||||
|
||||
@Override
|
||||
ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
|
||||
{
|
||||
switch (arrayExpr.type()) {
|
||||
case STRING:
|
||||
case STRING_ARRAY:
|
||||
return ExprEval.ofStringArray(add(arrayExpr.asStringArray(), scalarExpr.asString()).toArray(String[]::new));
|
||||
case LONG:
|
||||
case LONG_ARRAY:
|
||||
return ExprEval.ofLongArray(
|
||||
add(
|
||||
arrayExpr.asLongArray(),
|
||||
scalarExpr.isNumericNull() ? null : scalarExpr.asLong()
|
||||
).toArray(Long[]::new)
|
||||
);
|
||||
case DOUBLE:
|
||||
case DOUBLE_ARRAY:
|
||||
return ExprEval.ofDoubleArray(
|
||||
add(
|
||||
arrayExpr.asDoubleArray(),
|
||||
scalarExpr.isNumericNull() ? null : scalarExpr.asDouble()
|
||||
).toArray(Double[]::new)
|
||||
);
|
||||
}
|
||||
|
||||
throw new RE("Unable to add to unknown array type %s", arrayExpr.type());
|
||||
}
|
||||
|
||||
abstract <T> Stream<T> add(T[] array, @Nullable T val);
|
||||
}
|
||||
|
||||
/**
|
||||
* Base scaffolding for functions which accept 2 array arguments and combine them in some way
|
||||
*/
|
||||
abstract class ArraysMergeFunction extends ArraysFunction
|
||||
{
|
||||
@Override
|
||||
public Set<Expr> getArrayInputs(List<Expr> args)
|
||||
{
|
||||
return ImmutableSet.copyOf(args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasArrayOutput()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
|
||||
{
|
||||
ExprType arrayType = args.get(0).getOutputType(inspector);
|
||||
return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
|
||||
}
|
||||
|
||||
@Override
|
||||
ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr)
|
||||
{
|
||||
final Object[] array1 = lhsExpr.asArray();
|
||||
final Object[] array2 = rhsExpr.asArray();
|
||||
|
||||
if (array1 == null) {
|
||||
return ExprEval.of(null);
|
||||
}
|
||||
if (array2 == null) {
|
||||
return lhsExpr;
|
||||
}
|
||||
|
||||
switch (lhsExpr.type()) {
|
||||
case STRING:
|
||||
case STRING_ARRAY:
|
||||
return ExprEval.ofStringArray(
|
||||
merge(lhsExpr.asStringArray(), rhsExpr.asStringArray()).toArray(String[]::new)
|
||||
);
|
||||
case LONG:
|
||||
case LONG_ARRAY:
|
||||
return ExprEval.ofLongArray(
|
||||
merge(lhsExpr.asLongArray(), rhsExpr.asLongArray()).toArray(Long[]::new)
|
||||
);
|
||||
case DOUBLE:
|
||||
case DOUBLE_ARRAY:
|
||||
return ExprEval.ofDoubleArray(
|
||||
merge(lhsExpr.asDoubleArray(), rhsExpr.asDoubleArray()).toArray(Double[]::new)
|
||||
);
|
||||
}
|
||||
throw new RE("Unable to concatenate to unknown type %s", lhsExpr.type());
|
||||
}
|
||||
|
||||
abstract <T> Stream<T> merge(T[] array1, T[] array2);
|
||||
}
|
||||
|
||||
abstract class ReduceFunction implements Function
|
||||
{
|
||||
private final DoubleBinaryOperator doubleReducer;
|
||||
|
@ -3168,7 +3294,7 @@ public interface Function
|
|||
}
|
||||
}
|
||||
|
||||
class ArrayAppendFunction extends ArrayScalarFunction
|
||||
class ArrayAppendFunction extends ArrayAddElementFunction
|
||||
{
|
||||
@Override
|
||||
public String name()
|
||||
|
@ -3177,48 +3303,7 @@ public interface Function
|
|||
}
|
||||
|
||||
@Override
|
||||
public boolean hasArrayOutput()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
|
||||
{
|
||||
ExprType arrayType = args.get(0).getOutputType(inspector);
|
||||
return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
|
||||
}
|
||||
|
||||
@Override
|
||||
ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
|
||||
{
|
||||
switch (arrayExpr.type()) {
|
||||
case STRING:
|
||||
case STRING_ARRAY:
|
||||
return ExprEval.ofStringArray(this.append(arrayExpr.asStringArray(), scalarExpr.asString()).toArray(String[]::new));
|
||||
case LONG:
|
||||
case LONG_ARRAY:
|
||||
return ExprEval.ofLongArray(
|
||||
this.append(
|
||||
arrayExpr.asLongArray(),
|
||||
scalarExpr.isNumericNull() ? null : scalarExpr.asLong()).toArray(Long[]::new
|
||||
)
|
||||
);
|
||||
case DOUBLE:
|
||||
case DOUBLE_ARRAY:
|
||||
return ExprEval.ofDoubleArray(
|
||||
this.append(
|
||||
arrayExpr.asDoubleArray(),
|
||||
scalarExpr.isNumericNull() ? null : scalarExpr.asDouble()).toArray(Double[]::new
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
throw new RE("Unable to append to unknown type %s", arrayExpr.type());
|
||||
}
|
||||
|
||||
private <T> Stream<T> append(T[] array, T val)
|
||||
<T> Stream<T> add(T[] array, @Nullable T val)
|
||||
{
|
||||
List<T> l = new ArrayList<>(Arrays.asList(array));
|
||||
l.add(val);
|
||||
|
@ -3226,7 +3311,36 @@ public interface Function
|
|||
}
|
||||
}
|
||||
|
||||
class ArrayConcatFunction extends ArraysFunction
|
||||
class ArrayPrependFunction extends ArrayAddElementFunction
|
||||
{
|
||||
@Override
|
||||
public String name()
|
||||
{
|
||||
return "array_prepend";
|
||||
}
|
||||
|
||||
@Override
|
||||
Expr getScalarArgument(List<Expr> args)
|
||||
{
|
||||
return args.get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
Expr getArrayArgument(List<Expr> args)
|
||||
{
|
||||
return args.get(1);
|
||||
}
|
||||
|
||||
@Override
|
||||
<T> Stream<T> add(T[] array, @Nullable T val)
|
||||
{
|
||||
List<T> l = new ArrayList<>(Arrays.asList(array));
|
||||
l.add(0, val);
|
||||
return l.stream();
|
||||
}
|
||||
}
|
||||
|
||||
class ArrayConcatFunction extends ArraysMergeFunction
|
||||
{
|
||||
@Override
|
||||
public String name()
|
||||
|
@ -3235,59 +3349,7 @@ public interface Function
|
|||
}
|
||||
|
||||
@Override
|
||||
public Set<Expr> getArrayInputs(List<Expr> args)
|
||||
{
|
||||
return ImmutableSet.copyOf(args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasArrayOutput()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
|
||||
{
|
||||
ExprType arrayType = args.get(0).getOutputType(inspector);
|
||||
return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
|
||||
}
|
||||
|
||||
@Override
|
||||
ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr)
|
||||
{
|
||||
final Object[] array1 = lhsExpr.asArray();
|
||||
final Object[] array2 = rhsExpr.asArray();
|
||||
|
||||
if (array1 == null) {
|
||||
return ExprEval.of(null);
|
||||
}
|
||||
if (array2 == null) {
|
||||
return lhsExpr;
|
||||
}
|
||||
|
||||
switch (lhsExpr.type()) {
|
||||
case STRING:
|
||||
case STRING_ARRAY:
|
||||
return ExprEval.ofStringArray(
|
||||
cat(lhsExpr.asStringArray(), rhsExpr.asStringArray()).toArray(String[]::new)
|
||||
);
|
||||
case LONG:
|
||||
case LONG_ARRAY:
|
||||
return ExprEval.ofLongArray(
|
||||
cat(lhsExpr.asLongArray(), rhsExpr.asLongArray()).toArray(Long[]::new)
|
||||
);
|
||||
case DOUBLE:
|
||||
case DOUBLE_ARRAY:
|
||||
return ExprEval.ofDoubleArray(
|
||||
cat(lhsExpr.asDoubleArray(), rhsExpr.asDoubleArray()).toArray(Double[]::new)
|
||||
);
|
||||
}
|
||||
throw new RE("Unable to concatenate to unknown type %s", lhsExpr.type());
|
||||
}
|
||||
|
||||
private <T> Stream<T> cat(T[] array1, T[] array2)
|
||||
<T> Stream<T> merge(T[] array1, T[] array2)
|
||||
{
|
||||
List<T> l = new ArrayList<>(Arrays.asList(array1));
|
||||
l.addAll(Arrays.asList(array2));
|
||||
|
@ -3295,6 +3357,40 @@ public interface Function
|
|||
}
|
||||
}
|
||||
|
||||
class ArraySetAddFunction extends ArrayAddElementFunction
|
||||
{
|
||||
@Override
|
||||
public String name()
|
||||
{
|
||||
return "array_set_add";
|
||||
}
|
||||
|
||||
@Override
|
||||
<T> Stream<T> add(T[] array, @Nullable T val)
|
||||
{
|
||||
Set<T> l = new HashSet<>(Arrays.asList(array));
|
||||
l.add(val);
|
||||
return l.stream();
|
||||
}
|
||||
}
|
||||
|
||||
class ArraySetAddAllFunction extends ArraysMergeFunction
|
||||
{
|
||||
@Override
|
||||
public String name()
|
||||
{
|
||||
return "array_set_add_all";
|
||||
}
|
||||
|
||||
@Override
|
||||
<T> Stream<T> merge(T[] array1, T[] array2)
|
||||
{
|
||||
Set<T> l = new HashSet<>(Arrays.asList(array1));
|
||||
l.addAll(Arrays.asList(array2));
|
||||
return l.stream();
|
||||
}
|
||||
}
|
||||
|
||||
class ArrayContainsFunction extends ArraysFunction
|
||||
{
|
||||
@Override
|
||||
|
@ -3438,93 +3534,4 @@ public interface Function
|
|||
throw new RE("Unable to slice to unknown type %s", expr.type());
|
||||
}
|
||||
}
|
||||
|
||||
class ArrayPrependFunction implements Function
|
||||
{
|
||||
@Override
|
||||
public String name()
|
||||
{
|
||||
return "array_prepend";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validateArguments(List<Expr> args)
|
||||
{
|
||||
if (args.size() != 2) {
|
||||
throw new IAE("Function[%s] needs 2 arguments", name());
|
||||
}
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
|
||||
{
|
||||
ExprType arrayType = args.get(1).getOutputType(inspector);
|
||||
return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<Expr> getScalarInputs(List<Expr> args)
|
||||
{
|
||||
return ImmutableSet.of(args.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<Expr> getArrayInputs(List<Expr> args)
|
||||
{
|
||||
return ImmutableSet.of(args.get(1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasArrayInputs()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasArrayOutput()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
|
||||
{
|
||||
final ExprEval scalarExpr = args.get(0).eval(bindings);
|
||||
final ExprEval arrayExpr = args.get(1).eval(bindings);
|
||||
if (arrayExpr.asArray() == null) {
|
||||
return ExprEval.of(null);
|
||||
}
|
||||
switch (arrayExpr.type()) {
|
||||
case STRING:
|
||||
case STRING_ARRAY:
|
||||
return ExprEval.ofStringArray(this.prepend(scalarExpr.asString(), arrayExpr.asStringArray()).toArray(String[]::new));
|
||||
case LONG:
|
||||
case LONG_ARRAY:
|
||||
return ExprEval.ofLongArray(
|
||||
this.prepend(
|
||||
scalarExpr.isNumericNull() ? null : scalarExpr.asLong(),
|
||||
arrayExpr.asLongArray()).toArray(Long[]::new
|
||||
)
|
||||
);
|
||||
case DOUBLE:
|
||||
case DOUBLE_ARRAY:
|
||||
return ExprEval.ofDoubleArray(
|
||||
this.prepend(
|
||||
scalarExpr.isNumericNull() ? null : scalarExpr.asDouble(),
|
||||
arrayExpr.asDoubleArray()).toArray(Double[]::new
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
throw new RE("Unable to prepend to unknown type %s", arrayExpr.type());
|
||||
}
|
||||
|
||||
private <T> Stream<T> prepend(T val, T[] array)
|
||||
{
|
||||
List<T> l = new ArrayList<>(Arrays.asList(array));
|
||||
l.add(0, val);
|
||||
return l.stream();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ package org.apache.druid.math.expr;
|
|||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Supplier;
|
||||
import com.google.common.base.Suppliers;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Sets;
|
||||
|
@ -35,12 +36,14 @@ import org.apache.druid.java.util.common.logger.Logger;
|
|||
import org.apache.druid.math.expr.antlr.ExprLexer;
|
||||
import org.apache.druid.math.expr.antlr.ExprParser;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.UnaryOperator;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class Parser
|
||||
|
@ -96,17 +99,33 @@ public class Parser
|
|||
}
|
||||
|
||||
/**
|
||||
* Parse a string into a flattened {@link Expr}. There is some overhead to this, and these objects are all immutable,
|
||||
* so re-use instead of re-creating whenever possible.
|
||||
* Create a memoized lazy supplier to parse a string into a flattened {@link Expr}. There is some overhead to this,
|
||||
* and these objects are all immutable, so this assists in the goal of re-using instead of re-creating whenever
|
||||
* possible.
|
||||
*
|
||||
* Lazy form of {@link #parse(String, ExprMacroTable)}
|
||||
*
|
||||
* @param in expression to parse
|
||||
* @param macroTable additional extensions to expression language
|
||||
*/
|
||||
public static Supplier<Expr> lazyParse(@Nullable String in, ExprMacroTable macroTable)
|
||||
{
|
||||
return Suppliers.memoize(() -> in == null ? null : Parser.parse(in, macroTable));
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a string into a flattened {@link Expr}. There is some overhead to this, and these objects are all immutable,
|
||||
* so re-use instead of re-creating whenever possible.
|
||||
*
|
||||
* @param in expression to parse
|
||||
* @param macroTable additional extensions to expression language
|
||||
* @return
|
||||
*/
|
||||
public static Expr parse(String in, ExprMacroTable macroTable)
|
||||
{
|
||||
return parse(in, macroTable, true);
|
||||
}
|
||||
|
||||
|
||||
@VisibleForTesting
|
||||
public static Expr parse(String in, ExprMacroTable macroTable, boolean withFlatten)
|
||||
{
|
||||
|
@ -164,47 +183,43 @@ public class Parser
|
|||
/**
|
||||
* Applies a transformation to an {@link Expr} given a list of known (or uknown) multi-value input columns that are
|
||||
* used in a scalar manner, walking the {@link Expr} tree and lifting array variables into the {@link LambdaExpr} of
|
||||
* {@link ApplyFunctionExpr} and transforming the arguments of {@link FunctionExpr}
|
||||
* @param expr expression to visit and rewrite
|
||||
* @param bindingsToApply
|
||||
* @return
|
||||
* {@link ApplyFunctionExpr} and transforming the arguments of {@link FunctionExpr} as necessary.
|
||||
*
|
||||
* This function applies a transformation for "map" style uses, such as column selectors, where the supplied
|
||||
* expression will be transformed to return an array of results instead of the scalar result (or appropriately
|
||||
* rewritten into existing apply expressions to produce correct results when referenced from a scalar context).
|
||||
*
|
||||
* This function and {@link #foldUnappliedBindings(Expr, Expr.BindingAnalysis, List, String)} exist to handle
|
||||
* "multi-valued" string dimensions, which exist in a superposition of both single and multi-valued during realtime
|
||||
* ingestion, until they are written to a segment and become locked into either single or multi-valued. This also
|
||||
* means that multi-valued-ness can vary for a column from segment to segment, so this family of transformation
|
||||
* functions exist so that multi-valued strings can be expressed in either and array or scalar context, which is
|
||||
* important because the writer of the query might not actually know if the column is definitively always single or
|
||||
* multi-valued (and it might in fact not be).
|
||||
*
|
||||
* @see #foldUnappliedBindings(Expr, Expr.BindingAnalysis, List, String)
|
||||
*/
|
||||
public static Expr applyUnappliedBindings(Expr expr, Expr.BindingAnalysis bindingAnalysis, List<String> bindingsToApply)
|
||||
public static Expr applyUnappliedBindings(
|
||||
Expr expr,
|
||||
Expr.BindingAnalysis bindingAnalysis,
|
||||
List<String> bindingsToApply
|
||||
)
|
||||
{
|
||||
if (bindingsToApply.isEmpty()) {
|
||||
// nothing to do, expression is fine as is
|
||||
return expr;
|
||||
}
|
||||
// filter the list of bindings to those which are used in this expression
|
||||
List<String> unappliedBindingsInExpression = bindingsToApply.stream()
|
||||
.filter(x -> bindingAnalysis.getRequiredBindings().contains(x))
|
||||
.collect(Collectors.toList());
|
||||
List<String> unappliedBindingsInExpression =
|
||||
bindingsToApply.stream()
|
||||
.filter(x -> bindingAnalysis.getRequiredBindings().contains(x))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten
|
||||
Expr newExpr = expr.visit(
|
||||
childExpr -> {
|
||||
if (childExpr instanceof ApplyFunctionExpr) {
|
||||
// try to lift unapplied arguments into the apply function lambda
|
||||
return liftApplyLambda((ApplyFunctionExpr) childExpr, unappliedBindingsInExpression);
|
||||
} else if (childExpr instanceof FunctionExpr) {
|
||||
// check array function arguments for unapplied identifiers to transform if necessary
|
||||
FunctionExpr fnExpr = (FunctionExpr) childExpr;
|
||||
Set<Expr> arrayInputs = fnExpr.function.getArrayInputs(fnExpr.args);
|
||||
List<Expr> newArgs = new ArrayList<>();
|
||||
for (Expr arg : fnExpr.args) {
|
||||
if (arg.getIdentifierIfIdentifier() == null && arrayInputs.contains(arg)) {
|
||||
Expr newArg = applyUnappliedBindings(arg, bindingAnalysis, unappliedBindingsInExpression);
|
||||
newArgs.add(newArg);
|
||||
} else {
|
||||
newArgs.add(arg);
|
||||
}
|
||||
}
|
||||
|
||||
FunctionExpr newFnExpr = new FunctionExpr(fnExpr.function, fnExpr.function.name(), newArgs);
|
||||
return newFnExpr;
|
||||
}
|
||||
return childExpr;
|
||||
}
|
||||
Expr newExpr = rewriteUnappliedSubExpressions(
|
||||
expr,
|
||||
unappliedBindingsInExpression,
|
||||
(arg) -> applyUnappliedBindings(arg, bindingAnalysis, bindingsToApply)
|
||||
);
|
||||
|
||||
Expr.BindingAnalysis newExprBindings = newExpr.analyzeInputs();
|
||||
|
@ -221,9 +236,123 @@ public class Parser
|
|||
return applyUnapplied(newExpr, remainingUnappliedBindings);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Applies a transformation to an {@link Expr} given a list of known (or uknown) multi-value input columns that are
|
||||
* used in a scalar manner, walking the {@link Expr} tree and lifting array variables into the {@link LambdaExpr} of
|
||||
* {@link ApplyFunctionExpr} and transforming the arguments of {@link FunctionExpr} as necessary.
|
||||
*
|
||||
* This function applies a transformation for "fold" style uses, such as aggregators, where the supplied
|
||||
* expression will be transformed to accumulate the result of applying the expression to each value of the unapplied
|
||||
* input (or appropriately rewritten into existing apply expressions to produce correct results when referenced from
|
||||
* a scalar context). This rewriting assumes that there exists some accumulator variable, which is re-used as the
|
||||
* accumulator for this fold rewrite, so that evaluating each expression can be accumulated into the larger external
|
||||
* fold operation that an aggregator might be performing.
|
||||
*
|
||||
* This function and {@link #applyUnappliedBindings(Expr, Expr.BindingAnalysis, List)} exist to handle
|
||||
* "multi-valued" string dimensions, which exist in a superposition of both single and multi-valued during realtime
|
||||
* ingestion, until they are written to a segment and become locked into either single or multi-valued. This also
|
||||
* means that multi-valued-ness can vary for a column from segment to segment, so this family of transformation
|
||||
* functions exist so that multi-valued strings can be expressed in either and array or scalar context, which is
|
||||
* important because the writer of the query might not actually know if the column is definitively always single or
|
||||
* multi-valued (and it might in fact not be).
|
||||
*
|
||||
* @see #applyUnappliedBindings(Expr, Expr.BindingAnalysis, List)
|
||||
*/
|
||||
public static Expr foldUnappliedBindings(Expr expr, Expr.BindingAnalysis bindingAnalysis, List<String> bindingsToApply, String accumulatorId)
|
||||
{
|
||||
if (bindingsToApply.isEmpty()) {
|
||||
// nothing to do, expression is fine as is
|
||||
return expr;
|
||||
}
|
||||
|
||||
// filter the list of bindings to those which are used in this expression
|
||||
List<String> unappliedBindingsInExpression =
|
||||
bindingsToApply.stream()
|
||||
.filter(x -> bindingAnalysis.getRequiredBindings().contains(x))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Expr newExpr = rewriteUnappliedSubExpressions(
|
||||
expr,
|
||||
unappliedBindingsInExpression,
|
||||
(arg) -> foldUnappliedBindings(arg, bindingAnalysis, bindingsToApply, accumulatorId)
|
||||
);
|
||||
|
||||
Expr.BindingAnalysis newExprBindings = newExpr.analyzeInputs();
|
||||
final Set<String> expectedArrays = newExprBindings.getArrayVariables();
|
||||
|
||||
List<String> remainingUnappliedBindings =
|
||||
unappliedBindingsInExpression.stream().filter(x -> !expectedArrays.contains(x)).collect(Collectors.toList());
|
||||
|
||||
// if lifting the lambdas got rid of all missing bindings, return the transformed expression
|
||||
if (remainingUnappliedBindings.isEmpty()) {
|
||||
return newExpr;
|
||||
}
|
||||
|
||||
return foldUnapplied(newExpr, remainingUnappliedBindings, accumulatorId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten to "lift"
|
||||
* the identifier variables and transform the function.
|
||||
*
|
||||
* For example:
|
||||
* if "y" is unapplied:
|
||||
* map((x) -> x + y, x) => cartesian_map((x,y) -> x + y, x, y)
|
||||
*
|
||||
* @see #liftApplyLambda(ApplyFunctionExpr, List)
|
||||
*
|
||||
* Array functions on expressions using unapplied identifiers might also need transformed, so we recursively call the
|
||||
* unapplied binding transformation function (supplied to this method) on that expression to ensure proper
|
||||
* transformation and rewrite of these array expressions.
|
||||
*
|
||||
* For example:
|
||||
* if "y" is unapplied:
|
||||
* array_length(filter((x) -> x > y, x))
|
||||
*/
|
||||
private static Expr rewriteUnappliedSubExpressions(
|
||||
Expr expr,
|
||||
List<String> unappliedBindingsInExpression,
|
||||
UnaryOperator<Expr> applyUnappliedFn
|
||||
)
|
||||
{
|
||||
// any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten
|
||||
return expr.visit(
|
||||
childExpr -> {
|
||||
if (childExpr instanceof ApplyFunctionExpr) {
|
||||
// try to lift unapplied arguments into the apply function lambda
|
||||
return liftApplyLambda((ApplyFunctionExpr) childExpr, unappliedBindingsInExpression);
|
||||
} else if (childExpr instanceof FunctionExpr) {
|
||||
// check array function arguments for unapplied identifiers to transform if necessary
|
||||
FunctionExpr fnExpr = (FunctionExpr) childExpr;
|
||||
Set<Expr> arrayInputs = fnExpr.function.getArrayInputs(fnExpr.args);
|
||||
List<Expr> newArgs = new ArrayList<>();
|
||||
for (Expr arg : fnExpr.args) {
|
||||
if (arg.getIdentifierIfIdentifier() == null && arrayInputs.contains(arg)) {
|
||||
Expr newArg = applyUnappliedFn.apply(arg);
|
||||
newArgs.add(newArg);
|
||||
} else {
|
||||
newArgs.add(arg);
|
||||
}
|
||||
}
|
||||
|
||||
FunctionExpr newFnExpr = new FunctionExpr(fnExpr.function, fnExpr.function.name(), newArgs);
|
||||
return newFnExpr;
|
||||
}
|
||||
return childExpr;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* translate an {@link Expr} into an {@link ApplyFunctionExpr} for {@link ApplyFunction.MapFunction} or
|
||||
* {@link ApplyFunction.CartesianMapFunction} if there are multiple unbound arguments to be applied
|
||||
* {@link ApplyFunction.CartesianMapFunction} if there are multiple unbound arguments to be applied.
|
||||
*
|
||||
* For example:
|
||||
* if "x" is unapplied:
|
||||
* x + y => map((x) -> x + y, x)
|
||||
* if "x" and "y" are unapplied:
|
||||
* x + y => cartesian_map((x, y) -> x + y, x, y)
|
||||
*/
|
||||
private static Expr applyUnapplied(Expr expr, List<String> unappliedBindings)
|
||||
{
|
||||
|
@ -274,6 +403,72 @@ public class Parser
|
|||
return magic;
|
||||
}
|
||||
|
||||
/**
|
||||
* translate an {@link Expr} into an {@link ApplyFunctionExpr} for {@link ApplyFunction.FoldFunction} or
|
||||
* {@link ApplyFunction.CartesianFoldFunction} if there are multiple unbound arguments to be applied.
|
||||
*
|
||||
* This assumes a known {@link IdentifierExpr} is an "accumulator", which is re-used as the accumulator variable and
|
||||
* input for the translated fold.
|
||||
*
|
||||
* For example given an accumulator "__acc":
|
||||
* if "x" is unapplied:
|
||||
* __acc + x => fold((x, __acc) -> x + __acc, x, __acc)
|
||||
* if "x" and "y" are unapplied:
|
||||
* __acc + x + y => cartesian_fold((x, y, __acc) -> __acc + x + y, x, y, __acc)
|
||||
*
|
||||
*/
|
||||
private static Expr foldUnapplied(Expr expr, List<String> unappliedBindings, String accumulatorId)
|
||||
{
|
||||
|
||||
// filter to get list of IdentifierExpr that are backed by the unapplied bindings
|
||||
final List<IdentifierExpr> args = expr.analyzeInputs()
|
||||
.getFreeVariables()
|
||||
.stream()
|
||||
.filter(x -> unappliedBindings.contains(x.getBinding()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
final List<IdentifierExpr> lambdaArgs = new ArrayList<>();
|
||||
|
||||
// construct lambda args from list of args to apply. Identifiers in a lambda body have artificial 'binding' values
|
||||
// that is the same as the 'identifier', because the bindings are supplied by the wrapping apply function
|
||||
// replacements are done by binding rather than identifier because repeats of the same input should not result
|
||||
// in a cartesian product
|
||||
final Map<String, IdentifierExpr> toReplace = new HashMap<>();
|
||||
for (IdentifierExpr applyFnArg : args) {
|
||||
if (!toReplace.containsKey(applyFnArg.getBinding())) {
|
||||
IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getBinding());
|
||||
lambdaArgs.add(lambdaRewrite);
|
||||
toReplace.put(applyFnArg.getBinding(), lambdaRewrite);
|
||||
}
|
||||
}
|
||||
|
||||
lambdaArgs.add(new IdentifierExpr(accumulatorId));
|
||||
|
||||
// rewrite identifiers in the expression which will become the lambda body, so they match the lambda identifiers we
|
||||
// are constructing
|
||||
Expr newExpr = expr.visit(childExpr -> {
|
||||
if (childExpr instanceof IdentifierExpr) {
|
||||
if (toReplace.containsKey(((IdentifierExpr) childExpr).getBinding())) {
|
||||
return toReplace.get(((IdentifierExpr) childExpr).getBinding());
|
||||
}
|
||||
}
|
||||
return childExpr;
|
||||
});
|
||||
|
||||
|
||||
// wrap an expression in either fold or cartesian_fold to apply any unapplied identifiers
|
||||
final LambdaExpr lambdaExpr = new LambdaExpr(lambdaArgs, newExpr);
|
||||
final ApplyFunction fn;
|
||||
if (lambdaArgs.size() == 2) {
|
||||
fn = new ApplyFunction.FoldFunction();
|
||||
} else {
|
||||
fn = new ApplyFunction.CartesianFoldFunction();
|
||||
}
|
||||
|
||||
final Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, ImmutableList.copyOf(lambdaArgs));
|
||||
return magic;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs partial lifting of free identifiers of the lambda expression of an {@link ApplyFunctionExpr}, constrained
|
||||
* by a list of "unapplied" identifiers, and translating them into arguments of a new {@link LambdaExpr} and
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.druid.math.expr;
|
||||
|
||||
import com.google.common.collect.Maps;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Simple map backed object binding
|
||||
*/
|
||||
public class SettableObjectBinding implements Expr.ObjectBinding
|
||||
{
|
||||
private final Map<String, Object> bindings;
|
||||
|
||||
public SettableObjectBinding()
|
||||
{
|
||||
this.bindings = new HashMap<>();
|
||||
}
|
||||
|
||||
public SettableObjectBinding(int expectedSize)
|
||||
{
|
||||
this.bindings = Maps.newHashMapWithExpectedSize(expectedSize);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public Object get(String name)
|
||||
{
|
||||
return bindings.get(name);
|
||||
}
|
||||
|
||||
public SettableObjectBinding withBinding(String name, @Nullable Object value)
|
||||
{
|
||||
bindings.put(name, value);
|
||||
return this;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,220 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.druid.math.expr;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.apache.druid.java.util.common.ISE;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.testing.InitializedNullHandlingTest;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class ExprEvalTest extends InitializedNullHandlingTest
|
||||
{
|
||||
private static int MAX_SIZE_BYTES = 1 << 13;
|
||||
|
||||
@Rule
|
||||
public ExpectedException expectedException = ExpectedException.none();
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(1 << 16);
|
||||
|
||||
@Test
|
||||
public void testStringSerde()
|
||||
{
|
||||
assertExpr(0, "hello");
|
||||
assertExpr(1234, "hello");
|
||||
assertExpr(0, ExprEval.bestEffortOf(null));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStringSerdeTooBig()
|
||||
{
|
||||
expectedException.expect(ISE.class);
|
||||
expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.STRING, 16, 10));
|
||||
assertExpr(0, ExprEval.of("hello world"), 10);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testLongSerde()
|
||||
{
|
||||
assertExpr(0, 1L);
|
||||
assertExpr(1234, 1L);
|
||||
assertExpr(1234, ExprEval.ofLong(null));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDoubleSerde()
|
||||
{
|
||||
assertExpr(0, 1.123);
|
||||
assertExpr(1234, 1.123);
|
||||
assertExpr(1234, ExprEval.ofDouble(null));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStringArraySerde()
|
||||
{
|
||||
assertExpr(0, new String[] {"hello", "hi", "hey"});
|
||||
assertExpr(1024, new String[] {"hello", null, "hi", "hey"});
|
||||
assertExpr(2048, new String[] {});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStringArraySerdeToBig()
|
||||
{
|
||||
expectedException.expect(ISE.class);
|
||||
expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.STRING_ARRAY, 14, 10));
|
||||
assertExpr(0, ExprEval.ofStringArray(new String[] {"hello", "hi", "hey"}), 10);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLongArraySerde()
|
||||
{
|
||||
assertExpr(0, new Long[] {1L, 2L, 3L});
|
||||
assertExpr(1234, new Long[] {1L, 2L, null, 3L});
|
||||
assertExpr(1234, new Long[] {});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLongArraySerdeTooBig()
|
||||
{
|
||||
expectedException.expect(ISE.class);
|
||||
expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.LONG_ARRAY, 29, 10));
|
||||
assertExpr(0, ExprEval.ofLongArray(new Long[] {1L, 2L, 3L}), 10);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDoubleArraySerde()
|
||||
{
|
||||
assertExpr(0, new Double[] {1.1, 2.2, 3.3});
|
||||
assertExpr(1234, new Double[] {1.1, 2.2, null, 3.3});
|
||||
assertExpr(1234, new Double[] {});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDoubleArraySerdeTooBig()
|
||||
{
|
||||
expectedException.expect(ISE.class);
|
||||
expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.DOUBLE_ARRAY, 29, 10));
|
||||
assertExpr(0, ExprEval.ofDoubleArray(new Double[] {1.1, 2.2, 3.3}), 10);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_coerceListToArray()
|
||||
{
|
||||
Assert.assertNull(ExprEval.coerceListToArray(null, false));
|
||||
Assert.assertArrayEquals(new Object[0], (Object[]) ExprEval.coerceListToArray(ImmutableList.of(), false));
|
||||
Assert.assertArrayEquals(new String[]{null}, (String[]) ExprEval.coerceListToArray(null, true));
|
||||
Assert.assertArrayEquals(new String[]{null}, (String[]) ExprEval.coerceListToArray(ImmutableList.of(), true));
|
||||
|
||||
List<Long> longList = ImmutableList.of(1L, 2L, 3L);
|
||||
Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExprEval.coerceListToArray(longList, false));
|
||||
|
||||
List<Integer> intList = ImmutableList.of(1, 2, 3);
|
||||
Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExprEval.coerceListToArray(intList, false));
|
||||
|
||||
List<Float> floatList = ImmutableList.of(1.0f, 2.0f, 3.0f);
|
||||
Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExprEval.coerceListToArray(floatList, false));
|
||||
|
||||
List<Double> doubleList = ImmutableList.of(1.0, 2.0, 3.0);
|
||||
Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExprEval.coerceListToArray(doubleList, false));
|
||||
|
||||
List<String> stringList = ImmutableList.of("a", "b", "c");
|
||||
Assert.assertArrayEquals(new String[]{"a", "b", "c"}, (String[]) ExprEval.coerceListToArray(stringList, false));
|
||||
|
||||
List<String> withNulls = new ArrayList<>();
|
||||
withNulls.add("a");
|
||||
withNulls.add(null);
|
||||
withNulls.add("c");
|
||||
Assert.assertArrayEquals(new String[]{"a", null, "c"}, (String[]) ExprEval.coerceListToArray(withNulls, false));
|
||||
|
||||
List<Long> withNumberNulls = new ArrayList<>();
|
||||
withNumberNulls.add(1L);
|
||||
withNumberNulls.add(null);
|
||||
withNumberNulls.add(3L);
|
||||
|
||||
Assert.assertArrayEquals(new Long[]{1L, null, 3L}, (Long[]) ExprEval.coerceListToArray(withNumberNulls, false));
|
||||
|
||||
List<Object> withStringMix = ImmutableList.of(1L, "b", 3L);
|
||||
Assert.assertArrayEquals(
|
||||
new String[]{"1", "b", "3"},
|
||||
(String[]) ExprEval.coerceListToArray(withStringMix, false)
|
||||
);
|
||||
|
||||
List<Number> withIntsAndLongs = ImmutableList.of(1, 2L, 3);
|
||||
Assert.assertArrayEquals(
|
||||
new Long[]{1L, 2L, 3L},
|
||||
(Long[]) ExprEval.coerceListToArray(withIntsAndLongs, false)
|
||||
);
|
||||
|
||||
List<Number> withFloatsAndLongs = ImmutableList.of(1, 2L, 3.0f);
|
||||
Assert.assertArrayEquals(
|
||||
new Double[]{1.0, 2.0, 3.0},
|
||||
(Double[]) ExprEval.coerceListToArray(withFloatsAndLongs, false)
|
||||
);
|
||||
|
||||
List<Number> withDoublesAndLongs = ImmutableList.of(1, 2L, 3.0);
|
||||
Assert.assertArrayEquals(
|
||||
new Double[]{1.0, 2.0, 3.0},
|
||||
(Double[]) ExprEval.coerceListToArray(withDoublesAndLongs, false)
|
||||
);
|
||||
|
||||
List<Number> withFloatsAndDoubles = ImmutableList.of(1L, 2.0f, 3.0);
|
||||
Assert.assertArrayEquals(
|
||||
new Double[]{1.0, 2.0, 3.0},
|
||||
(Double[]) ExprEval.coerceListToArray(withFloatsAndDoubles, false)
|
||||
);
|
||||
|
||||
List<String> withAllNulls = new ArrayList<>();
|
||||
withAllNulls.add(null);
|
||||
withAllNulls.add(null);
|
||||
withAllNulls.add(null);
|
||||
Assert.assertArrayEquals(
|
||||
new String[]{null, null, null},
|
||||
(String[]) ExprEval.coerceListToArray(withAllNulls, false)
|
||||
);
|
||||
}
|
||||
|
||||
private void assertExpr(int position, Object expected)
|
||||
{
|
||||
assertExpr(position, ExprEval.bestEffortOf(expected));
|
||||
}
|
||||
|
||||
private void assertExpr(int position, ExprEval expected)
|
||||
{
|
||||
assertExpr(position, expected, MAX_SIZE_BYTES);
|
||||
}
|
||||
|
||||
private void assertExpr(int position, ExprEval expected, int maxSizeBytes)
|
||||
{
|
||||
ExprEval.serialize(buffer, position, expected, maxSizeBytes);
|
||||
if (ExprType.isArray(expected.type())) {
|
||||
Assert.assertArrayEquals(expected.asArray(), ExprEval.deserialize(buffer, position).asArray());
|
||||
} else {
|
||||
Assert.assertEquals(expected.value(), ExprEval.deserialize(buffer, position).value());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -290,6 +290,27 @@ public class FunctionTest extends InitializedNullHandlingTest
|
|||
assertArrayExpr("array_concat(0, 1)", new Long[]{0L, 1L});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testArraySetAdd()
|
||||
{
|
||||
assertArrayExpr("array_set_add([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
|
||||
assertArrayExpr("array_set_add([1, 2, 3], 'bar')", new Long[]{null, 1L, 2L, 3L});
|
||||
assertArrayExpr("array_set_add([1, 2, 2], 1)", new Long[]{1L, 2L});
|
||||
assertArrayExpr("array_set_add([], 1)", new String[]{"1"});
|
||||
assertArrayExpr("array_set_add(<LONG>[], 1)", new Long[]{1L});
|
||||
assertArrayExpr("array_set_add(<LONG>[], null)", new Long[]{null});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testArraySetAddAll()
|
||||
{
|
||||
assertArrayExpr("array_set_add_all([1, 2, 3], [2, 4, 6])", new Long[]{1L, 2L, 3L, 4L, 6L});
|
||||
assertArrayExpr("array_set_add_all([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
|
||||
assertArrayExpr("array_set_add_all(0, [1, 2, 3])", new Long[]{0L, 1L, 2L, 3L});
|
||||
assertArrayExpr("array_set_add_all(map(y -> y * 3, b), [1, 2, 3])", new Long[]{1L, 2L, 3L, 6L, 9L, 12L, 15L});
|
||||
assertArrayExpr("array_set_add_all(0, 1)", new Long[]{0L, 1L});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testArrayToString()
|
||||
{
|
||||
|
|
|
@ -528,6 +528,48 @@ public class ParserTest extends InitializedNullHandlingTest
|
|||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFoldUnapplied()
|
||||
{
|
||||
validateFoldUnapplied("x + __acc", "(+ x __acc)", "(+ x __acc)", ImmutableList.of(), "__acc");
|
||||
validateFoldUnapplied("x + __acc", "(+ x __acc)", "(+ x __acc)", ImmutableList.of("z"), "__acc");
|
||||
validateFoldUnapplied(
|
||||
"x + __acc",
|
||||
"(+ x __acc)",
|
||||
"(fold ([x, __acc] -> (+ x __acc)), [x, __acc])",
|
||||
ImmutableList.of("x"),
|
||||
"__acc"
|
||||
);
|
||||
validateFoldUnapplied(
|
||||
"x + y + __acc",
|
||||
"(+ (+ x y) __acc)",
|
||||
"(cartesian_fold ([x, y, __acc] -> (+ (+ x y) __acc)), [x, y, __acc])",
|
||||
ImmutableList.of("x", "y"),
|
||||
"__acc"
|
||||
);
|
||||
validateFoldUnapplied(
|
||||
"__acc + z + fold((x, acc) -> acc + x + y, x, 0)",
|
||||
"(+ (+ __acc z) (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))",
|
||||
"(fold ([z, __acc] -> (+ (+ __acc z) (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))), [z, __acc])",
|
||||
ImmutableList.of("z"),
|
||||
"__acc"
|
||||
);
|
||||
validateFoldUnapplied(
|
||||
"__acc + z + fold((x, acc) -> acc + x + y, x, 0)",
|
||||
"(+ (+ __acc z) (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))",
|
||||
"(fold ([z, __acc] -> (+ (+ __acc z) (cartesian_fold ([x, y, acc] -> (+ (+ acc x) y)), [x, y, 0]))), [z, __acc])",
|
||||
ImmutableList.of("y", "z"),
|
||||
"__acc"
|
||||
);
|
||||
validateFoldUnapplied(
|
||||
"__acc + fold((x, acc) -> x + y + acc, x, __acc)",
|
||||
"(+ __acc (fold ([x, acc] -> (+ (+ x y) acc)), [x, __acc]))",
|
||||
"(+ __acc (cartesian_fold ([x, y, acc] -> (+ (+ x y) acc)), [x, y, __acc]))",
|
||||
ImmutableList.of("y"),
|
||||
"__acc"
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUniquify()
|
||||
{
|
||||
|
@ -666,6 +708,33 @@ public class ParserTest extends InitializedNullHandlingTest
|
|||
Assert.assertEquals(transformed.stringify(), transformedRoundTrip.stringify());
|
||||
}
|
||||
|
||||
private void validateFoldUnapplied(
|
||||
String expression,
|
||||
String unapplied,
|
||||
String applied,
|
||||
List<String> identifiers,
|
||||
String accumulator
|
||||
)
|
||||
{
|
||||
final Expr parsed = Parser.parse(expression, ExprMacroTable.nil());
|
||||
Expr.BindingAnalysis deets = parsed.analyzeInputs();
|
||||
Parser.validateExpr(parsed, deets);
|
||||
final Expr transformed = Parser.foldUnappliedBindings(parsed, deets, identifiers, accumulator);
|
||||
Assert.assertEquals(expression, unapplied, parsed.toString());
|
||||
Assert.assertEquals(applied, applied, transformed.toString());
|
||||
|
||||
final Expr parsedNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false);
|
||||
final Expr parsedRoundTrip = Parser.parse(parsedNoFlatten.stringify(), ExprMacroTable.nil());
|
||||
Expr.BindingAnalysis roundTripDeets = parsedRoundTrip.analyzeInputs();
|
||||
Parser.validateExpr(parsedRoundTrip, roundTripDeets);
|
||||
final Expr transformedRoundTrip = Parser.foldUnappliedBindings(parsedRoundTrip, roundTripDeets, identifiers, accumulator);
|
||||
Assert.assertEquals(expression, unapplied, parsedRoundTrip.toString());
|
||||
Assert.assertEquals(applied, applied, transformedRoundTrip.toString());
|
||||
|
||||
Assert.assertEquals(parsed.stringify(), parsedRoundTrip.stringify());
|
||||
Assert.assertEquals(transformed.stringify(), transformedRoundTrip.stringify());
|
||||
}
|
||||
|
||||
private void validateConstantExpression(String expression, Object expected)
|
||||
{
|
||||
Expr parsed = Parser.parse(expression, ExprMacroTable.nil());
|
||||
|
|
|
@ -415,29 +415,6 @@ public class VectorExprSanityTest extends InitializedNullHandlingTest
|
|||
.toArray(String[][]::new);
|
||||
}
|
||||
|
||||
static class SettableObjectBinding implements Expr.ObjectBinding
|
||||
{
|
||||
private final Map<String, Object> bindings;
|
||||
|
||||
SettableObjectBinding()
|
||||
{
|
||||
this.bindings = new HashMap<>();
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public Object get(String name)
|
||||
{
|
||||
return bindings.get(name);
|
||||
}
|
||||
|
||||
public SettableObjectBinding withBinding(String name, @Nullable Object value)
|
||||
{
|
||||
bindings.put(name, value);
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
static class SettableVectorInputBinding implements Expr.VectorInputBinding
|
||||
{
|
||||
private final Map<String, boolean[]> nulls;
|
||||
|
|
|
@ -177,14 +177,17 @@ See javadoc of java.lang.Math for detailed explanation for each function.
|
|||
| array_offset_of(arr,expr) | returns the 0 based index of the first occurrence of expr in the array, or `-1` or `null` if `druid.generic.useDefaultValueForNull=false`if no matching elements exist in the array. |
|
||||
| array_ordinal_of(arr,expr) | returns the 1 based index of the first occurrence of expr in the array, or `-1` or `null` if `druid.generic.useDefaultValueForNull=false` if no matching elements exist in the array. |
|
||||
| array_prepend(expr,arr) | adds expr to arr at the beginning, the resulting array type determined by the type of the array |
|
||||
| array_append(arr1,expr) | appends expr to arr, the resulting array type determined by the type of the first array |
|
||||
| array_append(arr,expr) | appends expr to arr, the resulting array type determined by the type of the first array |
|
||||
| array_concat(arr1,arr2) | concatenates 2 arrays, the resulting array type determined by the type of the first array |
|
||||
| array_set_add(arr,expr) | adds expr to arr and converts the array to a new array composed of the unique set of elements. The resulting array type determined by the type of the array |
|
||||
| array_set_add_all(arr1,arr2) | combines the unique set of elements of 2 arrays, the resulting array type determined by the type of the first array |
|
||||
| array_slice(arr,start,end) | return the subarray of arr from the 0 based index start(inclusive) to end(exclusive), or `null`, if start is less than 0, greater than length of arr or less than end|
|
||||
| array_to_string(arr,str) | joins all elements of arr by the delimiter specified by str |
|
||||
| string_to_array(str1,str2) | splits str1 into an array on the delimiter specified by str2 |
|
||||
|
||||
|
||||
## Apply functions
|
||||
Apply functions allow for special 'lambda' expressions to be defined and applied to array inputs to enable free-form transformations.
|
||||
|
||||
| function | description |
|
||||
| --- | --- |
|
||||
|
@ -197,6 +200,26 @@ See javadoc of java.lang.Math for detailed explanation for each function.
|
|||
| all(lambda,arr) | returns 1 if all elements in the array matches the lambda expression, else 0 |
|
||||
|
||||
|
||||
### Lambda expressions syntax
|
||||
Lambda expressions are a sort of function definition, where new identifiers can be defined and passed as input to the expression body
|
||||
```
|
||||
(identifier1 ...) -> expr
|
||||
```
|
||||
e.g.
|
||||
```
|
||||
(x, y) -> x + y
|
||||
```
|
||||
The identifier arguments of a lambda expression correspond to the elements of the array it is being applied to. For example:
|
||||
```
|
||||
map((x) -> x + 1, some_multi_value_column)
|
||||
```
|
||||
will map each element of `some_multi_value_column` to the identifier `x` so that the lambda expression body can be evaluated for each `x`. The scoping rules are that lambda arguments will override identifiers which are defined externally from the lambda expression body. Using the same example:
|
||||
|
||||
```
|
||||
map((x) -> x + 1, x)
|
||||
```
|
||||
in this case, the `x` when evaluating `x + 1` is the lambda argument, thus an element of the multi-valued column `x`, rather than the column `x` itself.
|
||||
|
||||
## Reduction functions
|
||||
|
||||
Reduction functions operate on zero or more expressions and return a single expression. If no expressions are passed as
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
|||
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.FloatMinAggregatorFactory;
|
||||
|
@ -120,7 +121,8 @@ public class AggregatorsModule extends SimpleModule
|
|||
@JsonSubTypes.Type(name = "floatAny", value = FloatAnyAggregatorFactory.class),
|
||||
@JsonSubTypes.Type(name = "doubleAny", value = DoubleAnyAggregatorFactory.class),
|
||||
@JsonSubTypes.Type(name = "stringAny", value = StringAnyAggregatorFactory.class),
|
||||
@JsonSubTypes.Type(name = "grouping", value = GroupingAggregatorFactory.class)
|
||||
@JsonSubTypes.Type(name = "grouping", value = GroupingAggregatorFactory.class),
|
||||
@JsonSubTypes.Type(name = "expression", value = ExpressionLambdaAggregatorFactory.class)
|
||||
})
|
||||
public interface AggregatorFactoryMixin
|
||||
{
|
||||
|
|
|
@ -137,6 +137,9 @@ public class AggregatorUtil
|
|||
// GROUPING aggregator
|
||||
public static final byte GROUPING_CACHE_TYPE_ID = 0x46;
|
||||
|
||||
// expression lambda aggregator
|
||||
public static final byte EXPRESSION_LAMBDA_CACHE_TYPE_ID = 0x47;
|
||||
|
||||
|
||||
/**
|
||||
* returns the list of dependent postAggregators that should be calculated in order to calculate given postAgg
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.druid.query.aggregation;
|
||||
|
||||
import org.apache.druid.math.expr.Expr;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
public class ExpressionLambdaAggregator implements Aggregator
|
||||
{
|
||||
private final Expr lambda;
|
||||
private final ExpressionLambdaAggregatorInputBindings bindings;
|
||||
|
||||
public ExpressionLambdaAggregator(Expr lambda, ExpressionLambdaAggregatorInputBindings bindings)
|
||||
{
|
||||
this.lambda = lambda;
|
||||
this.bindings = bindings;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void aggregate()
|
||||
{
|
||||
bindings.accumulate(lambda.eval(bindings));
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public Object get()
|
||||
{
|
||||
return bindings.getAccumulator().value();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getFloat()
|
||||
{
|
||||
return (float) bindings.getAccumulator().asDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getLong()
|
||||
{
|
||||
return bindings.getAccumulator().asLong();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getDouble()
|
||||
{
|
||||
return bindings.getAccumulator().asDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isNull()
|
||||
{
|
||||
return bindings.getAccumulator().isNumericNull();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close()
|
||||
{
|
||||
// nothing to close
|
||||
}
|
||||
}
|
|
@ -0,0 +1,516 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.druid.query.aggregation;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JacksonInject;
|
||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Supplier;
|
||||
import com.google.common.base.Suppliers;
|
||||
import com.google.common.collect.Iterables;
|
||||
import org.apache.druid.java.util.common.HumanReadableBytes;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.java.util.common.guava.Comparators;
|
||||
import org.apache.druid.math.expr.Expr;
|
||||
import org.apache.druid.math.expr.ExprEval;
|
||||
import org.apache.druid.math.expr.ExprMacroTable;
|
||||
import org.apache.druid.math.expr.ExprType;
|
||||
import org.apache.druid.math.expr.Parser;
|
||||
import org.apache.druid.math.expr.SettableObjectBinding;
|
||||
import org.apache.druid.query.cache.CacheKeyBuilder;
|
||||
import org.apache.druid.query.expression.ExprUtils;
|
||||
import org.apache.druid.segment.ColumnInspector;
|
||||
import org.apache.druid.segment.ColumnSelectorFactory;
|
||||
import org.apache.druid.segment.column.ColumnCapabilities;
|
||||
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
|
||||
import org.apache.druid.segment.column.ValueType;
|
||||
import org.apache.druid.segment.virtual.ExpressionPlan;
|
||||
import org.apache.druid.segment.virtual.ExpressionPlanner;
|
||||
import org.apache.druid.segment.virtual.ExpressionSelectors;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
|
||||
{
|
||||
private static final String FINALIZE_IDENTIFIER = "o";
|
||||
private static final String COMPARE_O1 = "o1";
|
||||
private static final String COMPARE_O2 = "o2";
|
||||
private static final String DEFAULT_ACCUMULATOR_ID = "__acc";
|
||||
|
||||
// minimum permitted agg size is 10 bytes so it is at least large enough to hold primitive numerics (long, double)
|
||||
// | expression type byte | is_null byte | primitive value (8 bytes) |
|
||||
private static final int MIN_SIZE_BYTES = 10;
|
||||
private static final HumanReadableBytes DEFAULT_MAX_SIZE_BYTES = new HumanReadableBytes(1L << 10);
|
||||
|
||||
private final String name;
|
||||
@Nullable
|
||||
private final Set<String> fields;
|
||||
private final String accumulatorId;
|
||||
private final String foldExpressionString;
|
||||
private final String initialValueExpressionString;
|
||||
private final String initialCombineValueExpressionString;
|
||||
|
||||
private final String combineExpressionString;
|
||||
@Nullable
|
||||
private final String compareExpressionString;
|
||||
@Nullable
|
||||
private final String finalizeExpressionString;
|
||||
|
||||
private final ExprMacroTable macroTable;
|
||||
private final Supplier<ExprEval<?>> initialValue;
|
||||
private final Supplier<ExprEval<?>> initialCombineValue;
|
||||
private final Supplier<Expr> foldExpression;
|
||||
private final Supplier<Expr> combineExpression;
|
||||
private final Supplier<Expr> compareExpression;
|
||||
private final Supplier<Expr> finalizeExpression;
|
||||
private final HumanReadableBytes maxSizeBytes;
|
||||
|
||||
private final Supplier<SettableObjectBinding> compareBindings =
|
||||
Suppliers.memoize(() -> new SettableObjectBinding(2));
|
||||
private final Supplier<SettableObjectBinding> combineBindings =
|
||||
Suppliers.memoize(() -> new SettableObjectBinding(2));
|
||||
private final Supplier<SettableObjectBinding> finalizeBindings =
|
||||
Suppliers.memoize(() -> new SettableObjectBinding(1));
|
||||
|
||||
@JsonCreator
|
||||
public ExpressionLambdaAggregatorFactory(
|
||||
@JsonProperty("name") String name,
|
||||
@JsonProperty("fields") @Nullable final Set<String> fields,
|
||||
@JsonProperty("accumulatorIdentifier") @Nullable final String accumulatorIdentifier,
|
||||
@JsonProperty("initialValue") final String initialValue,
|
||||
@JsonProperty("initialCombineValue") @Nullable final String initialCombineValue,
|
||||
@JsonProperty("fold") final String foldExpression,
|
||||
@JsonProperty("combine") @Nullable final String combineExpression,
|
||||
@JsonProperty("compare") @Nullable final String compareExpression,
|
||||
@JsonProperty("finalize") @Nullable final String finalizeExpression,
|
||||
@JsonProperty("maxSizeBytes") @Nullable final HumanReadableBytes maxSizeBytes,
|
||||
@JacksonInject ExprMacroTable macroTable
|
||||
)
|
||||
{
|
||||
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
|
||||
|
||||
this.name = name;
|
||||
this.fields = fields;
|
||||
this.accumulatorId = accumulatorIdentifier != null ? accumulatorIdentifier : DEFAULT_ACCUMULATOR_ID;
|
||||
|
||||
this.initialValueExpressionString = initialValue;
|
||||
this.initialCombineValueExpressionString = initialCombineValue == null ? initialValue : initialCombineValue;
|
||||
this.foldExpressionString = foldExpression;
|
||||
if (combineExpression != null) {
|
||||
this.combineExpressionString = combineExpression;
|
||||
} else {
|
||||
// if the combine expression is null, allow single input aggregator expressions to be rewritten to replace the
|
||||
// field with the aggregator name. Fields is null for the combining/merging aggregator, but the expression should
|
||||
// already be set with the rewritten value at that point
|
||||
Preconditions.checkArgument(
|
||||
fields != null && fields.size() == 1,
|
||||
"Must have a single input field if no combine expression is supplied"
|
||||
);
|
||||
this.combineExpressionString = StringUtils.replace(foldExpression, Iterables.getOnlyElement(fields), name);
|
||||
}
|
||||
this.compareExpressionString = compareExpression;
|
||||
this.finalizeExpressionString = finalizeExpression;
|
||||
this.macroTable = macroTable;
|
||||
|
||||
this.initialValue = Suppliers.memoize(() -> {
|
||||
Expr parsed = Parser.parse(initialValue, macroTable);
|
||||
Preconditions.checkArgument(parsed.isLiteral(), "initial value must be constant");
|
||||
return parsed.eval(ExprUtils.nilBindings());
|
||||
});
|
||||
this.initialCombineValue = Suppliers.memoize(() -> {
|
||||
Expr parsed = Parser.parse(this.initialCombineValueExpressionString, macroTable);
|
||||
Preconditions.checkArgument(parsed.isLiteral(), "initial combining value must be constant");
|
||||
return parsed.eval(ExprUtils.nilBindings());
|
||||
});
|
||||
this.foldExpression = Parser.lazyParse(foldExpressionString, macroTable);
|
||||
this.combineExpression = Parser.lazyParse(combineExpressionString, macroTable);
|
||||
this.compareExpression = Parser.lazyParse(compareExpressionString, macroTable);
|
||||
this.finalizeExpression = Parser.lazyParse(finalizeExpressionString, macroTable);
|
||||
this.maxSizeBytes = maxSizeBytes != null ? maxSizeBytes : DEFAULT_MAX_SIZE_BYTES;
|
||||
Preconditions.checkArgument(this.maxSizeBytes.getBytesInInt() >= MIN_SIZE_BYTES);
|
||||
}
|
||||
|
||||
@JsonProperty
|
||||
@Override
|
||||
public String getName()
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
@JsonProperty
|
||||
@Nullable
|
||||
public Set<String> getFields()
|
||||
{
|
||||
return fields;
|
||||
}
|
||||
|
||||
@JsonProperty
|
||||
@Nullable
|
||||
public String getAccumulatorIdentifier()
|
||||
{
|
||||
return accumulatorId;
|
||||
}
|
||||
|
||||
@JsonProperty("initialValue")
|
||||
public String getInitialValueExpressionString()
|
||||
{
|
||||
return initialValueExpressionString;
|
||||
}
|
||||
|
||||
@JsonProperty("initialCombineValue")
|
||||
public String getInitialCombineValueExpressionString()
|
||||
{
|
||||
return initialCombineValueExpressionString;
|
||||
}
|
||||
|
||||
@JsonProperty("fold")
|
||||
public String getFoldExpressionString()
|
||||
{
|
||||
return foldExpressionString;
|
||||
}
|
||||
|
||||
@JsonProperty("combine")
|
||||
public String getCombineExpressionString()
|
||||
{
|
||||
return combineExpressionString;
|
||||
}
|
||||
|
||||
@JsonProperty("compare")
|
||||
@Nullable
|
||||
public String getCompareExpressionString()
|
||||
{
|
||||
return compareExpressionString;
|
||||
}
|
||||
|
||||
@JsonProperty("finalize")
|
||||
@Nullable
|
||||
public String getFinalizeExpressionString()
|
||||
{
|
||||
return finalizeExpressionString;
|
||||
}
|
||||
|
||||
@JsonProperty("maxSizeBytes")
|
||||
public HumanReadableBytes getMaxSizeBytes()
|
||||
{
|
||||
return maxSizeBytes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] getCacheKey()
|
||||
{
|
||||
return new CacheKeyBuilder(AggregatorUtil.EXPRESSION_LAMBDA_CACHE_TYPE_ID)
|
||||
.appendStrings(fields)
|
||||
.appendString(initialValueExpressionString)
|
||||
.appendString(initialCombineValueExpressionString)
|
||||
.appendString(foldExpressionString)
|
||||
.appendString(combineExpressionString)
|
||||
.appendString(compareExpressionString)
|
||||
.appendString(finalizeExpressionString)
|
||||
.appendInt(maxSizeBytes.getBytesInInt())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Aggregator factorize(ColumnSelectorFactory metricFactory)
|
||||
{
|
||||
FactorizePlan thePlan = new FactorizePlan(metricFactory);
|
||||
return new ExpressionLambdaAggregator(
|
||||
thePlan.getExpression(),
|
||||
thePlan.getBindings()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
|
||||
{
|
||||
FactorizePlan thePlan = new FactorizePlan(metricFactory);
|
||||
return new ExpressionLambdaBufferAggregator(
|
||||
thePlan.getExpression(),
|
||||
thePlan.getInitialValue(),
|
||||
thePlan.getBindings(),
|
||||
maxSizeBytes.getBytesInInt()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Comparator getComparator()
|
||||
{
|
||||
Expr compareExpr = compareExpression.get();
|
||||
if (compareExpr != null) {
|
||||
return (o1, o2) ->
|
||||
compareExpr.eval(compareBindings.get().withBinding(COMPARE_O1, o1).withBinding(COMPARE_O2, o2)).asInt();
|
||||
}
|
||||
switch (initialValue.get().type()) {
|
||||
case LONG:
|
||||
return LongSumAggregator.COMPARATOR;
|
||||
case DOUBLE:
|
||||
return DoubleSumAggregator.COMPARATOR;
|
||||
default:
|
||||
return Comparators.naturalNullsFirst();
|
||||
}
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public Object combine(@Nullable Object lhs, @Nullable Object rhs)
|
||||
{
|
||||
// arbitrarily assign lhs and rhs to accumulator and aggregator name inputs to re-use combine function
|
||||
return combineExpression.get().eval(
|
||||
combineBindings.get().withBinding(accumulatorId, lhs).withBinding(name, rhs)
|
||||
).value();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object deserialize(Object object)
|
||||
{
|
||||
return object;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public Object finalizeComputation(@Nullable Object object)
|
||||
{
|
||||
Expr finalizeExpr;
|
||||
finalizeExpr = finalizeExpression.get();
|
||||
if (finalizeExpr != null) {
|
||||
return finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, object)).value();
|
||||
}
|
||||
return object;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> requiredFields()
|
||||
{
|
||||
if (fields == null) {
|
||||
return combineExpression.get().analyzeInputs().getRequiredBindingsList();
|
||||
}
|
||||
return foldExpression.get().analyzeInputs().getRequiredBindingsList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public AggregatorFactory getCombiningFactory()
|
||||
{
|
||||
return new ExpressionLambdaAggregatorFactory(
|
||||
name,
|
||||
null,
|
||||
accumulatorId,
|
||||
initialValueExpressionString,
|
||||
initialCombineValueExpressionString,
|
||||
foldExpressionString,
|
||||
combineExpressionString,
|
||||
compareExpressionString,
|
||||
finalizeExpressionString,
|
||||
maxSizeBytes,
|
||||
macroTable
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AggregatorFactory> getRequiredColumns()
|
||||
{
|
||||
return Collections.singletonList(
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
name,
|
||||
fields,
|
||||
accumulatorId,
|
||||
initialValueExpressionString,
|
||||
initialCombineValueExpressionString,
|
||||
foldExpressionString,
|
||||
combineExpressionString,
|
||||
compareExpressionString,
|
||||
finalizeExpressionString,
|
||||
maxSizeBytes,
|
||||
macroTable
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ValueType getType()
|
||||
{
|
||||
if (fields == null) {
|
||||
return ExprType.toValueType(initialCombineValue.get().type());
|
||||
}
|
||||
return ExprType.toValueType(initialValue.get().type());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ValueType getFinalizedType()
|
||||
{
|
||||
Expr finalizeExpr = finalizeExpression.get();
|
||||
ExprEval<?> initialVal = initialCombineValue.get();
|
||||
if (finalizeExpr != null) {
|
||||
return ExprType.toValueType(
|
||||
finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, initialVal)).type()
|
||||
);
|
||||
}
|
||||
return ExprType.toValueType(initialVal.type());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxIntermediateSize()
|
||||
{
|
||||
// numeric expressions are either longs or doubles, with strings or arrays max size is unknown
|
||||
// for numeric arguments, the first 2 bytes are used for expression type byte and is_null byte
|
||||
return getType().isNumeric() ? 2 + Long.BYTES : maxSizeBytes.getBytesInInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o)
|
||||
{
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
ExpressionLambdaAggregatorFactory that = (ExpressionLambdaAggregatorFactory) o;
|
||||
return maxSizeBytes.equals(that.maxSizeBytes)
|
||||
&& name.equals(that.name)
|
||||
&& Objects.equals(fields, that.fields)
|
||||
&& accumulatorId.equals(that.accumulatorId)
|
||||
&& foldExpressionString.equals(that.foldExpressionString)
|
||||
&& initialValueExpressionString.equals(that.initialValueExpressionString)
|
||||
&& initialCombineValueExpressionString.equals(that.initialCombineValueExpressionString)
|
||||
&& combineExpressionString.equals(that.combineExpressionString)
|
||||
&& Objects.equals(compareExpressionString, that.compareExpressionString)
|
||||
&& Objects.equals(finalizeExpressionString, that.finalizeExpressionString);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode()
|
||||
{
|
||||
return Objects.hash(
|
||||
name,
|
||||
fields,
|
||||
accumulatorId,
|
||||
foldExpressionString,
|
||||
initialValueExpressionString,
|
||||
initialCombineValueExpressionString,
|
||||
combineExpressionString,
|
||||
compareExpressionString,
|
||||
finalizeExpressionString,
|
||||
maxSizeBytes
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString()
|
||||
{
|
||||
return "ExpressionLambdaAggregatorFactory{" +
|
||||
"name='" + name + '\'' +
|
||||
", fields=" + fields +
|
||||
", accumulatorId='" + accumulatorId + '\'' +
|
||||
", foldExpressionString='" + foldExpressionString + '\'' +
|
||||
", initialValueExpressionString='" + initialValueExpressionString + '\'' +
|
||||
", initialCombineValueExpressionString='" + initialCombineValueExpressionString + '\'' +
|
||||
", combineExpressionString='" + combineExpressionString + '\'' +
|
||||
", compareExpressionString='" + compareExpressionString + '\'' +
|
||||
", finalizeExpressionString='" + finalizeExpressionString + '\'' +
|
||||
", maxSizeBytes=" + maxSizeBytes +
|
||||
'}';
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine how to factorize the aggregator
|
||||
*/
|
||||
private class FactorizePlan
|
||||
{
|
||||
private final ExpressionPlan plan;
|
||||
|
||||
private final ExprEval<?> seed;
|
||||
private final ExpressionLambdaAggregatorInputBindings bindings;
|
||||
|
||||
FactorizePlan(ColumnSelectorFactory metricFactory)
|
||||
{
|
||||
final List<String> columns;
|
||||
|
||||
if (fields != null) {
|
||||
// if fields are set, we are accumulating from raw inputs, use fold expression
|
||||
plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), foldExpression.get());
|
||||
seed = initialValue.get();
|
||||
columns = plan.getAnalysis().getRequiredBindingsList();
|
||||
} else {
|
||||
// else we are merging intermediary results, use combine expression
|
||||
plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), combineExpression.get());
|
||||
seed = initialCombineValue.get();
|
||||
columns = plan.getAnalysis().getRequiredBindingsList();
|
||||
}
|
||||
|
||||
bindings = new ExpressionLambdaAggregatorInputBindings(
|
||||
ExpressionSelectors.createBindings(metricFactory, columns),
|
||||
accumulatorId,
|
||||
seed
|
||||
);
|
||||
}
|
||||
|
||||
public Expr getExpression()
|
||||
{
|
||||
if (fields == null) {
|
||||
return plan.getExpression();
|
||||
}
|
||||
// for fold expressions, check to see if it needs transformation due to scalar use of multi-valued or unknown
|
||||
// inputs
|
||||
return plan.getAppliedFoldExpression(accumulatorId);
|
||||
}
|
||||
|
||||
public ExprEval<?> getInitialValue()
|
||||
{
|
||||
return seed;
|
||||
}
|
||||
|
||||
public ExpressionLambdaAggregatorInputBindings getBindings()
|
||||
{
|
||||
return bindings;
|
||||
}
|
||||
|
||||
private ColumnInspector inspectorWithAccumulator(ColumnInspector inspector)
|
||||
{
|
||||
return new ColumnInspector()
|
||||
{
|
||||
@Nullable
|
||||
@Override
|
||||
public ColumnCapabilities getColumnCapabilities(String column)
|
||||
{
|
||||
if (accumulatorId.equals(column)) {
|
||||
return ColumnCapabilitiesImpl.createDefault().setType(ExprType.toValueType(initialValue.get().type()));
|
||||
}
|
||||
return inspector.getColumnCapabilities(column);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public ExprType getType(String name)
|
||||
{
|
||||
if (accumulatorId.equals(name)) {
|
||||
return initialValue.get().type();
|
||||
}
|
||||
return inspector.getType(name);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.druid.query.aggregation;
|
||||
|
||||
import org.apache.druid.math.expr.Expr;
|
||||
import org.apache.druid.math.expr.ExprEval;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
* Special {@link Expr.ObjectBinding} for use with {@link ExpressionLambdaAggregatorFactory}.
|
||||
* This value binding holds a value for a special 'accumulator' variable, in addition to the 'normal' bindings to the
|
||||
* underlying selector inputs for other identifiers, which allows for easy forward feeding of the results of an
|
||||
* expression evaluation to use in the bindings of the next evaluation.
|
||||
*/
|
||||
public class ExpressionLambdaAggregatorInputBindings implements Expr.ObjectBinding
|
||||
{
|
||||
private final Expr.ObjectBinding inputBindings;
|
||||
private final String accumlatorIdentifier;
|
||||
private ExprEval<?> accumulator;
|
||||
|
||||
public ExpressionLambdaAggregatorInputBindings(
|
||||
Expr.ObjectBinding inputBindings,
|
||||
String accumulatorIdentifier,
|
||||
ExprEval<?> initialValue
|
||||
)
|
||||
{
|
||||
this.accumlatorIdentifier = accumulatorIdentifier;
|
||||
this.inputBindings = inputBindings;
|
||||
this.accumulator = initialValue;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public Object get(String name)
|
||||
{
|
||||
if (accumlatorIdentifier.equals(name)) {
|
||||
return accumulator.value();
|
||||
}
|
||||
return inputBindings.get(name);
|
||||
}
|
||||
|
||||
public void accumulate(ExprEval<?> eval)
|
||||
{
|
||||
accumulator = eval;
|
||||
}
|
||||
|
||||
public ExprEval<?> getAccumulator()
|
||||
{
|
||||
return accumulator;
|
||||
}
|
||||
|
||||
public void setAccumulator(ExprEval<?> acc)
|
||||
{
|
||||
this.accumulator = acc;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.druid.query.aggregation;
|
||||
|
||||
import org.apache.druid.math.expr.Expr;
|
||||
import org.apache.druid.math.expr.ExprEval;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
public class ExpressionLambdaBufferAggregator implements BufferAggregator
|
||||
{
|
||||
private final Expr lambda;
|
||||
private final ExprEval<?> initialValue;
|
||||
private final ExpressionLambdaAggregatorInputBindings bindings;
|
||||
private final int maxSizeBytes;
|
||||
|
||||
public ExpressionLambdaBufferAggregator(
|
||||
Expr lambda,
|
||||
ExprEval<?> initialValue,
|
||||
ExpressionLambdaAggregatorInputBindings bindings,
|
||||
int maxSizeBytes
|
||||
)
|
||||
{
|
||||
this.lambda = lambda;
|
||||
this.initialValue = initialValue;
|
||||
this.bindings = bindings;
|
||||
this.maxSizeBytes = maxSizeBytes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void init(ByteBuffer buf, int position)
|
||||
{
|
||||
ExprEval.serialize(buf, position, initialValue, maxSizeBytes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void aggregate(ByteBuffer buf, int position)
|
||||
{
|
||||
ExprEval<?> acc = ExprEval.deserialize(buf, position);
|
||||
bindings.setAccumulator(acc);
|
||||
ExprEval<?> newAcc = lambda.eval(bindings);
|
||||
ExprEval.serialize(buf, position, newAcc, maxSizeBytes);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public Object get(ByteBuffer buf, int position)
|
||||
{
|
||||
return ExprEval.deserialize(buf, position).value();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getFloat(ByteBuffer buf, int position)
|
||||
{
|
||||
return (float) ExprEval.deserialize(buf, position).asDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getDouble(ByteBuffer buf, int position)
|
||||
{
|
||||
return ExprEval.deserialize(buf, position).asDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getLong(ByteBuffer buf, int position)
|
||||
{
|
||||
return ExprEval.deserialize(buf, position).asLong();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close()
|
||||
{
|
||||
// nothing to close
|
||||
}
|
||||
}
|
|
@ -23,7 +23,6 @@ package org.apache.druid.query.aggregation;
|
|||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Supplier;
|
||||
import com.google.common.base.Suppliers;
|
||||
import org.apache.druid.math.expr.Expr;
|
||||
import org.apache.druid.math.expr.ExprMacroTable;
|
||||
import org.apache.druid.math.expr.Parser;
|
||||
|
@ -72,7 +71,7 @@ public abstract class SimpleDoubleAggregatorFactory extends NullableNumericAggre
|
|||
this.fieldName = fieldName;
|
||||
this.expression = expression;
|
||||
this.storeDoubleAsFloat = ColumnHolder.storeDoubleAsFloat();
|
||||
this.fieldExpression = Suppliers.memoize(() -> expression == null ? null : Parser.parse(expression, macroTable));
|
||||
this.fieldExpression = Parser.lazyParse(expression, macroTable);
|
||||
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
|
||||
Preconditions.checkArgument(
|
||||
fieldName == null ^ expression == null,
|
||||
|
|
|
@ -23,7 +23,6 @@ package org.apache.druid.query.aggregation;
|
|||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Supplier;
|
||||
import com.google.common.base.Suppliers;
|
||||
import org.apache.druid.math.expr.Expr;
|
||||
import org.apache.druid.math.expr.ExprMacroTable;
|
||||
import org.apache.druid.math.expr.Parser;
|
||||
|
@ -63,7 +62,7 @@ public abstract class SimpleFloatAggregatorFactory extends NullableNumericAggreg
|
|||
this.name = name;
|
||||
this.fieldName = fieldName;
|
||||
this.expression = expression;
|
||||
this.fieldExpression = Suppliers.memoize(() -> expression == null ? null : Parser.parse(expression, macroTable));
|
||||
this.fieldExpression = Parser.lazyParse(expression, macroTable);
|
||||
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
|
||||
Preconditions.checkArgument(
|
||||
fieldName == null ^ expression == null,
|
||||
|
|
|
@ -23,7 +23,6 @@ package org.apache.druid.query.aggregation;
|
|||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Supplier;
|
||||
import com.google.common.base.Suppliers;
|
||||
import org.apache.druid.math.expr.Expr;
|
||||
import org.apache.druid.math.expr.ExprMacroTable;
|
||||
import org.apache.druid.math.expr.Parser;
|
||||
|
@ -69,7 +68,7 @@ public abstract class SimpleLongAggregatorFactory extends NullableNumericAggrega
|
|||
this.name = name;
|
||||
this.fieldName = fieldName;
|
||||
this.expression = expression;
|
||||
this.fieldExpression = Suppliers.memoize(() -> expression == null ? null : Parser.parse(expression, macroTable));
|
||||
this.fieldExpression = Parser.lazyParse(expression, macroTable);
|
||||
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
|
||||
Preconditions.checkArgument(
|
||||
fieldName == null ^ expression == null,
|
||||
|
|
|
@ -92,7 +92,7 @@ public class ExpressionPostAggregator implements PostAggregator
|
|||
ordering,
|
||||
macroTable,
|
||||
ImmutableMap.of(),
|
||||
Suppliers.memoize(() -> Parser.parse(expression, macroTable))
|
||||
Parser.lazyParse(expression, macroTable)
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ import com.fasterxml.jackson.annotation.JsonInclude;
|
|||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Supplier;
|
||||
import com.google.common.base.Suppliers;
|
||||
import com.google.common.collect.RangeSet;
|
||||
import org.apache.druid.math.expr.Expr;
|
||||
import org.apache.druid.math.expr.ExprMacroTable;
|
||||
|
@ -53,7 +52,7 @@ public class ExpressionDimFilter extends AbstractOptimizableDimFilter implements
|
|||
{
|
||||
this.expression = expression;
|
||||
this.filterTuning = filterTuning;
|
||||
this.parsed = Suppliers.memoize(() -> Parser.parse(expression, macroTable));
|
||||
this.parsed = Parser.lazyParse(expression, macroTable);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
|
|
|
@ -26,6 +26,7 @@ import com.google.common.base.Preconditions;
|
|||
import com.google.common.base.Suppliers;
|
||||
import org.apache.druid.data.input.Row;
|
||||
import org.apache.druid.math.expr.Expr;
|
||||
import org.apache.druid.math.expr.ExprEval;
|
||||
import org.apache.druid.math.expr.ExprMacroTable;
|
||||
import org.apache.druid.math.expr.Parser;
|
||||
import org.apache.druid.segment.column.ColumnHolder;
|
||||
|
@ -106,7 +107,7 @@ public class ExpressionTransform implements Transform
|
|||
} else {
|
||||
Object raw = row.getRaw(column);
|
||||
if (raw instanceof List) {
|
||||
return ExpressionSelectors.coerceListToArray((List) raw);
|
||||
return ExprEval.coerceListToArray((List) raw, true);
|
||||
}
|
||||
return raw;
|
||||
}
|
||||
|
|
|
@ -122,6 +122,14 @@ public class ExpressionPlan
|
|||
return expression;
|
||||
}
|
||||
|
||||
public Expr getAppliedFoldExpression(String accumulatorId)
|
||||
{
|
||||
if (is(Trait.NEEDS_APPLIED)) {
|
||||
return Parser.foldUnappliedBindings(expression, analysis, unappliedInputs, accumulatorId);
|
||||
}
|
||||
return expression;
|
||||
}
|
||||
|
||||
public Expr.BindingAnalysis getAnalysis()
|
||||
{
|
||||
return analysis;
|
||||
|
|
|
@ -24,7 +24,6 @@ import com.google.common.base.Preconditions;
|
|||
import com.google.common.base.Supplier;
|
||||
import com.google.common.collect.Iterables;
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.java.util.common.UOE;
|
||||
import org.apache.druid.math.expr.Expr;
|
||||
import org.apache.druid.math.expr.ExprEval;
|
||||
import org.apache.druid.math.expr.Parser;
|
||||
|
@ -242,13 +241,26 @@ public class ExpressionSelectors
|
|||
* provides the set of identifiers which need a binding (list of required columns), and context of whether or not they
|
||||
* are used as array or scalar inputs
|
||||
*/
|
||||
private static Expr.ObjectBinding createBindings(
|
||||
public static Expr.ObjectBinding createBindings(
|
||||
Expr.BindingAnalysis bindingAnalysis,
|
||||
ColumnSelectorFactory columnSelectorFactory
|
||||
)
|
||||
{
|
||||
final Map<String, Supplier<Object>> suppliers = new HashMap<>();
|
||||
final List<String> columns = bindingAnalysis.getRequiredBindingsList();
|
||||
return createBindings(columnSelectorFactory, columns);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link Expr.BindingAnalysis} which
|
||||
* provides the set of identifiers which need a binding (list of required columns), and context of whether or not they
|
||||
* are used as array or scalar inputs
|
||||
*/
|
||||
public static Expr.ObjectBinding createBindings(
|
||||
ColumnSelectorFactory columnSelectorFactory,
|
||||
List<String> columns
|
||||
)
|
||||
{
|
||||
final Map<String, Supplier<Object>> suppliers = new HashMap<>();
|
||||
for (String columnName : columns) {
|
||||
final ColumnCapabilities columnCapabilities = columnSelectorFactory.getColumnCapabilities(columnName);
|
||||
final ValueType nativeType = columnCapabilities != null ? columnCapabilities.getType() : null;
|
||||
|
@ -269,8 +281,8 @@ public class ExpressionSelectors
|
|||
columnSelectorFactory.makeDimensionSelector(new DefaultDimensionSpec(columnName, columnName)),
|
||||
multiVal
|
||||
);
|
||||
} else if (nativeType == null) {
|
||||
// Unknown ValueType. Try making an Object selector and see if that gives us anything useful.
|
||||
} else if (nativeType == null || ValueType.isArray(nativeType)) {
|
||||
// Unknown ValueType or array type. Try making an Object selector and see if that gives us anything useful.
|
||||
supplier = supplierFromObjectSelector(columnSelectorFactory.makeColumnValueSelector(columnName));
|
||||
} else {
|
||||
// Unhandleable ValueType (COMPLEX).
|
||||
|
@ -370,10 +382,10 @@ public class ExpressionSelectors
|
|||
// Might be Numbers and Strings. Use a selector that double-checks.
|
||||
return () -> {
|
||||
final Object val = selector.getObject();
|
||||
if (val instanceof Number || val instanceof String) {
|
||||
if (val instanceof Number || val instanceof String || (val != null && val.getClass().isArray())) {
|
||||
return val;
|
||||
} else if (val instanceof List) {
|
||||
return coerceListToArray((List) val);
|
||||
return ExprEval.coerceListToArray((List) val, true);
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
|
@ -382,7 +394,7 @@ public class ExpressionSelectors
|
|||
return () -> {
|
||||
final Object val = selector.getObject();
|
||||
if (val != null) {
|
||||
return coerceListToArray((List) val);
|
||||
return ExprEval.coerceListToArray((List) val, true);
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
@ -392,70 +404,6 @@ public class ExpressionSelectors
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Selectors are not consistent in treatment of null, [], and [null], so coerce [] to [null]
|
||||
*/
|
||||
public static Object coerceListToArray(@Nullable List<?> val)
|
||||
{
|
||||
if (val != null && val.size() > 0) {
|
||||
Class coercedType = null;
|
||||
|
||||
for (Object elem : val) {
|
||||
if (elem != null) {
|
||||
coercedType = convertType(coercedType, elem.getClass());
|
||||
}
|
||||
}
|
||||
|
||||
if (coercedType == Long.class || coercedType == Integer.class) {
|
||||
return val.stream().map(x -> x != null ? ((Number) x).longValue() : null).toArray(Long[]::new);
|
||||
}
|
||||
if (coercedType == Float.class || coercedType == Double.class) {
|
||||
return val.stream().map(x -> x != null ? ((Number) x).doubleValue() : null).toArray(Double[]::new);
|
||||
}
|
||||
// default to string
|
||||
return val.stream().map(x -> x != null ? x.toString() : null).toArray(String[]::new);
|
||||
}
|
||||
return new String[]{null};
|
||||
}
|
||||
|
||||
private static Class convertType(@Nullable Class existing, Class next)
|
||||
{
|
||||
if (Number.class.isAssignableFrom(next) || next == String.class) {
|
||||
if (existing == null) {
|
||||
return next;
|
||||
}
|
||||
// string wins everything
|
||||
if (existing == String.class) {
|
||||
return existing;
|
||||
}
|
||||
if (next == String.class) {
|
||||
return next;
|
||||
}
|
||||
// all numbers win over Integer
|
||||
if (existing == Integer.class) {
|
||||
return next;
|
||||
}
|
||||
if (existing == Float.class) {
|
||||
// doubles win over floats
|
||||
if (next == Double.class) {
|
||||
return next;
|
||||
}
|
||||
return existing;
|
||||
}
|
||||
if (existing == Long.class) {
|
||||
if (next == Integer.class) {
|
||||
// long beats int
|
||||
return existing;
|
||||
}
|
||||
// double and float win over longs
|
||||
return next;
|
||||
}
|
||||
// otherwise double
|
||||
return Double.class;
|
||||
}
|
||||
throw new UOE("Invalid array expression type: %s", next);
|
||||
}
|
||||
|
||||
/**
|
||||
* Coerces {@link ExprEval} value back to selector friendly {@link List} if the evaluated expression result is an
|
||||
* array type
|
||||
|
|
|
@ -72,7 +72,7 @@ public class ExpressionVirtualColumn implements VirtualColumn
|
|||
this.name = Preconditions.checkNotNull(name, "name");
|
||||
this.expression = Preconditions.checkNotNull(expression, "expression");
|
||||
this.outputType = outputType;
|
||||
this.parsedExpression = Suppliers.memoize(() -> Parser.parse(expression, macroTable));
|
||||
this.parsedExpression = Parser.lazyParse(expression, macroTable);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,570 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.druid.query.aggregation;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import nl.jqno.equalsverifier.EqualsVerifier;
|
||||
import org.apache.druid.java.util.common.HumanReadableBytes;
|
||||
import org.apache.druid.java.util.common.granularity.Granularities;
|
||||
import org.apache.druid.query.Druids;
|
||||
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
|
||||
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
|
||||
import org.apache.druid.query.expression.TestExprMacroTable;
|
||||
import org.apache.druid.query.timeseries.TimeseriesQuery;
|
||||
import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
|
||||
import org.apache.druid.segment.TestHelper;
|
||||
import org.apache.druid.segment.column.RowSignature;
|
||||
import org.apache.druid.segment.column.ValueType;
|
||||
import org.apache.druid.testing.InitializedNullHandlingTest;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandlingTest
|
||||
{
|
||||
private static ObjectMapper MAPPER = TestHelper.makeJsonMapper();
|
||||
|
||||
@Rule
|
||||
public ExpectedException expectedException = ExpectedException.none();
|
||||
|
||||
@Test
|
||||
public void testSerde() throws IOException
|
||||
{
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
"customAccumulator",
|
||||
"0.0",
|
||||
"10.0",
|
||||
"customAccumulator + some_column + some_other_column",
|
||||
"customAccumulator + expr_agg_name",
|
||||
"if (o1 > o2, if (o1 == o2, 0, 1), -1)",
|
||||
"o + 100",
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
Assert.assertEquals(agg, MAPPER.readValue(MAPPER.writeValueAsBytes(agg), ExpressionLambdaAggregatorFactory.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEqualsAndHashCode()
|
||||
{
|
||||
EqualsVerifier.forClass(ExpressionLambdaAggregatorFactory.class)
|
||||
.usingGetClass()
|
||||
.withIgnoredFields(
|
||||
"macroTable",
|
||||
"initialValue",
|
||||
"initialCombineValue",
|
||||
"foldExpression",
|
||||
"combineExpression",
|
||||
"compareExpression",
|
||||
"finalizeExpression",
|
||||
"compareBindings",
|
||||
"combineBindings",
|
||||
"finalizeBindings"
|
||||
)
|
||||
.verify();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInitialValueMustBeConstant()
|
||||
{
|
||||
expectedException.expect(IllegalArgumentException.class);
|
||||
expectedException.expectMessage("initial value must be constant");
|
||||
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"x + y",
|
||||
null,
|
||||
"__acc + some_column + some_other_column",
|
||||
"__acc + expr_agg_name",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
agg.getType();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInitialCombineValueMustBeConstant()
|
||||
{
|
||||
expectedException.expect(IllegalArgumentException.class);
|
||||
expectedException.expectMessage("initial combining value must be constant");
|
||||
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"0.0",
|
||||
"x + y",
|
||||
"__acc + some_column + some_other_column",
|
||||
"__acc + expr_agg_name",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
agg.getFinalizedType();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSingleInputCombineExpressionIsOptional()
|
||||
{
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("x"),
|
||||
null,
|
||||
"0",
|
||||
null,
|
||||
"__acc + x",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
Assert.assertEquals(1L, agg.combine(0L, 1L));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFinalizeCanDo()
|
||||
{
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("x"),
|
||||
null,
|
||||
"0",
|
||||
null,
|
||||
"__acc + x",
|
||||
null,
|
||||
null,
|
||||
"o + 100",
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
Assert.assertEquals(100L, agg.finalizeComputation(0L));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFinalizeCanDoArrays()
|
||||
{
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("x"),
|
||||
null,
|
||||
"0",
|
||||
null,
|
||||
"array_set_add(__acc, x)",
|
||||
"array_set_add_all(__acc, expr_agg_name)",
|
||||
null,
|
||||
"array_to_string(o, ',')",
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
Assert.assertEquals("a,b,c", agg.finalizeComputation(new String[]{"a", "b", "c"}));
|
||||
Assert.assertEquals("a,b,c", agg.finalizeComputation(ImmutableList.of("a", "b", "c")));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStringType()
|
||||
{
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"''",
|
||||
"''",
|
||||
"concat(__acc, some_column, some_other_column)",
|
||||
"concat(__acc, expr_agg_name)",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
Assert.assertEquals(ValueType.STRING, agg.getType());
|
||||
Assert.assertEquals(ValueType.STRING, agg.getCombiningFactory().getType());
|
||||
Assert.assertEquals(ValueType.STRING, agg.getFinalizedType());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLongType()
|
||||
{
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"0",
|
||||
null,
|
||||
"__acc + some_column + some_other_column",
|
||||
"__acc + expr_agg_name",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
Assert.assertEquals(ValueType.LONG, agg.getType());
|
||||
Assert.assertEquals(ValueType.LONG, agg.getCombiningFactory().getType());
|
||||
Assert.assertEquals(ValueType.LONG, agg.getFinalizedType());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDoubleType()
|
||||
{
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"0.0",
|
||||
null,
|
||||
"__acc + some_column + some_other_column",
|
||||
"__acc + expr_agg_name",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
Assert.assertEquals(ValueType.DOUBLE, agg.getType());
|
||||
Assert.assertEquals(ValueType.DOUBLE, agg.getCombiningFactory().getType());
|
||||
Assert.assertEquals(ValueType.DOUBLE, agg.getFinalizedType());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStringArrayType()
|
||||
{
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"''",
|
||||
"<STRING>[]",
|
||||
"concat(__acc, some_column, some_other_column)",
|
||||
"array_set_add(__acc, expr_agg_name)",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
Assert.assertEquals(ValueType.STRING, agg.getType());
|
||||
Assert.assertEquals(ValueType.STRING_ARRAY, agg.getCombiningFactory().getType());
|
||||
Assert.assertEquals(ValueType.STRING_ARRAY, agg.getFinalizedType());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStringArrayTypeFinalized()
|
||||
{
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"''",
|
||||
"<STRING>[]",
|
||||
"concat(__acc, some_column, some_other_column)",
|
||||
"array_set_add(__acc, expr_agg_name)",
|
||||
null,
|
||||
"array_to_string(o, ';')",
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
Assert.assertEquals(ValueType.STRING, agg.getType());
|
||||
Assert.assertEquals(ValueType.STRING_ARRAY, agg.getCombiningFactory().getType());
|
||||
Assert.assertEquals(ValueType.STRING, agg.getFinalizedType());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLongArrayType()
|
||||
{
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"0",
|
||||
"<LONG>[]",
|
||||
"__acc + some_column + some_other_column",
|
||||
"array_set_add(__acc, expr_agg_name)",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
Assert.assertEquals(ValueType.LONG, agg.getType());
|
||||
Assert.assertEquals(ValueType.LONG_ARRAY, agg.getCombiningFactory().getType());
|
||||
Assert.assertEquals(ValueType.LONG_ARRAY, agg.getFinalizedType());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLongArrayTypeFinalized()
|
||||
{
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"0",
|
||||
"<LONG>[]",
|
||||
"__acc + some_column + some_other_column",
|
||||
"array_set_add(__acc, expr_agg_name)",
|
||||
null,
|
||||
"array_to_string(o, ';')",
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
Assert.assertEquals(ValueType.LONG, agg.getType());
|
||||
Assert.assertEquals(ValueType.LONG_ARRAY, agg.getCombiningFactory().getType());
|
||||
Assert.assertEquals(ValueType.STRING, agg.getFinalizedType());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDoubleArrayType()
|
||||
{
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"0.0",
|
||||
"<DOUBLE>[]",
|
||||
"__acc + some_column + some_other_column",
|
||||
"array_set_add(__acc, expr_agg_name)",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
Assert.assertEquals(ValueType.DOUBLE, agg.getType());
|
||||
Assert.assertEquals(ValueType.DOUBLE_ARRAY, agg.getCombiningFactory().getType());
|
||||
Assert.assertEquals(ValueType.DOUBLE_ARRAY, agg.getFinalizedType());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDoubleArrayTypeFinalized()
|
||||
{
|
||||
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
|
||||
"expr_agg_name",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"0.0",
|
||||
"<DOUBLE>[]",
|
||||
"__acc + some_column + some_other_column",
|
||||
"array_set_add(__acc, expr_agg_name)",
|
||||
null,
|
||||
"array_to_string(o, ';')",
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
);
|
||||
|
||||
Assert.assertEquals(ValueType.DOUBLE, agg.getType());
|
||||
Assert.assertEquals(ValueType.DOUBLE_ARRAY, agg.getCombiningFactory().getType());
|
||||
Assert.assertEquals(ValueType.STRING, agg.getFinalizedType());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResultArraySignature()
|
||||
{
|
||||
final TimeseriesQuery query =
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource("dummy")
|
||||
.intervals("2000/3000")
|
||||
.granularity(Granularities.HOUR)
|
||||
.aggregators(
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"string_expr",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"''",
|
||||
"''",
|
||||
"concat(__acc, some_column, some_other_column)",
|
||||
"concat(__acc, string_expr)",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"double_expr",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"0.0",
|
||||
null,
|
||||
"__acc + some_column + some_other_column",
|
||||
"__acc + double_expr",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"long_expr",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"0",
|
||||
null,
|
||||
"__acc + some_column + some_other_column",
|
||||
"__acc + long_expr",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"string_array_expr",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"<STRING>[]",
|
||||
"<STRING>[]",
|
||||
"array_set_add(__acc, concat(some_column, some_other_column))",
|
||||
"array_set_add_all(__acc, string_array_expr)",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"double_array_expr",
|
||||
ImmutableSet.of("some_column", "some_other_column_expr"),
|
||||
null,
|
||||
"0.0",
|
||||
"<DOUBLE>[]",
|
||||
"__acc + some_column + some_other_column",
|
||||
"array_set_add(__acc, double_array)",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"long_array_expr",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"0",
|
||||
"<LONG>[]",
|
||||
"__acc + some_column + some_other_column",
|
||||
"array_set_add(__acc, long_array_expr)",
|
||||
null,
|
||||
null,
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"string_array_expr_finalized",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"''",
|
||||
"<STRING>[]",
|
||||
"concat(__acc, some_column, some_other_column)",
|
||||
"array_set_add(__acc, string_array_expr)",
|
||||
null,
|
||||
"array_to_string(o, ';')",
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"double_array_expr_finalized",
|
||||
ImmutableSet.of("some_column", "some_other_column_expr"),
|
||||
null,
|
||||
"0.0",
|
||||
"<DOUBLE>[]",
|
||||
"__acc + some_column + some_other_column",
|
||||
"array_set_add(__acc, double_array)",
|
||||
null,
|
||||
"array_to_string(o, ';')",
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"long_array_expr_finalized",
|
||||
ImmutableSet.of("some_column", "some_other_column"),
|
||||
null,
|
||||
"0",
|
||||
"<LONG>[]",
|
||||
"__acc + some_column + some_other_column",
|
||||
"array_set_add(__acc, long_array_expr)",
|
||||
null,
|
||||
"fold((x, acc) -> x + acc, o, 0)",
|
||||
new HumanReadableBytes(2048),
|
||||
TestExprMacroTable.INSTANCE
|
||||
)
|
||||
)
|
||||
.postAggregators(
|
||||
new FieldAccessPostAggregator("string-array-expr-access", "string_array_expr_finalized"),
|
||||
new FinalizingFieldAccessPostAggregator("string-array-expr-finalize", "string_array_expr_finalized"),
|
||||
new FieldAccessPostAggregator("double-array-expr-access", "double_array_expr_finalized"),
|
||||
new FinalizingFieldAccessPostAggregator("double-array-expr-finalize", "double_array_expr_finalized"),
|
||||
new FieldAccessPostAggregator("long-array-expr-access", "long_array_expr_finalized"),
|
||||
new FinalizingFieldAccessPostAggregator("long-array-expr-finalize", "long_array_expr_finalized")
|
||||
)
|
||||
.build();
|
||||
|
||||
Assert.assertEquals(
|
||||
RowSignature.builder()
|
||||
.addTimeColumn()
|
||||
.add("string_expr", ValueType.STRING)
|
||||
.add("double_expr", ValueType.DOUBLE)
|
||||
.add("long_expr", ValueType.LONG)
|
||||
.add("string_array_expr", ValueType.STRING_ARRAY)
|
||||
// type does not equal finalized type. (combining factory type does equal finalized type,
|
||||
// but this signature doesn't use combining factory)
|
||||
.add("double_array_expr", null)
|
||||
// type does not equal finalized type. (combining factory type does equal finalized type,
|
||||
// but this signature doesn't use combining factory)
|
||||
.add("long_array_expr", null)
|
||||
// string because fold type equals finalized type, even though merge type is array
|
||||
.add("string_array_expr_finalized", ValueType.STRING)
|
||||
// type does not equal finalized type. (combining factory type does equal finalized type,
|
||||
// but this signature doesn't use combining factory)
|
||||
.add("double_array_expr_finalized", null)
|
||||
// long because fold type equals finalized type, even though merge type is array
|
||||
.add("long_array_expr_finalized", ValueType.LONG)
|
||||
// fold type is string
|
||||
.add("string-array-expr-access", ValueType.STRING)
|
||||
// finalized type is string
|
||||
.add("string-array-expr-finalize", ValueType.STRING)
|
||||
// double because fold type is double
|
||||
.add("double-array-expr-access", ValueType.DOUBLE)
|
||||
// string because finalize type is string
|
||||
.add("double-array-expr-finalize", ValueType.STRING)
|
||||
// long because fold type is long
|
||||
.add("long-array-expr-access", ValueType.LONG)
|
||||
// finalized type is long
|
||||
.add("long-array-expr-finalize", ValueType.LONG)
|
||||
.build(),
|
||||
new TimeseriesQueryQueryToolChest().resultArraySignature(query)
|
||||
);
|
||||
}
|
||||
}
|
|
@ -25,6 +25,7 @@ import com.google.common.base.Supplier;
|
|||
import com.google.common.base.Suppliers;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Iterables;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Ordering;
|
||||
|
@ -67,6 +68,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
|
|||
import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.FloatSumAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.JavaScriptAggregatorFactory;
|
||||
|
@ -11138,6 +11140,704 @@ public class GroupByQueryRunnerTest extends InitializedNullHandlingTest
|
|||
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGroupByWithExpressionAggregator()
|
||||
{
|
||||
// expression agg not yet vectorized
|
||||
cannotVectorize();
|
||||
GroupByQuery query = makeQueryBuilder()
|
||||
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
|
||||
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
|
||||
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
|
||||
.setAggregatorSpecs(
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"rows",
|
||||
Collections.emptySet(),
|
||||
null,
|
||||
"0",
|
||||
null,
|
||||
"__acc + 1",
|
||||
"__acc + rows",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"idx",
|
||||
ImmutableSet.of("index"),
|
||||
null,
|
||||
"0.0",
|
||||
null,
|
||||
"__acc + index",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
)
|
||||
)
|
||||
.setGranularity(QueryRunnerTestHelper.DAY_GRAN)
|
||||
.build();
|
||||
|
||||
List<ResultRow> expectedResults = Arrays.asList(
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"automotive",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
135.88510131835938d
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"business",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
118.57034
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"entertainment",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
158.747224
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"health",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
120.134704
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"mezzanine",
|
||||
"rows",
|
||||
3L,
|
||||
"idx",
|
||||
2871.8866900000003d
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"news",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
121.58358d
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"premium",
|
||||
"rows",
|
||||
3L,
|
||||
"idx",
|
||||
2900.798647d
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"technology",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
78.622547d
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"travel",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
119.922742d
|
||||
),
|
||||
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"automotive",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
147.42593d
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"business",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
112.987027d
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"entertainment",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
166.016049d
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"health",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
113.446008d
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"mezzanine",
|
||||
"rows",
|
||||
3L,
|
||||
"idx",
|
||||
2448.830613d
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"news",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
114.290141d
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"premium",
|
||||
"rows",
|
||||
3L,
|
||||
"idx",
|
||||
2506.415148d
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"technology",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
97.387433d
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"travel",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
126.411364d
|
||||
)
|
||||
);
|
||||
|
||||
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
|
||||
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGroupByWithExpressionAggregatorWithArrays()
|
||||
{
|
||||
// expression agg not yet vectorized
|
||||
cannotVectorize();
|
||||
|
||||
// array types don't work with group by v1
|
||||
if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) {
|
||||
expectedException.expect(IllegalStateException.class);
|
||||
expectedException.expectMessage("Unable to handle type[STRING_ARRAY] for AggregatorFactory[class org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory]");
|
||||
}
|
||||
|
||||
GroupByQuery query = makeQueryBuilder()
|
||||
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
|
||||
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
|
||||
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
|
||||
.setAggregatorSpecs(
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"rows",
|
||||
Collections.emptySet(),
|
||||
null,
|
||||
"0",
|
||||
null,
|
||||
"__acc + 1",
|
||||
"__acc + rows",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"idx",
|
||||
ImmutableSet.of("index"),
|
||||
null,
|
||||
"0.0",
|
||||
null,
|
||||
"__acc + index",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"array_agg_distinct",
|
||||
ImmutableSet.of(QueryRunnerTestHelper.MARKET_DIMENSION),
|
||||
"acc",
|
||||
"[]",
|
||||
null,
|
||||
"array_set_add(acc, market)",
|
||||
"array_set_add_all(acc, array_agg_distinct)",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
)
|
||||
)
|
||||
.setGranularity(QueryRunnerTestHelper.DAY_GRAN)
|
||||
.build();
|
||||
|
||||
List<ResultRow> expectedResults = Arrays.asList(
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"automotive",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
135.88510131835938d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"business",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
118.57034,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"entertainment",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
158.747224,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"health",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
120.134704,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"mezzanine",
|
||||
"rows",
|
||||
3L,
|
||||
"idx",
|
||||
2871.8866900000003d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"upfront", "spot", "total_market"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"news",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
121.58358d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"premium",
|
||||
"rows",
|
||||
3L,
|
||||
"idx",
|
||||
2900.798647d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"upfront", "spot", "total_market"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"technology",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
78.622547d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"travel",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
119.922742d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
),
|
||||
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"automotive",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
147.42593d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"business",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
112.987027d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"entertainment",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
166.016049d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"health",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
113.446008d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"mezzanine",
|
||||
"rows",
|
||||
3L,
|
||||
"idx",
|
||||
2448.830613d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"upfront", "spot", "total_market"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"news",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
114.290141d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"premium",
|
||||
"rows",
|
||||
3L,
|
||||
"idx",
|
||||
2506.415148d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"upfront", "spot", "total_market"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"technology",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
97.387433d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"travel",
|
||||
"rows",
|
||||
1L,
|
||||
"idx",
|
||||
126.411364d,
|
||||
"array_agg_distinct",
|
||||
new String[] {"spot"}
|
||||
)
|
||||
);
|
||||
|
||||
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
|
||||
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGroupByExpressionAggregatorArrayMultiValue()
|
||||
{
|
||||
// expression agg not yet vectorized
|
||||
cannotVectorize();
|
||||
|
||||
// array types don't work with group by v1
|
||||
if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) {
|
||||
expectedException.expect(IllegalStateException.class);
|
||||
expectedException.expectMessage("Unable to handle type[STRING_ARRAY] for AggregatorFactory[class org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory]");
|
||||
}
|
||||
|
||||
GroupByQuery query = makeQueryBuilder()
|
||||
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
|
||||
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
|
||||
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
|
||||
.setAggregatorSpecs(
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"array_agg_distinct",
|
||||
ImmutableSet.of(QueryRunnerTestHelper.PLACEMENTISH_DIMENSION),
|
||||
"acc",
|
||||
"[]",
|
||||
null,
|
||||
"array_set_add(acc, placementish)",
|
||||
"array_set_add_all(acc, array_agg_distinct)",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
)
|
||||
)
|
||||
.setGranularity(QueryRunnerTestHelper.DAY_GRAN)
|
||||
.build();
|
||||
|
||||
List<ResultRow> expectedResults = Arrays.asList(
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"automotive",
|
||||
"array_agg_distinct",
|
||||
new String[] {"a", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"business",
|
||||
"array_agg_distinct",
|
||||
new String[] {"b", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"entertainment",
|
||||
"array_agg_distinct",
|
||||
new String[] {"e", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"health",
|
||||
"array_agg_distinct",
|
||||
new String[] {"h", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"mezzanine",
|
||||
"array_agg_distinct",
|
||||
new String[] {"m", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"news",
|
||||
"array_agg_distinct",
|
||||
new String[] {"n", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"premium",
|
||||
"array_agg_distinct",
|
||||
new String[] {"p", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"technology",
|
||||
"array_agg_distinct",
|
||||
new String[] {"t", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-01",
|
||||
"alias",
|
||||
"travel",
|
||||
"array_agg_distinct",
|
||||
new String[] {"t", "preferred"}
|
||||
),
|
||||
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"automotive",
|
||||
"array_agg_distinct",
|
||||
new String[] {"a", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"business",
|
||||
"array_agg_distinct",
|
||||
new String[] {"b", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"entertainment",
|
||||
"array_agg_distinct",
|
||||
new String[] {"e", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"health",
|
||||
"array_agg_distinct",
|
||||
new String[] {"h", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"mezzanine",
|
||||
"array_agg_distinct",
|
||||
new String[] {"m", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"news",
|
||||
"array_agg_distinct",
|
||||
new String[] {"n", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"premium",
|
||||
"array_agg_distinct",
|
||||
new String[] {"p", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"technology",
|
||||
"array_agg_distinct",
|
||||
new String[] {"t", "preferred"}
|
||||
),
|
||||
makeRow(
|
||||
query,
|
||||
"2011-04-02",
|
||||
"alias",
|
||||
"travel",
|
||||
"array_agg_distinct",
|
||||
new String[] {"t", "preferred"}
|
||||
)
|
||||
);
|
||||
|
||||
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
|
||||
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
|
||||
}
|
||||
|
||||
private static ResultRow makeRow(final GroupByQuery query, final String timestamp, final Object... vals)
|
||||
{
|
||||
return GroupByQueryRunnerTestHelper.createExpectedRow(query, timestamp, vals);
|
||||
|
|
|
@ -303,32 +303,4 @@ public class GroupByTimeseriesQueryRunnerTest extends TimeseriesQueryRunnerTest
|
|||
// Skip this test because the timeseries test expects a day that doesn't have a filter match to be filled in,
|
||||
// but group by just doesn't return a value if the filter doesn't match.
|
||||
}
|
||||
|
||||
@Override
|
||||
public void testTimeseriesWithTimestampResultFieldContextForArrayResponse()
|
||||
{
|
||||
// Cannot vectorize with an expression virtual column
|
||||
if (!vectorize) {
|
||||
super.testTimeseriesWithTimestampResultFieldContextForArrayResponse();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void testTimeseriesWithTimestampResultFieldContextForMapResponse()
|
||||
{
|
||||
// Cannot vectorize with an expression virtual column
|
||||
if (!vectorize) {
|
||||
super.testTimeseriesWithTimestampResultFieldContextForMapResponse();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@Test
|
||||
public void testTimeseriesWithPostAggregatorReferencingTimestampResultField()
|
||||
{
|
||||
// Cannot vectorize with an expression virtual column
|
||||
if (!vectorize) {
|
||||
super.testTimeseriesWithPostAggregatorReferencingTimestampResultField();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ package org.apache.druid.query.timeseries;
|
|||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Iterables;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.primitives.Doubles;
|
||||
|
@ -44,6 +45,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
|
|||
import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory;
|
||||
|
@ -2935,6 +2937,104 @@ public class TimeseriesQueryRunnerTest extends InitializedNullHandlingTest
|
|||
assertExpectedResults(expectedResults, results);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTimeseriesWithExpressionAggregator()
|
||||
{
|
||||
// expression agg cannot vectorize
|
||||
cannotVectorize();
|
||||
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(QueryRunnerTestHelper.DATA_SOURCE)
|
||||
.granularity(QueryRunnerTestHelper.DAY_GRAN)
|
||||
.intervals(QueryRunnerTestHelper.FIRST_TO_THIRD)
|
||||
.aggregators(
|
||||
Arrays.asList(
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"diy_count",
|
||||
ImmutableSet.of(),
|
||||
null,
|
||||
"0",
|
||||
null,
|
||||
"__acc + 1",
|
||||
"__acc + diy_count",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"diy_sum",
|
||||
ImmutableSet.of("index"),
|
||||
null,
|
||||
"0.0",
|
||||
null,
|
||||
"__acc + index",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"diy_decomposed_sum",
|
||||
ImmutableSet.of("index"),
|
||||
null,
|
||||
"0.0",
|
||||
"<DOUBLE>[]",
|
||||
"__acc + index",
|
||||
"array_concat(__acc, diy_decomposed_sum)",
|
||||
null,
|
||||
"fold((x, acc) -> x + acc, o, 0.0)",
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"array_agg_distinct",
|
||||
ImmutableSet.of(QueryRunnerTestHelper.MARKET_DIMENSION),
|
||||
"acc",
|
||||
"[]",
|
||||
null,
|
||||
"array_set_add(acc, market)",
|
||||
"array_set_add_all(acc, array_agg_distinct)",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
)
|
||||
)
|
||||
)
|
||||
.descending(descending)
|
||||
.context(makeContext())
|
||||
.build();
|
||||
|
||||
List<Result<TimeseriesResultValue>> expectedResults = Arrays.asList(
|
||||
new Result<>(
|
||||
DateTimes.of("2011-04-01"),
|
||||
new TimeseriesResultValue(
|
||||
ImmutableMap.of(
|
||||
"diy_count", 13L,
|
||||
"diy_sum", 6626.151569,
|
||||
"diy_decomposed_sum", 6626.151569,
|
||||
"array_agg_distinct", new String[] {"upfront", "spot", "total_market"}
|
||||
)
|
||||
)
|
||||
),
|
||||
new Result<>(
|
||||
DateTimes.of("2011-04-02"),
|
||||
new TimeseriesResultValue(
|
||||
ImmutableMap.of(
|
||||
"diy_count", 13L,
|
||||
"diy_sum", 5833.209718,
|
||||
"diy_decomposed_sum", 5833.209718,
|
||||
"array_agg_distinct", new String[] {"upfront", "spot", "total_market"}
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
Iterable<Result<TimeseriesResultValue>> results = runner.run(QueryPlus.wrap(query)).toList();
|
||||
assertExpectedResults(expectedResults, results);
|
||||
}
|
||||
|
||||
private Map<String, Object> makeContext()
|
||||
{
|
||||
return makeContext(ImmutableMap.of());
|
||||
|
|
|
@ -22,6 +22,7 @@ package org.apache.druid.query.topn;
|
|||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Iterables;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
|
@ -52,6 +53,7 @@ import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
|||
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.FloatMinAggregatorFactory;
|
||||
|
@ -5964,6 +5966,110 @@ public class TopNQueryRunnerTest extends InitializedNullHandlingTest
|
|||
assertExpectedResults(expectedResults, query);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExpressionAggregator()
|
||||
{
|
||||
// sorted by array length of array_agg_distinct
|
||||
TopNQuery query = new TopNQueryBuilder()
|
||||
.dataSource(QueryRunnerTestHelper.DATA_SOURCE)
|
||||
.granularity(QueryRunnerTestHelper.ALL_GRAN)
|
||||
.dimension(QueryRunnerTestHelper.MARKET_DIMENSION)
|
||||
.metric("array_agg_distinct")
|
||||
.threshold(4)
|
||||
.intervals(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
|
||||
.aggregators(
|
||||
Lists.newArrayList(
|
||||
Arrays.asList(
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"diy_count",
|
||||
Collections.emptySet(),
|
||||
null,
|
||||
"0",
|
||||
null,
|
||||
"__acc + 1",
|
||||
"__acc + diy_count",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"diy_sum",
|
||||
ImmutableSet.of("index"),
|
||||
null,
|
||||
"0.0",
|
||||
null,
|
||||
"__acc + index",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"diy_decomposed_sum",
|
||||
ImmutableSet.of("index"),
|
||||
null,
|
||||
"0.0",
|
||||
"<DOUBLE>[]",
|
||||
"__acc + index",
|
||||
"array_concat(__acc, diy_decomposed_sum)",
|
||||
null,
|
||||
"fold((x, acc) -> x + acc, o, 0.0)",
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
),
|
||||
new ExpressionLambdaAggregatorFactory(
|
||||
"array_agg_distinct",
|
||||
ImmutableSet.of(QueryRunnerTestHelper.QUALITY_DIMENSION),
|
||||
"acc",
|
||||
"[]",
|
||||
null,
|
||||
"array_set_add(acc, quality)",
|
||||
"array_set_add_all(acc, array_agg_distinct)",
|
||||
"if(array_length(o1) > array_length(o2), 1, if (array_length(o1) == array_length(o2), 0, -1))",
|
||||
null,
|
||||
null,
|
||||
TestExprMacroTable.INSTANCE
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
.build();
|
||||
|
||||
List<Result<TopNResultValue>> expectedResults = Collections.singletonList(
|
||||
new Result<>(
|
||||
DateTimes.of("2011-01-12T00:00:00.000Z"),
|
||||
new TopNResultValue(
|
||||
Arrays.<Map<String, Object>>asList(
|
||||
ImmutableMap.<String, Object>builder()
|
||||
.put(QueryRunnerTestHelper.MARKET_DIMENSION, "spot")
|
||||
.put("diy_count", 837L)
|
||||
.put("diy_sum", 95606.57232284546D)
|
||||
.put("diy_decomposed_sum", 95606.57232284546D)
|
||||
.put("array_agg_distinct", new String[]{"mezzanine", "news", "premium", "business", "entertainment", "health", "technology", "automotive", "travel"})
|
||||
.build(),
|
||||
ImmutableMap.<String, Object>builder()
|
||||
.put(QueryRunnerTestHelper.MARKET_DIMENSION, "total_market")
|
||||
.put("diy_count", 186L)
|
||||
.put("diy_sum", 215679.82879638672D)
|
||||
.put("diy_decomposed_sum", 215679.82879638672D)
|
||||
.put("array_agg_distinct", new String[]{"mezzanine", "premium"})
|
||||
.build(),
|
||||
ImmutableMap.<String, Object>builder()
|
||||
.put(QueryRunnerTestHelper.MARKET_DIMENSION, "upfront")
|
||||
.put("diy_count", 186L)
|
||||
.put("diy_sum", 192046.1060180664D)
|
||||
.put("diy_decomposed_sum", 192046.1060180664D)
|
||||
.put("array_agg_distinct", new String[]{"mezzanine", "premium"})
|
||||
.build()
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
assertExpectedResults(expectedResults, query);
|
||||
}
|
||||
|
||||
private static Map<String, Object> makeRowWithNulls(
|
||||
String dimName,
|
||||
@Nullable Object dimValue,
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.apache.druid.guice.GuiceAnnotationIntrospector;
|
|||
import org.apache.druid.jackson.DefaultObjectMapper;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.java.util.common.guava.Sequence;
|
||||
import org.apache.druid.math.expr.ExprEval;
|
||||
import org.apache.druid.math.expr.ExprMacroTable;
|
||||
import org.apache.druid.query.Result;
|
||||
import org.apache.druid.query.expression.TestExprMacroTable;
|
||||
|
@ -352,7 +353,9 @@ public class TestHelper
|
|||
final Object expectedValue = expectedMap.get(key);
|
||||
final Object actualValue = actualMap.get(key);
|
||||
|
||||
if (expectedValue instanceof Float || expectedValue instanceof Double) {
|
||||
if (expectedValue != null && expectedValue.getClass().isArray()) {
|
||||
Assert.assertArrayEquals((Object[]) expectedValue, (Object[]) actualValue);
|
||||
} else if (expectedValue instanceof Float || expectedValue instanceof Double) {
|
||||
Assert.assertEquals(
|
||||
StringUtils.format("%s: key[%s]", msg, key),
|
||||
((Number) expectedValue).doubleValue(),
|
||||
|
@ -382,7 +385,23 @@ public class TestHelper
|
|||
final Object expectedValue = expected.get(i);
|
||||
final Object actualValue = actual.get(i);
|
||||
|
||||
if (expectedValue instanceof Float || expectedValue instanceof Double) {
|
||||
|
||||
if (expectedValue != null && expectedValue.getClass().isArray()) {
|
||||
// spilled results will materialize into lists, coerce them back to arrays if we expected arrays
|
||||
if (actualValue instanceof List) {
|
||||
Assert.assertEquals(
|
||||
message,
|
||||
(Object[]) expectedValue,
|
||||
(Object[]) ExprEval.coerceListToArray((List) actualValue, true)
|
||||
);
|
||||
} else {
|
||||
Assert.assertArrayEquals(
|
||||
message,
|
||||
(Object[]) expectedValue,
|
||||
(Object[]) actualValue
|
||||
);
|
||||
}
|
||||
} else if (expectedValue instanceof Float || expectedValue instanceof Double) {
|
||||
Assert.assertEquals(
|
||||
message,
|
||||
((Number) expectedValue).doubleValue(),
|
||||
|
|
|
@ -243,77 +243,6 @@ public class ExpressionSelectorsTest extends InitializedNullHandlingTest
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_coerceListToArray()
|
||||
{
|
||||
List<Long> longList = ImmutableList.of(1L, 2L, 3L);
|
||||
Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExpressionSelectors.coerceListToArray(longList));
|
||||
|
||||
List<Integer> intList = ImmutableList.of(1, 2, 3);
|
||||
Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExpressionSelectors.coerceListToArray(intList));
|
||||
|
||||
List<Float> floatList = ImmutableList.of(1.0f, 2.0f, 3.0f);
|
||||
Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExpressionSelectors.coerceListToArray(floatList));
|
||||
|
||||
List<Double> doubleList = ImmutableList.of(1.0, 2.0, 3.0);
|
||||
Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExpressionSelectors.coerceListToArray(doubleList));
|
||||
|
||||
List<String> stringList = ImmutableList.of("a", "b", "c");
|
||||
Assert.assertArrayEquals(new String[]{"a", "b", "c"}, (String[]) ExpressionSelectors.coerceListToArray(stringList));
|
||||
|
||||
List<String> withNulls = new ArrayList<>();
|
||||
withNulls.add("a");
|
||||
withNulls.add(null);
|
||||
withNulls.add("c");
|
||||
Assert.assertArrayEquals(new String[]{"a", null, "c"}, (String[]) ExpressionSelectors.coerceListToArray(withNulls));
|
||||
|
||||
List<Long> withNumberNulls = new ArrayList<>();
|
||||
withNumberNulls.add(1L);
|
||||
withNumberNulls.add(null);
|
||||
withNumberNulls.add(3L);
|
||||
|
||||
Assert.assertArrayEquals(new Long[]{1L, null, 3L}, (Long[]) ExpressionSelectors.coerceListToArray(withNumberNulls));
|
||||
|
||||
List<Object> withStringMix = ImmutableList.of(1L, "b", 3L);
|
||||
Assert.assertArrayEquals(
|
||||
new String[]{"1", "b", "3"},
|
||||
(String[]) ExpressionSelectors.coerceListToArray(withStringMix)
|
||||
);
|
||||
|
||||
List<Number> withIntsAndLongs = ImmutableList.of(1, 2L, 3);
|
||||
Assert.assertArrayEquals(
|
||||
new Long[]{1L, 2L, 3L},
|
||||
(Long[]) ExpressionSelectors.coerceListToArray(withIntsAndLongs)
|
||||
);
|
||||
|
||||
List<Number> withFloatsAndLongs = ImmutableList.of(1, 2L, 3.0f);
|
||||
Assert.assertArrayEquals(
|
||||
new Double[]{1.0, 2.0, 3.0},
|
||||
(Double[]) ExpressionSelectors.coerceListToArray(withFloatsAndLongs)
|
||||
);
|
||||
|
||||
List<Number> withDoublesAndLongs = ImmutableList.of(1, 2L, 3.0);
|
||||
Assert.assertArrayEquals(
|
||||
new Double[]{1.0, 2.0, 3.0},
|
||||
(Double[]) ExpressionSelectors.coerceListToArray(withDoublesAndLongs)
|
||||
);
|
||||
|
||||
List<Number> withFloatsAndDoubles = ImmutableList.of(1L, 2.0f, 3.0);
|
||||
Assert.assertArrayEquals(
|
||||
new Double[]{1.0, 2.0, 3.0},
|
||||
(Double[]) ExpressionSelectors.coerceListToArray(withFloatsAndDoubles)
|
||||
);
|
||||
|
||||
List<String> withAllNulls = new ArrayList<>();
|
||||
withAllNulls.add(null);
|
||||
withAllNulls.add(null);
|
||||
withAllNulls.add(null);
|
||||
Assert.assertArrayEquals(
|
||||
new String[]{null, null, null},
|
||||
(String[]) ExpressionSelectors.coerceListToArray(withAllNulls)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_coerceEvalToSelectorObject()
|
||||
{
|
||||
|
|
|
@ -1100,6 +1100,8 @@ arr1
|
|||
arr2
|
||||
array_append
|
||||
array_concat
|
||||
array_set_add
|
||||
array_set_add_all
|
||||
array_contains
|
||||
array_length
|
||||
array_offset
|
||||
|
|
Loading…
Reference in New Issue