expression aggregator (#11104)

* add experimental expression aggregator

* add test

* fix lgtm

* fix test

* adjust test

* use not null constant

* array_set_concat docs

* add equals and hashcode and tostring

* fix it

* spelling

* do multi-value magic for expression agg, more javadocs, tests

* formatting

* fix inspection

* more better

* nullable
This commit is contained in:
Clint Wylie 2021-04-22 18:30:16 -07:00 committed by GitHub
parent 49a9c3ffb7
commit 57ff1f9cdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 3511 additions and 450 deletions

View File

@ -183,9 +183,9 @@ class NullLongExpr extends ConstantExpr<Long>
class LongArrayExpr extends ConstantExpr<Long[]>
{
LongArrayExpr(Long[] value)
LongArrayExpr(@Nullable Long[] value)
{
super(ExprType.LONG_ARRAY, Preconditions.checkNotNull(value, "value"));
super(ExprType.LONG_ARRAY, value);
}
@Override
@ -320,9 +320,9 @@ class NullDoubleExpr extends ConstantExpr<Double>
class DoubleArrayExpr extends ConstantExpr<Double[]>
{
DoubleArrayExpr(Double[] value)
DoubleArrayExpr(@Nullable Double[] value)
{
super(ExprType.DOUBLE_ARRAY, Preconditions.checkNotNull(value, "value"));
super(ExprType.DOUBLE_ARRAY, value);
}
@Override
@ -426,9 +426,9 @@ class StringExpr extends ConstantExpr<String>
class StringArrayExpr extends ConstantExpr<String[]>
{
StringArrayExpr(String[] value)
StringArrayExpr(@Nullable String[] value)
{
super(ExprType.STRING_ARRAY, Preconditions.checkNotNull(value, "value"));
super(ExprType.STRING_ARRAY, value);
}
@Override

View File

@ -23,15 +23,357 @@ import com.google.common.primitives.Doubles;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.common.guava.GuavaUtils;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.UOE;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
/**
* Generic result holder for evaluated {@link Expr} containing the value and {@link ExprType} of the value to allow
*/
public abstract class ExprEval<T>
{
private static final int NULL_LENGTH = -1;
/**
* Deserialize an expression stored in a bytebuffer, e.g. for an agg.
*
* This should be refactored to be consolidated with some of the standard type handling of aggregators probably
*/
public static ExprEval deserialize(ByteBuffer buffer, int position)
{
// | expression type (byte) | expression bytes |
ExprType type = ExprType.fromByte(buffer.get(position));
int offset = position + 1;
switch (type) {
case LONG:
// | expression type (byte) | is null (byte) | long bytes |
if (buffer.get(offset++) == NullHandling.IS_NOT_NULL_BYTE) {
return of(buffer.getLong(offset));
}
return ofLong(null);
case DOUBLE:
// | expression type (byte) | is null (byte) | double bytes |
if (buffer.get(offset++) == NullHandling.IS_NOT_NULL_BYTE) {
return of(buffer.getDouble(offset));
}
return ofDouble(null);
case STRING:
// | expression type (byte) | string length (int) | string bytes |
final int length = buffer.getInt(offset);
if (length < 0) {
return of(null);
}
final byte[] stringBytes = new byte[length];
final int oldPosition = buffer.position();
buffer.position(offset + Integer.BYTES);
buffer.get(stringBytes, 0, length);
buffer.position(oldPosition);
return of(StringUtils.fromUtf8(stringBytes));
case LONG_ARRAY:
// | expression type (byte) | array length (int) | array bytes |
final int longArrayLength = buffer.getInt(offset);
offset += Integer.BYTES;
if (longArrayLength < 0) {
return ofLongArray(null);
}
final Long[] longs = new Long[longArrayLength];
for (int i = 0; i < longArrayLength; i++) {
final byte isNull = buffer.get(offset);
offset += Byte.BYTES;
if (isNull == NullHandling.IS_NOT_NULL_BYTE) {
// | is null (byte) | long bytes |
longs[i] = buffer.getLong(offset);
offset += Long.BYTES;
} else {
// | is null (byte) |
longs[i] = null;
}
}
return ofLongArray(longs);
case DOUBLE_ARRAY:
// | expression type (byte) | array length (int) | array bytes |
final int doubleArrayLength = buffer.getInt(offset);
offset += Integer.BYTES;
if (doubleArrayLength < 0) {
return ofDoubleArray(null);
}
final Double[] doubles = new Double[doubleArrayLength];
for (int i = 0; i < doubleArrayLength; i++) {
final byte isNull = buffer.get(offset);
offset += Byte.BYTES;
if (isNull == NullHandling.IS_NOT_NULL_BYTE) {
// | is null (byte) | double bytes |
doubles[i] = buffer.getDouble(offset);
offset += Double.BYTES;
} else {
// | is null (byte) |
doubles[i] = null;
}
}
return ofDoubleArray(doubles);
case STRING_ARRAY:
// | expression type (byte) | array length (int) | array bytes |
final int stringArrayLength = buffer.getInt(offset);
offset += Integer.BYTES;
if (stringArrayLength < 0) {
return ofStringArray(null);
}
final String[] stringArray = new String[stringArrayLength];
for (int i = 0; i < stringArrayLength; i++) {
final int stringElementLength = buffer.getInt(offset);
offset += Integer.BYTES;
if (stringElementLength < 0) {
// | string length (int) |
stringArray[i] = null;
} else {
// | string length (int) | string bytes |
final byte[] stringElementBytes = new byte[stringElementLength];
final int oldPosition2 = buffer.position();
buffer.position(offset);
buffer.get(stringElementBytes, 0, stringElementLength);
buffer.position(oldPosition2);
stringArray[i] = StringUtils.fromUtf8(stringElementBytes);
offset += stringElementLength;
}
}
return ofStringArray(stringArray);
default:
throw new UOE("how can this be?");
}
}
/**
* Write an expression result to a bytebuffer, throwing an {@link ISE} if the data exceeds a maximum size. Primitive
* numeric types are not validated to be lower than max size, so it is expected to be at least 10 bytes. Callers
* of this method should enforce this themselves (instead of doing it here, which might be done every row)
*
* This should be refactored to be consolidated with some of the standard type handling of aggregators probably
*/
public static void serialize(ByteBuffer buffer, int position, ExprEval<?> eval, int maxSizeBytes)
{
int offset = position;
buffer.put(offset++, eval.type().getId());
switch (eval.type()) {
case LONG:
if (eval.isNumericNull()) {
buffer.put(offset, NullHandling.IS_NULL_BYTE);
} else {
buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
buffer.putLong(offset, eval.asLong());
}
break;
case DOUBLE:
if (eval.isNumericNull()) {
buffer.put(offset, NullHandling.IS_NULL_BYTE);
} else {
buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
buffer.putDouble(offset, eval.asDouble());
}
break;
case STRING:
final byte[] stringBytes = StringUtils.toUtf8Nullable(eval.asString());
if (stringBytes != null) {
// | expression type (byte) | string length (int) | string bytes |
checkMaxBytes(eval.type(), 1 + Integer.BYTES + stringBytes.length, maxSizeBytes);
buffer.putInt(offset, stringBytes.length);
offset += Integer.BYTES;
final int oldPosition = buffer.position();
buffer.position(offset);
buffer.put(stringBytes, 0, stringBytes.length);
buffer.position(oldPosition);
} else {
checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes);
buffer.putInt(offset, NULL_LENGTH);
}
break;
case LONG_ARRAY:
Long[] longs = eval.asLongArray();
if (longs == null) {
// | expression type (byte) | array length (int) |
checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes);
buffer.putInt(offset, NULL_LENGTH);
} else {
// | expression type (byte) | array length (int) | array bytes |
final int sizeBytes = 1 + Integer.BYTES + (Long.BYTES * longs.length);
checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
buffer.putInt(offset, longs.length);
offset += Integer.BYTES;
for (Long aLong : longs) {
if (aLong != null) {
buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
offset++;
buffer.putLong(offset, aLong);
offset += Long.BYTES;
} else {
buffer.put(offset++, NullHandling.IS_NULL_BYTE);
}
}
}
break;
case DOUBLE_ARRAY:
Double[] doubles = eval.asDoubleArray();
if (doubles == null) {
// | expression type (byte) | array length (int) |
checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes);
buffer.putInt(offset, NULL_LENGTH);
} else {
// | expression type (byte) | array length (int) | array bytes |
final int sizeBytes = 1 + Integer.BYTES + (Double.BYTES * doubles.length);
checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
buffer.putInt(offset, doubles.length);
offset += Integer.BYTES;
for (Double aDouble : doubles) {
if (aDouble != null) {
buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
offset++;
buffer.putDouble(offset, aDouble);
offset += Long.BYTES;
} else {
buffer.put(offset++, NullHandling.IS_NULL_BYTE);
}
}
}
break;
case STRING_ARRAY:
String[] strings = eval.asStringArray();
if (strings == null) {
// | expression type (byte) | array length (int) |
checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes);
buffer.putInt(offset, NULL_LENGTH);
} else {
// | expression type (byte) | array length (int) | array bytes |
buffer.putInt(offset, strings.length);
offset += Integer.BYTES;
int sizeBytes = 1 + Integer.BYTES;
for (String string : strings) {
if (string == null) {
// | string length (int) |
sizeBytes += Integer.BYTES;
checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
buffer.putInt(offset, NULL_LENGTH);
offset += Integer.BYTES;
} else {
// | string length (int) | string bytes |
final byte[] stringElementBytes = StringUtils.toUtf8(string);
sizeBytes += Integer.BYTES + stringElementBytes.length;
checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
buffer.putInt(offset, stringElementBytes.length);
offset += Integer.BYTES;
final int oldPosition = buffer.position();
buffer.position(offset);
buffer.put(stringElementBytes, 0, stringElementBytes.length);
buffer.position(oldPosition);
offset += stringElementBytes.length;
}
}
}
break;
default:
throw new UOE("how can this be?");
}
}
private static void checkMaxBytes(ExprType type, int sizeBytes, int maxSizeBytes)
{
if (sizeBytes > maxSizeBytes) {
throw new ISE("Unable to serialize [%s], size [%s] is larger than max [%s]", type, sizeBytes, maxSizeBytes);
}
}
/**
* Converts a List to an appropriate array type, optionally doing some conversion to make multi-valued strings
* consistent across selector types, which are not consistent in treatment of null, [], and [null].
*
* If homogenizeMultiValueStrings is true, null and [] will be converted to [null], otherwise they will retain
*/
@Nullable
public static Object coerceListToArray(@Nullable List<?> val, boolean homogenizeMultiValueStrings)
{
// if value is not null and has at least 1 element, conversion is unambigous regardless of the selector
if (val != null && val.size() > 0) {
Class<?> coercedType = null;
for (Object elem : val) {
if (elem != null) {
coercedType = convertType(coercedType, elem.getClass());
}
}
if (coercedType == Long.class || coercedType == Integer.class) {
return val.stream().map(x -> x != null ? ((Number) x).longValue() : null).toArray(Long[]::new);
}
if (coercedType == Float.class || coercedType == Double.class) {
return val.stream().map(x -> x != null ? ((Number) x).doubleValue() : null).toArray(Double[]::new);
}
// default to string
return val.stream().map(x -> x != null ? x.toString() : null).toArray(String[]::new);
}
if (homogenizeMultiValueStrings) {
return new String[]{null};
} else {
if (val != null) {
return val.toArray();
}
return null;
}
}
/**
* Find the common type to use between 2 types, useful for choosing the appropriate type for an array given a set
* of objects with unknown type, following rules similar to Java, our own native Expr, and SQL implicit type
* conversions. This is used to assist in preparing native java objects for {@link Expr.ObjectBinding} which will
* later be wrapped in {@link ExprEval} when evaluating {@link IdentifierExpr}.
*
* If any type is string, then the result will be string because everything can be converted to a string, but a string
* cannot be converted to everything.
*
* For numbers, integer is the most restrictive type, only chosen if both types are integers. Longs win over integers,
* floats over longs and integers, and doubles win over everything.
*/
private static Class convertType(@Nullable Class existing, Class next)
{
if (Number.class.isAssignableFrom(next) || next == String.class) {
if (existing == null) {
return next;
}
// string wins everything
if (existing == String.class) {
return existing;
}
if (next == String.class) {
return next;
}
// all numbers win over Integer
if (existing == Integer.class) {
return next;
}
if (existing == Float.class) {
// doubles win over floats
if (next == Double.class) {
return next;
}
return existing;
}
if (existing == Long.class) {
if (next == Integer.class) {
// long beats int
return existing;
}
// double and float win over longs
return next;
}
// otherwise double
return Double.class;
}
throw new UOE("Invalid array expression type: %s", next);
}
public static ExprEval ofLong(@Nullable Number longValue)
{
return new LongExprEval(longValue);
@ -118,6 +460,13 @@ public abstract class ExprEval<T>
return new StringArrayExprEval((String[]) val);
}
if (val instanceof List) {
// do not convert empty lists to arrays with a single null element here, because that should have been done
// by the selectors preparing their ObjectBindings if necessary. If we get to this point it was legitimately
// empty
return bestEffortOf(coerceListToArray((List<?>) val, false));
}
return new StringExprEval(val == null ? null : String.valueOf(val));
}

View File

@ -19,6 +19,8 @@
package org.apache.druid.math.expr;
import it.unimi.dsi.fastutil.bytes.Byte2ObjectArrayMap;
import it.unimi.dsi.fastutil.bytes.Byte2ObjectMap;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.column.ValueType;
@ -29,13 +31,32 @@ import javax.annotation.Nullable;
*/
public enum ExprType
{
DOUBLE,
LONG,
STRING,
DOUBLE_ARRAY,
LONG_ARRAY,
STRING_ARRAY;
DOUBLE((byte) 0x01),
LONG((byte) 0x02),
STRING((byte) 0x03),
DOUBLE_ARRAY((byte) 0x04),
LONG_ARRAY((byte) 0x05),
STRING_ARRAY((byte) 0x06);
private static final Byte2ObjectMap<ExprType> TYPE_BYTES = new Byte2ObjectArrayMap<>(ExprType.values().length);
static {
for (ExprType type : ExprType.values()) {
TYPE_BYTES.put(type.getId(), type);
}
}
final byte id;
ExprType(byte id)
{
this.id = id;
}
public byte getId()
{
return id;
}
public boolean isNumeric()
{
@ -47,6 +68,11 @@ public enum ExprType
return isScalar(this);
}
public static ExprType fromByte(byte id)
{
return TYPE_BYTES.get(id);
}
/**
* The expression system does not distinguish between {@link ValueType#FLOAT} and {@link ValueType#DOUBLE}, and
* cannot currently handle {@link ValueType#COMPLEX} inputs. This method will convert {@link ValueType#FLOAT} to
@ -177,5 +203,4 @@ public enum ExprType
}
return elementType;
}
}

View File

@ -42,6 +42,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
@ -384,13 +385,13 @@ public interface Function
@Override
public Set<Expr> getScalarInputs(List<Expr> args)
{
return ImmutableSet.of(args.get(1));
return ImmutableSet.of(getScalarArgument(args));
}
@Override
public Set<Expr> getArrayInputs(List<Expr> args)
{
return ImmutableSet.of(args.get(0));
return ImmutableSet.of(getArrayArgument(args));
}
@Override
@ -402,14 +403,24 @@ public interface Function
@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
final ExprEval arrayExpr = args.get(0).eval(bindings);
final ExprEval scalarExpr = args.get(1).eval(bindings);
final ExprEval arrayExpr = getArrayArgument(args).eval(bindings);
final ExprEval scalarExpr = getScalarArgument(args).eval(bindings);
if (arrayExpr.asArray() == null) {
return ExprEval.of(null);
}
return doApply(arrayExpr, scalarExpr);
}
Expr getScalarArgument(List<Expr> args)
{
return args.get(1);
}
Expr getArrayArgument(List<Expr> args)
{
return args.get(0);
}
abstract ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr);
}
@ -450,8 +461,11 @@ public interface Function
final ExprEval arrayExpr1 = args.get(0).eval(bindings);
final ExprEval arrayExpr2 = args.get(1).eval(bindings);
if (arrayExpr1.asArray() == null || arrayExpr2.asArray() == null) {
return ExprEval.of(null);
if (arrayExpr1.asArray() == null) {
return arrayExpr1;
}
if (arrayExpr2.asArray() == null) {
return arrayExpr2;
}
return doApply(arrayExpr1, arrayExpr2);
@ -460,6 +474,118 @@ public interface Function
abstract ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr);
}
/**
* Scaffolding for a 2 argument {@link Function} which accepts one array and one scalar input and adds the scalar
* input to the array in some way.
*/
abstract class ArrayAddElementFunction extends ArrayScalarFunction
{
@Override
public boolean hasArrayOutput()
{
return true;
}
@Nullable
@Override
public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
{
ExprType arrayType = getArrayArgument(args).getOutputType(inspector);
return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
}
@Override
ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
{
switch (arrayExpr.type()) {
case STRING:
case STRING_ARRAY:
return ExprEval.ofStringArray(add(arrayExpr.asStringArray(), scalarExpr.asString()).toArray(String[]::new));
case LONG:
case LONG_ARRAY:
return ExprEval.ofLongArray(
add(
arrayExpr.asLongArray(),
scalarExpr.isNumericNull() ? null : scalarExpr.asLong()
).toArray(Long[]::new)
);
case DOUBLE:
case DOUBLE_ARRAY:
return ExprEval.ofDoubleArray(
add(
arrayExpr.asDoubleArray(),
scalarExpr.isNumericNull() ? null : scalarExpr.asDouble()
).toArray(Double[]::new)
);
}
throw new RE("Unable to add to unknown array type %s", arrayExpr.type());
}
abstract <T> Stream<T> add(T[] array, @Nullable T val);
}
/**
* Base scaffolding for functions which accept 2 array arguments and combine them in some way
*/
abstract class ArraysMergeFunction extends ArraysFunction
{
@Override
public Set<Expr> getArrayInputs(List<Expr> args)
{
return ImmutableSet.copyOf(args);
}
@Override
public boolean hasArrayOutput()
{
return true;
}
@Nullable
@Override
public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
{
ExprType arrayType = args.get(0).getOutputType(inspector);
return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
}
@Override
ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr)
{
final Object[] array1 = lhsExpr.asArray();
final Object[] array2 = rhsExpr.asArray();
if (array1 == null) {
return ExprEval.of(null);
}
if (array2 == null) {
return lhsExpr;
}
switch (lhsExpr.type()) {
case STRING:
case STRING_ARRAY:
return ExprEval.ofStringArray(
merge(lhsExpr.asStringArray(), rhsExpr.asStringArray()).toArray(String[]::new)
);
case LONG:
case LONG_ARRAY:
return ExprEval.ofLongArray(
merge(lhsExpr.asLongArray(), rhsExpr.asLongArray()).toArray(Long[]::new)
);
case DOUBLE:
case DOUBLE_ARRAY:
return ExprEval.ofDoubleArray(
merge(lhsExpr.asDoubleArray(), rhsExpr.asDoubleArray()).toArray(Double[]::new)
);
}
throw new RE("Unable to concatenate to unknown type %s", lhsExpr.type());
}
abstract <T> Stream<T> merge(T[] array1, T[] array2);
}
abstract class ReduceFunction implements Function
{
private final DoubleBinaryOperator doubleReducer;
@ -3168,7 +3294,7 @@ public interface Function
}
}
class ArrayAppendFunction extends ArrayScalarFunction
class ArrayAppendFunction extends ArrayAddElementFunction
{
@Override
public String name()
@ -3177,48 +3303,7 @@ public interface Function
}
@Override
public boolean hasArrayOutput()
{
return true;
}
@Nullable
@Override
public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
{
ExprType arrayType = args.get(0).getOutputType(inspector);
return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
}
@Override
ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
{
switch (arrayExpr.type()) {
case STRING:
case STRING_ARRAY:
return ExprEval.ofStringArray(this.append(arrayExpr.asStringArray(), scalarExpr.asString()).toArray(String[]::new));
case LONG:
case LONG_ARRAY:
return ExprEval.ofLongArray(
this.append(
arrayExpr.asLongArray(),
scalarExpr.isNumericNull() ? null : scalarExpr.asLong()).toArray(Long[]::new
)
);
case DOUBLE:
case DOUBLE_ARRAY:
return ExprEval.ofDoubleArray(
this.append(
arrayExpr.asDoubleArray(),
scalarExpr.isNumericNull() ? null : scalarExpr.asDouble()).toArray(Double[]::new
)
);
}
throw new RE("Unable to append to unknown type %s", arrayExpr.type());
}
private <T> Stream<T> append(T[] array, T val)
<T> Stream<T> add(T[] array, @Nullable T val)
{
List<T> l = new ArrayList<>(Arrays.asList(array));
l.add(val);
@ -3226,7 +3311,36 @@ public interface Function
}
}
class ArrayConcatFunction extends ArraysFunction
class ArrayPrependFunction extends ArrayAddElementFunction
{
@Override
public String name()
{
return "array_prepend";
}
@Override
Expr getScalarArgument(List<Expr> args)
{
return args.get(0);
}
@Override
Expr getArrayArgument(List<Expr> args)
{
return args.get(1);
}
@Override
<T> Stream<T> add(T[] array, @Nullable T val)
{
List<T> l = new ArrayList<>(Arrays.asList(array));
l.add(0, val);
return l.stream();
}
}
class ArrayConcatFunction extends ArraysMergeFunction
{
@Override
public String name()
@ -3235,59 +3349,7 @@ public interface Function
}
@Override
public Set<Expr> getArrayInputs(List<Expr> args)
{
return ImmutableSet.copyOf(args);
}
@Override
public boolean hasArrayOutput()
{
return true;
}
@Nullable
@Override
public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
{
ExprType arrayType = args.get(0).getOutputType(inspector);
return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
}
@Override
ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr)
{
final Object[] array1 = lhsExpr.asArray();
final Object[] array2 = rhsExpr.asArray();
if (array1 == null) {
return ExprEval.of(null);
}
if (array2 == null) {
return lhsExpr;
}
switch (lhsExpr.type()) {
case STRING:
case STRING_ARRAY:
return ExprEval.ofStringArray(
cat(lhsExpr.asStringArray(), rhsExpr.asStringArray()).toArray(String[]::new)
);
case LONG:
case LONG_ARRAY:
return ExprEval.ofLongArray(
cat(lhsExpr.asLongArray(), rhsExpr.asLongArray()).toArray(Long[]::new)
);
case DOUBLE:
case DOUBLE_ARRAY:
return ExprEval.ofDoubleArray(
cat(lhsExpr.asDoubleArray(), rhsExpr.asDoubleArray()).toArray(Double[]::new)
);
}
throw new RE("Unable to concatenate to unknown type %s", lhsExpr.type());
}
private <T> Stream<T> cat(T[] array1, T[] array2)
<T> Stream<T> merge(T[] array1, T[] array2)
{
List<T> l = new ArrayList<>(Arrays.asList(array1));
l.addAll(Arrays.asList(array2));
@ -3295,6 +3357,40 @@ public interface Function
}
}
class ArraySetAddFunction extends ArrayAddElementFunction
{
@Override
public String name()
{
return "array_set_add";
}
@Override
<T> Stream<T> add(T[] array, @Nullable T val)
{
Set<T> l = new HashSet<>(Arrays.asList(array));
l.add(val);
return l.stream();
}
}
class ArraySetAddAllFunction extends ArraysMergeFunction
{
@Override
public String name()
{
return "array_set_add_all";
}
@Override
<T> Stream<T> merge(T[] array1, T[] array2)
{
Set<T> l = new HashSet<>(Arrays.asList(array1));
l.addAll(Arrays.asList(array2));
return l.stream();
}
}
class ArrayContainsFunction extends ArraysFunction
{
@Override
@ -3438,93 +3534,4 @@ public interface Function
throw new RE("Unable to slice to unknown type %s", expr.type());
}
}
class ArrayPrependFunction implements Function
{
@Override
public String name()
{
return "array_prepend";
}
@Override
public void validateArguments(List<Expr> args)
{
if (args.size() != 2) {
throw new IAE("Function[%s] needs 2 arguments", name());
}
}
@Nullable
@Override
public ExprType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
{
ExprType arrayType = args.get(1).getOutputType(inspector);
return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType);
}
@Override
public Set<Expr> getScalarInputs(List<Expr> args)
{
return ImmutableSet.of(args.get(0));
}
@Override
public Set<Expr> getArrayInputs(List<Expr> args)
{
return ImmutableSet.of(args.get(1));
}
@Override
public boolean hasArrayInputs()
{
return true;
}
@Override
public boolean hasArrayOutput()
{
return true;
}
@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
final ExprEval scalarExpr = args.get(0).eval(bindings);
final ExprEval arrayExpr = args.get(1).eval(bindings);
if (arrayExpr.asArray() == null) {
return ExprEval.of(null);
}
switch (arrayExpr.type()) {
case STRING:
case STRING_ARRAY:
return ExprEval.ofStringArray(this.prepend(scalarExpr.asString(), arrayExpr.asStringArray()).toArray(String[]::new));
case LONG:
case LONG_ARRAY:
return ExprEval.ofLongArray(
this.prepend(
scalarExpr.isNumericNull() ? null : scalarExpr.asLong(),
arrayExpr.asLongArray()).toArray(Long[]::new
)
);
case DOUBLE:
case DOUBLE_ARRAY:
return ExprEval.ofDoubleArray(
this.prepend(
scalarExpr.isNumericNull() ? null : scalarExpr.asDouble(),
arrayExpr.asDoubleArray()).toArray(Double[]::new
)
);
}
throw new RE("Unable to prepend to unknown type %s", arrayExpr.type());
}
private <T> Stream<T> prepend(T val, T[] array)
{
List<T> l = new ArrayList<>(Arrays.asList(array));
l.add(0, val);
return l.stream();
}
}
}

View File

@ -22,6 +22,7 @@ package org.apache.druid.math.expr;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
@ -35,12 +36,14 @@ import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.antlr.ExprLexer;
import org.apache.druid.math.expr.antlr.ExprParser;
import javax.annotation.Nullable;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
public class Parser
@ -96,17 +99,33 @@ public class Parser
}
/**
* Parse a string into a flattened {@link Expr}. There is some overhead to this, and these objects are all immutable,
* so re-use instead of re-creating whenever possible.
* Create a memoized lazy supplier to parse a string into a flattened {@link Expr}. There is some overhead to this,
* and these objects are all immutable, so this assists in the goal of re-using instead of re-creating whenever
* possible.
*
* Lazy form of {@link #parse(String, ExprMacroTable)}
*
* @param in expression to parse
* @param macroTable additional extensions to expression language
*/
public static Supplier<Expr> lazyParse(@Nullable String in, ExprMacroTable macroTable)
{
return Suppliers.memoize(() -> in == null ? null : Parser.parse(in, macroTable));
}
/**
* Parse a string into a flattened {@link Expr}. There is some overhead to this, and these objects are all immutable,
* so re-use instead of re-creating whenever possible.
*
* @param in expression to parse
* @param macroTable additional extensions to expression language
* @return
*/
public static Expr parse(String in, ExprMacroTable macroTable)
{
return parse(in, macroTable, true);
}
@VisibleForTesting
public static Expr parse(String in, ExprMacroTable macroTable, boolean withFlatten)
{
@ -164,47 +183,43 @@ public class Parser
/**
* Applies a transformation to an {@link Expr} given a list of known (or uknown) multi-value input columns that are
* used in a scalar manner, walking the {@link Expr} tree and lifting array variables into the {@link LambdaExpr} of
* {@link ApplyFunctionExpr} and transforming the arguments of {@link FunctionExpr}
* @param expr expression to visit and rewrite
* @param bindingsToApply
* @return
* {@link ApplyFunctionExpr} and transforming the arguments of {@link FunctionExpr} as necessary.
*
* This function applies a transformation for "map" style uses, such as column selectors, where the supplied
* expression will be transformed to return an array of results instead of the scalar result (or appropriately
* rewritten into existing apply expressions to produce correct results when referenced from a scalar context).
*
* This function and {@link #foldUnappliedBindings(Expr, Expr.BindingAnalysis, List, String)} exist to handle
* "multi-valued" string dimensions, which exist in a superposition of both single and multi-valued during realtime
* ingestion, until they are written to a segment and become locked into either single or multi-valued. This also
* means that multi-valued-ness can vary for a column from segment to segment, so this family of transformation
* functions exist so that multi-valued strings can be expressed in either and array or scalar context, which is
* important because the writer of the query might not actually know if the column is definitively always single or
* multi-valued (and it might in fact not be).
*
* @see #foldUnappliedBindings(Expr, Expr.BindingAnalysis, List, String)
*/
public static Expr applyUnappliedBindings(Expr expr, Expr.BindingAnalysis bindingAnalysis, List<String> bindingsToApply)
public static Expr applyUnappliedBindings(
Expr expr,
Expr.BindingAnalysis bindingAnalysis,
List<String> bindingsToApply
)
{
if (bindingsToApply.isEmpty()) {
// nothing to do, expression is fine as is
return expr;
}
// filter the list of bindings to those which are used in this expression
List<String> unappliedBindingsInExpression = bindingsToApply.stream()
.filter(x -> bindingAnalysis.getRequiredBindings().contains(x))
.collect(Collectors.toList());
List<String> unappliedBindingsInExpression =
bindingsToApply.stream()
.filter(x -> bindingAnalysis.getRequiredBindings().contains(x))
.collect(Collectors.toList());
// any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten
Expr newExpr = expr.visit(
childExpr -> {
if (childExpr instanceof ApplyFunctionExpr) {
// try to lift unapplied arguments into the apply function lambda
return liftApplyLambda((ApplyFunctionExpr) childExpr, unappliedBindingsInExpression);
} else if (childExpr instanceof FunctionExpr) {
// check array function arguments for unapplied identifiers to transform if necessary
FunctionExpr fnExpr = (FunctionExpr) childExpr;
Set<Expr> arrayInputs = fnExpr.function.getArrayInputs(fnExpr.args);
List<Expr> newArgs = new ArrayList<>();
for (Expr arg : fnExpr.args) {
if (arg.getIdentifierIfIdentifier() == null && arrayInputs.contains(arg)) {
Expr newArg = applyUnappliedBindings(arg, bindingAnalysis, unappliedBindingsInExpression);
newArgs.add(newArg);
} else {
newArgs.add(arg);
}
}
FunctionExpr newFnExpr = new FunctionExpr(fnExpr.function, fnExpr.function.name(), newArgs);
return newFnExpr;
}
return childExpr;
}
Expr newExpr = rewriteUnappliedSubExpressions(
expr,
unappliedBindingsInExpression,
(arg) -> applyUnappliedBindings(arg, bindingAnalysis, bindingsToApply)
);
Expr.BindingAnalysis newExprBindings = newExpr.analyzeInputs();
@ -221,9 +236,123 @@ public class Parser
return applyUnapplied(newExpr, remainingUnappliedBindings);
}
/**
* Applies a transformation to an {@link Expr} given a list of known (or uknown) multi-value input columns that are
* used in a scalar manner, walking the {@link Expr} tree and lifting array variables into the {@link LambdaExpr} of
* {@link ApplyFunctionExpr} and transforming the arguments of {@link FunctionExpr} as necessary.
*
* This function applies a transformation for "fold" style uses, such as aggregators, where the supplied
* expression will be transformed to accumulate the result of applying the expression to each value of the unapplied
* input (or appropriately rewritten into existing apply expressions to produce correct results when referenced from
* a scalar context). This rewriting assumes that there exists some accumulator variable, which is re-used as the
* accumulator for this fold rewrite, so that evaluating each expression can be accumulated into the larger external
* fold operation that an aggregator might be performing.
*
* This function and {@link #applyUnappliedBindings(Expr, Expr.BindingAnalysis, List)} exist to handle
* "multi-valued" string dimensions, which exist in a superposition of both single and multi-valued during realtime
* ingestion, until they are written to a segment and become locked into either single or multi-valued. This also
* means that multi-valued-ness can vary for a column from segment to segment, so this family of transformation
* functions exist so that multi-valued strings can be expressed in either and array or scalar context, which is
* important because the writer of the query might not actually know if the column is definitively always single or
* multi-valued (and it might in fact not be).
*
* @see #applyUnappliedBindings(Expr, Expr.BindingAnalysis, List)
*/
public static Expr foldUnappliedBindings(Expr expr, Expr.BindingAnalysis bindingAnalysis, List<String> bindingsToApply, String accumulatorId)
{
if (bindingsToApply.isEmpty()) {
// nothing to do, expression is fine as is
return expr;
}
// filter the list of bindings to those which are used in this expression
List<String> unappliedBindingsInExpression =
bindingsToApply.stream()
.filter(x -> bindingAnalysis.getRequiredBindings().contains(x))
.collect(Collectors.toList());
Expr newExpr = rewriteUnappliedSubExpressions(
expr,
unappliedBindingsInExpression,
(arg) -> foldUnappliedBindings(arg, bindingAnalysis, bindingsToApply, accumulatorId)
);
Expr.BindingAnalysis newExprBindings = newExpr.analyzeInputs();
final Set<String> expectedArrays = newExprBindings.getArrayVariables();
List<String> remainingUnappliedBindings =
unappliedBindingsInExpression.stream().filter(x -> !expectedArrays.contains(x)).collect(Collectors.toList());
// if lifting the lambdas got rid of all missing bindings, return the transformed expression
if (remainingUnappliedBindings.isEmpty()) {
return newExpr;
}
return foldUnapplied(newExpr, remainingUnappliedBindings, accumulatorId);
}
/**
* Any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten to "lift"
* the identifier variables and transform the function.
*
* For example:
* if "y" is unapplied:
* map((x) -> x + y, x) => cartesian_map((x,y) -> x + y, x, y)
*
* @see #liftApplyLambda(ApplyFunctionExpr, List)
*
* Array functions on expressions using unapplied identifiers might also need transformed, so we recursively call the
* unapplied binding transformation function (supplied to this method) on that expression to ensure proper
* transformation and rewrite of these array expressions.
*
* For example:
* if "y" is unapplied:
* array_length(filter((x) -> x > y, x))
*/
private static Expr rewriteUnappliedSubExpressions(
Expr expr,
List<String> unappliedBindingsInExpression,
UnaryOperator<Expr> applyUnappliedFn
)
{
// any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten
return expr.visit(
childExpr -> {
if (childExpr instanceof ApplyFunctionExpr) {
// try to lift unapplied arguments into the apply function lambda
return liftApplyLambda((ApplyFunctionExpr) childExpr, unappliedBindingsInExpression);
} else if (childExpr instanceof FunctionExpr) {
// check array function arguments for unapplied identifiers to transform if necessary
FunctionExpr fnExpr = (FunctionExpr) childExpr;
Set<Expr> arrayInputs = fnExpr.function.getArrayInputs(fnExpr.args);
List<Expr> newArgs = new ArrayList<>();
for (Expr arg : fnExpr.args) {
if (arg.getIdentifierIfIdentifier() == null && arrayInputs.contains(arg)) {
Expr newArg = applyUnappliedFn.apply(arg);
newArgs.add(newArg);
} else {
newArgs.add(arg);
}
}
FunctionExpr newFnExpr = new FunctionExpr(fnExpr.function, fnExpr.function.name(), newArgs);
return newFnExpr;
}
return childExpr;
}
);
}
/**
* translate an {@link Expr} into an {@link ApplyFunctionExpr} for {@link ApplyFunction.MapFunction} or
* {@link ApplyFunction.CartesianMapFunction} if there are multiple unbound arguments to be applied
* {@link ApplyFunction.CartesianMapFunction} if there are multiple unbound arguments to be applied.
*
* For example:
* if "x" is unapplied:
* x + y => map((x) -> x + y, x)
* if "x" and "y" are unapplied:
* x + y => cartesian_map((x, y) -> x + y, x, y)
*/
private static Expr applyUnapplied(Expr expr, List<String> unappliedBindings)
{
@ -274,6 +403,72 @@ public class Parser
return magic;
}
/**
* translate an {@link Expr} into an {@link ApplyFunctionExpr} for {@link ApplyFunction.FoldFunction} or
* {@link ApplyFunction.CartesianFoldFunction} if there are multiple unbound arguments to be applied.
*
* This assumes a known {@link IdentifierExpr} is an "accumulator", which is re-used as the accumulator variable and
* input for the translated fold.
*
* For example given an accumulator "__acc":
* if "x" is unapplied:
* __acc + x => fold((x, __acc) -> x + __acc, x, __acc)
* if "x" and "y" are unapplied:
* __acc + x + y => cartesian_fold((x, y, __acc) -> __acc + x + y, x, y, __acc)
*
*/
private static Expr foldUnapplied(Expr expr, List<String> unappliedBindings, String accumulatorId)
{
// filter to get list of IdentifierExpr that are backed by the unapplied bindings
final List<IdentifierExpr> args = expr.analyzeInputs()
.getFreeVariables()
.stream()
.filter(x -> unappliedBindings.contains(x.getBinding()))
.collect(Collectors.toList());
final List<IdentifierExpr> lambdaArgs = new ArrayList<>();
// construct lambda args from list of args to apply. Identifiers in a lambda body have artificial 'binding' values
// that is the same as the 'identifier', because the bindings are supplied by the wrapping apply function
// replacements are done by binding rather than identifier because repeats of the same input should not result
// in a cartesian product
final Map<String, IdentifierExpr> toReplace = new HashMap<>();
for (IdentifierExpr applyFnArg : args) {
if (!toReplace.containsKey(applyFnArg.getBinding())) {
IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getBinding());
lambdaArgs.add(lambdaRewrite);
toReplace.put(applyFnArg.getBinding(), lambdaRewrite);
}
}
lambdaArgs.add(new IdentifierExpr(accumulatorId));
// rewrite identifiers in the expression which will become the lambda body, so they match the lambda identifiers we
// are constructing
Expr newExpr = expr.visit(childExpr -> {
if (childExpr instanceof IdentifierExpr) {
if (toReplace.containsKey(((IdentifierExpr) childExpr).getBinding())) {
return toReplace.get(((IdentifierExpr) childExpr).getBinding());
}
}
return childExpr;
});
// wrap an expression in either fold or cartesian_fold to apply any unapplied identifiers
final LambdaExpr lambdaExpr = new LambdaExpr(lambdaArgs, newExpr);
final ApplyFunction fn;
if (lambdaArgs.size() == 2) {
fn = new ApplyFunction.FoldFunction();
} else {
fn = new ApplyFunction.CartesianFoldFunction();
}
final Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, ImmutableList.copyOf(lambdaArgs));
return magic;
}
/**
* Performs partial lifting of free identifiers of the lambda expression of an {@link ApplyFunctionExpr}, constrained
* by a list of "unapplied" identifiers, and translating them into arguments of a new {@link LambdaExpr} and

View File

@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.math.expr;
import com.google.common.collect.Maps;
import javax.annotation.Nullable;
import java.util.HashMap;
import java.util.Map;
/**
* Simple map backed object binding
*/
public class SettableObjectBinding implements Expr.ObjectBinding
{
private final Map<String, Object> bindings;
public SettableObjectBinding()
{
this.bindings = new HashMap<>();
}
public SettableObjectBinding(int expectedSize)
{
this.bindings = Maps.newHashMapWithExpectedSize(expectedSize);
}
@Nullable
@Override
public Object get(String name)
{
return bindings.get(name);
}
public SettableObjectBinding withBinding(String name, @Nullable Object value)
{
bindings.put(name, value);
return this;
}
}

View File

@ -0,0 +1,220 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.math.expr;
import com.google.common.collect.ImmutableList;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
public class ExprEvalTest extends InitializedNullHandlingTest
{
private static int MAX_SIZE_BYTES = 1 << 13;
@Rule
public ExpectedException expectedException = ExpectedException.none();
ByteBuffer buffer = ByteBuffer.allocate(1 << 16);
@Test
public void testStringSerde()
{
assertExpr(0, "hello");
assertExpr(1234, "hello");
assertExpr(0, ExprEval.bestEffortOf(null));
}
@Test
public void testStringSerdeTooBig()
{
expectedException.expect(ISE.class);
expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.STRING, 16, 10));
assertExpr(0, ExprEval.of("hello world"), 10);
}
@Test
public void testLongSerde()
{
assertExpr(0, 1L);
assertExpr(1234, 1L);
assertExpr(1234, ExprEval.ofLong(null));
}
@Test
public void testDoubleSerde()
{
assertExpr(0, 1.123);
assertExpr(1234, 1.123);
assertExpr(1234, ExprEval.ofDouble(null));
}
@Test
public void testStringArraySerde()
{
assertExpr(0, new String[] {"hello", "hi", "hey"});
assertExpr(1024, new String[] {"hello", null, "hi", "hey"});
assertExpr(2048, new String[] {});
}
@Test
public void testStringArraySerdeToBig()
{
expectedException.expect(ISE.class);
expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.STRING_ARRAY, 14, 10));
assertExpr(0, ExprEval.ofStringArray(new String[] {"hello", "hi", "hey"}), 10);
}
@Test
public void testLongArraySerde()
{
assertExpr(0, new Long[] {1L, 2L, 3L});
assertExpr(1234, new Long[] {1L, 2L, null, 3L});
assertExpr(1234, new Long[] {});
}
@Test
public void testLongArraySerdeTooBig()
{
expectedException.expect(ISE.class);
expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.LONG_ARRAY, 29, 10));
assertExpr(0, ExprEval.ofLongArray(new Long[] {1L, 2L, 3L}), 10);
}
@Test
public void testDoubleArraySerde()
{
assertExpr(0, new Double[] {1.1, 2.2, 3.3});
assertExpr(1234, new Double[] {1.1, 2.2, null, 3.3});
assertExpr(1234, new Double[] {});
}
@Test
public void testDoubleArraySerdeTooBig()
{
expectedException.expect(ISE.class);
expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.DOUBLE_ARRAY, 29, 10));
assertExpr(0, ExprEval.ofDoubleArray(new Double[] {1.1, 2.2, 3.3}), 10);
}
@Test
public void test_coerceListToArray()
{
Assert.assertNull(ExprEval.coerceListToArray(null, false));
Assert.assertArrayEquals(new Object[0], (Object[]) ExprEval.coerceListToArray(ImmutableList.of(), false));
Assert.assertArrayEquals(new String[]{null}, (String[]) ExprEval.coerceListToArray(null, true));
Assert.assertArrayEquals(new String[]{null}, (String[]) ExprEval.coerceListToArray(ImmutableList.of(), true));
List<Long> longList = ImmutableList.of(1L, 2L, 3L);
Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExprEval.coerceListToArray(longList, false));
List<Integer> intList = ImmutableList.of(1, 2, 3);
Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExprEval.coerceListToArray(intList, false));
List<Float> floatList = ImmutableList.of(1.0f, 2.0f, 3.0f);
Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExprEval.coerceListToArray(floatList, false));
List<Double> doubleList = ImmutableList.of(1.0, 2.0, 3.0);
Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExprEval.coerceListToArray(doubleList, false));
List<String> stringList = ImmutableList.of("a", "b", "c");
Assert.assertArrayEquals(new String[]{"a", "b", "c"}, (String[]) ExprEval.coerceListToArray(stringList, false));
List<String> withNulls = new ArrayList<>();
withNulls.add("a");
withNulls.add(null);
withNulls.add("c");
Assert.assertArrayEquals(new String[]{"a", null, "c"}, (String[]) ExprEval.coerceListToArray(withNulls, false));
List<Long> withNumberNulls = new ArrayList<>();
withNumberNulls.add(1L);
withNumberNulls.add(null);
withNumberNulls.add(3L);
Assert.assertArrayEquals(new Long[]{1L, null, 3L}, (Long[]) ExprEval.coerceListToArray(withNumberNulls, false));
List<Object> withStringMix = ImmutableList.of(1L, "b", 3L);
Assert.assertArrayEquals(
new String[]{"1", "b", "3"},
(String[]) ExprEval.coerceListToArray(withStringMix, false)
);
List<Number> withIntsAndLongs = ImmutableList.of(1, 2L, 3);
Assert.assertArrayEquals(
new Long[]{1L, 2L, 3L},
(Long[]) ExprEval.coerceListToArray(withIntsAndLongs, false)
);
List<Number> withFloatsAndLongs = ImmutableList.of(1, 2L, 3.0f);
Assert.assertArrayEquals(
new Double[]{1.0, 2.0, 3.0},
(Double[]) ExprEval.coerceListToArray(withFloatsAndLongs, false)
);
List<Number> withDoublesAndLongs = ImmutableList.of(1, 2L, 3.0);
Assert.assertArrayEquals(
new Double[]{1.0, 2.0, 3.0},
(Double[]) ExprEval.coerceListToArray(withDoublesAndLongs, false)
);
List<Number> withFloatsAndDoubles = ImmutableList.of(1L, 2.0f, 3.0);
Assert.assertArrayEquals(
new Double[]{1.0, 2.0, 3.0},
(Double[]) ExprEval.coerceListToArray(withFloatsAndDoubles, false)
);
List<String> withAllNulls = new ArrayList<>();
withAllNulls.add(null);
withAllNulls.add(null);
withAllNulls.add(null);
Assert.assertArrayEquals(
new String[]{null, null, null},
(String[]) ExprEval.coerceListToArray(withAllNulls, false)
);
}
private void assertExpr(int position, Object expected)
{
assertExpr(position, ExprEval.bestEffortOf(expected));
}
private void assertExpr(int position, ExprEval expected)
{
assertExpr(position, expected, MAX_SIZE_BYTES);
}
private void assertExpr(int position, ExprEval expected, int maxSizeBytes)
{
ExprEval.serialize(buffer, position, expected, maxSizeBytes);
if (ExprType.isArray(expected.type())) {
Assert.assertArrayEquals(expected.asArray(), ExprEval.deserialize(buffer, position).asArray());
} else {
Assert.assertEquals(expected.value(), ExprEval.deserialize(buffer, position).value());
}
}
}

View File

@ -290,6 +290,27 @@ public class FunctionTest extends InitializedNullHandlingTest
assertArrayExpr("array_concat(0, 1)", new Long[]{0L, 1L});
}
@Test
public void testArraySetAdd()
{
assertArrayExpr("array_set_add([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
assertArrayExpr("array_set_add([1, 2, 3], 'bar')", new Long[]{null, 1L, 2L, 3L});
assertArrayExpr("array_set_add([1, 2, 2], 1)", new Long[]{1L, 2L});
assertArrayExpr("array_set_add([], 1)", new String[]{"1"});
assertArrayExpr("array_set_add(<LONG>[], 1)", new Long[]{1L});
assertArrayExpr("array_set_add(<LONG>[], null)", new Long[]{null});
}
@Test
public void testArraySetAddAll()
{
assertArrayExpr("array_set_add_all([1, 2, 3], [2, 4, 6])", new Long[]{1L, 2L, 3L, 4L, 6L});
assertArrayExpr("array_set_add_all([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
assertArrayExpr("array_set_add_all(0, [1, 2, 3])", new Long[]{0L, 1L, 2L, 3L});
assertArrayExpr("array_set_add_all(map(y -> y * 3, b), [1, 2, 3])", new Long[]{1L, 2L, 3L, 6L, 9L, 12L, 15L});
assertArrayExpr("array_set_add_all(0, 1)", new Long[]{0L, 1L});
}
@Test
public void testArrayToString()
{

View File

@ -528,6 +528,48 @@ public class ParserTest extends InitializedNullHandlingTest
);
}
@Test
public void testFoldUnapplied()
{
validateFoldUnapplied("x + __acc", "(+ x __acc)", "(+ x __acc)", ImmutableList.of(), "__acc");
validateFoldUnapplied("x + __acc", "(+ x __acc)", "(+ x __acc)", ImmutableList.of("z"), "__acc");
validateFoldUnapplied(
"x + __acc",
"(+ x __acc)",
"(fold ([x, __acc] -> (+ x __acc)), [x, __acc])",
ImmutableList.of("x"),
"__acc"
);
validateFoldUnapplied(
"x + y + __acc",
"(+ (+ x y) __acc)",
"(cartesian_fold ([x, y, __acc] -> (+ (+ x y) __acc)), [x, y, __acc])",
ImmutableList.of("x", "y"),
"__acc"
);
validateFoldUnapplied(
"__acc + z + fold((x, acc) -> acc + x + y, x, 0)",
"(+ (+ __acc z) (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))",
"(fold ([z, __acc] -> (+ (+ __acc z) (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))), [z, __acc])",
ImmutableList.of("z"),
"__acc"
);
validateFoldUnapplied(
"__acc + z + fold((x, acc) -> acc + x + y, x, 0)",
"(+ (+ __acc z) (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))",
"(fold ([z, __acc] -> (+ (+ __acc z) (cartesian_fold ([x, y, acc] -> (+ (+ acc x) y)), [x, y, 0]))), [z, __acc])",
ImmutableList.of("y", "z"),
"__acc"
);
validateFoldUnapplied(
"__acc + fold((x, acc) -> x + y + acc, x, __acc)",
"(+ __acc (fold ([x, acc] -> (+ (+ x y) acc)), [x, __acc]))",
"(+ __acc (cartesian_fold ([x, y, acc] -> (+ (+ x y) acc)), [x, y, __acc]))",
ImmutableList.of("y"),
"__acc"
);
}
@Test
public void testUniquify()
{
@ -666,6 +708,33 @@ public class ParserTest extends InitializedNullHandlingTest
Assert.assertEquals(transformed.stringify(), transformedRoundTrip.stringify());
}
private void validateFoldUnapplied(
String expression,
String unapplied,
String applied,
List<String> identifiers,
String accumulator
)
{
final Expr parsed = Parser.parse(expression, ExprMacroTable.nil());
Expr.BindingAnalysis deets = parsed.analyzeInputs();
Parser.validateExpr(parsed, deets);
final Expr transformed = Parser.foldUnappliedBindings(parsed, deets, identifiers, accumulator);
Assert.assertEquals(expression, unapplied, parsed.toString());
Assert.assertEquals(applied, applied, transformed.toString());
final Expr parsedNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false);
final Expr parsedRoundTrip = Parser.parse(parsedNoFlatten.stringify(), ExprMacroTable.nil());
Expr.BindingAnalysis roundTripDeets = parsedRoundTrip.analyzeInputs();
Parser.validateExpr(parsedRoundTrip, roundTripDeets);
final Expr transformedRoundTrip = Parser.foldUnappliedBindings(parsedRoundTrip, roundTripDeets, identifiers, accumulator);
Assert.assertEquals(expression, unapplied, parsedRoundTrip.toString());
Assert.assertEquals(applied, applied, transformedRoundTrip.toString());
Assert.assertEquals(parsed.stringify(), parsedRoundTrip.stringify());
Assert.assertEquals(transformed.stringify(), transformedRoundTrip.stringify());
}
private void validateConstantExpression(String expression, Object expected)
{
Expr parsed = Parser.parse(expression, ExprMacroTable.nil());

View File

@ -415,29 +415,6 @@ public class VectorExprSanityTest extends InitializedNullHandlingTest
.toArray(String[][]::new);
}
static class SettableObjectBinding implements Expr.ObjectBinding
{
private final Map<String, Object> bindings;
SettableObjectBinding()
{
this.bindings = new HashMap<>();
}
@Nullable
@Override
public Object get(String name)
{
return bindings.get(name);
}
public SettableObjectBinding withBinding(String name, @Nullable Object value)
{
bindings.put(name, value);
return this;
}
}
static class SettableVectorInputBinding implements Expr.VectorInputBinding
{
private final Map<String, boolean[]> nulls;

View File

@ -177,14 +177,17 @@ See javadoc of java.lang.Math for detailed explanation for each function.
| array_offset_of(arr,expr) | returns the 0 based index of the first occurrence of expr in the array, or `-1` or `null` if `druid.generic.useDefaultValueForNull=false`if no matching elements exist in the array. |
| array_ordinal_of(arr,expr) | returns the 1 based index of the first occurrence of expr in the array, or `-1` or `null` if `druid.generic.useDefaultValueForNull=false` if no matching elements exist in the array. |
| array_prepend(expr,arr) | adds expr to arr at the beginning, the resulting array type determined by the type of the array |
| array_append(arr1,expr) | appends expr to arr, the resulting array type determined by the type of the first array |
| array_append(arr,expr) | appends expr to arr, the resulting array type determined by the type of the first array |
| array_concat(arr1,arr2) | concatenates 2 arrays, the resulting array type determined by the type of the first array |
| array_set_add(arr,expr) | adds expr to arr and converts the array to a new array composed of the unique set of elements. The resulting array type determined by the type of the array |
| array_set_add_all(arr1,arr2) | combines the unique set of elements of 2 arrays, the resulting array type determined by the type of the first array |
| array_slice(arr,start,end) | return the subarray of arr from the 0 based index start(inclusive) to end(exclusive), or `null`, if start is less than 0, greater than length of arr or less than end|
| array_to_string(arr,str) | joins all elements of arr by the delimiter specified by str |
| string_to_array(str1,str2) | splits str1 into an array on the delimiter specified by str2 |
## Apply functions
Apply functions allow for special 'lambda' expressions to be defined and applied to array inputs to enable free-form transformations.
| function | description |
| --- | --- |
@ -197,6 +200,26 @@ See javadoc of java.lang.Math for detailed explanation for each function.
| all(lambda,arr) | returns 1 if all elements in the array matches the lambda expression, else 0 |
### Lambda expressions syntax
Lambda expressions are a sort of function definition, where new identifiers can be defined and passed as input to the expression body
```
(identifier1 ...) -> expr
```
e.g.
```
(x, y) -> x + y
```
The identifier arguments of a lambda expression correspond to the elements of the array it is being applied to. For example:
```
map((x) -> x + 1, some_multi_value_column)
```
will map each element of `some_multi_value_column` to the identifier `x` so that the lambda expression body can be evaluated for each `x`. The scoping rules are that lambda arguments will override identifiers which are defined externally from the lambda expression body. Using the same example:
```
map((x) -> x + 1, x)
```
in this case, the `x` when evaluating `x + 1` is the lambda argument, thus an element of the multi-valued column `x`, rather than the column `x` itself.
## Reduction functions
Reduction functions operate on zero or more expressions and return a single expression. If no expressions are passed as

View File

@ -27,6 +27,7 @@ import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory;
import org.apache.druid.query.aggregation.FloatMinAggregatorFactory;
@ -120,7 +121,8 @@ public class AggregatorsModule extends SimpleModule
@JsonSubTypes.Type(name = "floatAny", value = FloatAnyAggregatorFactory.class),
@JsonSubTypes.Type(name = "doubleAny", value = DoubleAnyAggregatorFactory.class),
@JsonSubTypes.Type(name = "stringAny", value = StringAnyAggregatorFactory.class),
@JsonSubTypes.Type(name = "grouping", value = GroupingAggregatorFactory.class)
@JsonSubTypes.Type(name = "grouping", value = GroupingAggregatorFactory.class),
@JsonSubTypes.Type(name = "expression", value = ExpressionLambdaAggregatorFactory.class)
})
public interface AggregatorFactoryMixin
{

View File

@ -137,6 +137,9 @@ public class AggregatorUtil
// GROUPING aggregator
public static final byte GROUPING_CACHE_TYPE_ID = 0x46;
// expression lambda aggregator
public static final byte EXPRESSION_LAMBDA_CACHE_TYPE_ID = 0x47;
/**
* returns the list of dependent postAggregators that should be calculated in order to calculate given postAgg

View File

@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation;
import org.apache.druid.math.expr.Expr;
import javax.annotation.Nullable;
public class ExpressionLambdaAggregator implements Aggregator
{
private final Expr lambda;
private final ExpressionLambdaAggregatorInputBindings bindings;
public ExpressionLambdaAggregator(Expr lambda, ExpressionLambdaAggregatorInputBindings bindings)
{
this.lambda = lambda;
this.bindings = bindings;
}
@Override
public void aggregate()
{
bindings.accumulate(lambda.eval(bindings));
}
@Nullable
@Override
public Object get()
{
return bindings.getAccumulator().value();
}
@Override
public float getFloat()
{
return (float) bindings.getAccumulator().asDouble();
}
@Override
public long getLong()
{
return bindings.getAccumulator().asLong();
}
@Override
public double getDouble()
{
return bindings.getAccumulator().asDouble();
}
@Override
public boolean isNull()
{
return bindings.getAccumulator().isNumericNull();
}
@Override
public void close()
{
// nothing to close
}
}

View File

@ -0,0 +1,516 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation;
import com.fasterxml.jackson.annotation.JacksonInject;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.Iterables;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.guava.Comparators;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.math.expr.SettableObjectBinding;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.expression.ExprUtils;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.virtual.ExpressionPlan;
import org.apache.druid.segment.virtual.ExpressionPlanner;
import org.apache.druid.segment.virtual.ExpressionSelectors;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
{
private static final String FINALIZE_IDENTIFIER = "o";
private static final String COMPARE_O1 = "o1";
private static final String COMPARE_O2 = "o2";
private static final String DEFAULT_ACCUMULATOR_ID = "__acc";
// minimum permitted agg size is 10 bytes so it is at least large enough to hold primitive numerics (long, double)
// | expression type byte | is_null byte | primitive value (8 bytes) |
private static final int MIN_SIZE_BYTES = 10;
private static final HumanReadableBytes DEFAULT_MAX_SIZE_BYTES = new HumanReadableBytes(1L << 10);
private final String name;
@Nullable
private final Set<String> fields;
private final String accumulatorId;
private final String foldExpressionString;
private final String initialValueExpressionString;
private final String initialCombineValueExpressionString;
private final String combineExpressionString;
@Nullable
private final String compareExpressionString;
@Nullable
private final String finalizeExpressionString;
private final ExprMacroTable macroTable;
private final Supplier<ExprEval<?>> initialValue;
private final Supplier<ExprEval<?>> initialCombineValue;
private final Supplier<Expr> foldExpression;
private final Supplier<Expr> combineExpression;
private final Supplier<Expr> compareExpression;
private final Supplier<Expr> finalizeExpression;
private final HumanReadableBytes maxSizeBytes;
private final Supplier<SettableObjectBinding> compareBindings =
Suppliers.memoize(() -> new SettableObjectBinding(2));
private final Supplier<SettableObjectBinding> combineBindings =
Suppliers.memoize(() -> new SettableObjectBinding(2));
private final Supplier<SettableObjectBinding> finalizeBindings =
Suppliers.memoize(() -> new SettableObjectBinding(1));
@JsonCreator
public ExpressionLambdaAggregatorFactory(
@JsonProperty("name") String name,
@JsonProperty("fields") @Nullable final Set<String> fields,
@JsonProperty("accumulatorIdentifier") @Nullable final String accumulatorIdentifier,
@JsonProperty("initialValue") final String initialValue,
@JsonProperty("initialCombineValue") @Nullable final String initialCombineValue,
@JsonProperty("fold") final String foldExpression,
@JsonProperty("combine") @Nullable final String combineExpression,
@JsonProperty("compare") @Nullable final String compareExpression,
@JsonProperty("finalize") @Nullable final String finalizeExpression,
@JsonProperty("maxSizeBytes") @Nullable final HumanReadableBytes maxSizeBytes,
@JacksonInject ExprMacroTable macroTable
)
{
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
this.name = name;
this.fields = fields;
this.accumulatorId = accumulatorIdentifier != null ? accumulatorIdentifier : DEFAULT_ACCUMULATOR_ID;
this.initialValueExpressionString = initialValue;
this.initialCombineValueExpressionString = initialCombineValue == null ? initialValue : initialCombineValue;
this.foldExpressionString = foldExpression;
if (combineExpression != null) {
this.combineExpressionString = combineExpression;
} else {
// if the combine expression is null, allow single input aggregator expressions to be rewritten to replace the
// field with the aggregator name. Fields is null for the combining/merging aggregator, but the expression should
// already be set with the rewritten value at that point
Preconditions.checkArgument(
fields != null && fields.size() == 1,
"Must have a single input field if no combine expression is supplied"
);
this.combineExpressionString = StringUtils.replace(foldExpression, Iterables.getOnlyElement(fields), name);
}
this.compareExpressionString = compareExpression;
this.finalizeExpressionString = finalizeExpression;
this.macroTable = macroTable;
this.initialValue = Suppliers.memoize(() -> {
Expr parsed = Parser.parse(initialValue, macroTable);
Preconditions.checkArgument(parsed.isLiteral(), "initial value must be constant");
return parsed.eval(ExprUtils.nilBindings());
});
this.initialCombineValue = Suppliers.memoize(() -> {
Expr parsed = Parser.parse(this.initialCombineValueExpressionString, macroTable);
Preconditions.checkArgument(parsed.isLiteral(), "initial combining value must be constant");
return parsed.eval(ExprUtils.nilBindings());
});
this.foldExpression = Parser.lazyParse(foldExpressionString, macroTable);
this.combineExpression = Parser.lazyParse(combineExpressionString, macroTable);
this.compareExpression = Parser.lazyParse(compareExpressionString, macroTable);
this.finalizeExpression = Parser.lazyParse(finalizeExpressionString, macroTable);
this.maxSizeBytes = maxSizeBytes != null ? maxSizeBytes : DEFAULT_MAX_SIZE_BYTES;
Preconditions.checkArgument(this.maxSizeBytes.getBytesInInt() >= MIN_SIZE_BYTES);
}
@JsonProperty
@Override
public String getName()
{
return name;
}
@JsonProperty
@Nullable
public Set<String> getFields()
{
return fields;
}
@JsonProperty
@Nullable
public String getAccumulatorIdentifier()
{
return accumulatorId;
}
@JsonProperty("initialValue")
public String getInitialValueExpressionString()
{
return initialValueExpressionString;
}
@JsonProperty("initialCombineValue")
public String getInitialCombineValueExpressionString()
{
return initialCombineValueExpressionString;
}
@JsonProperty("fold")
public String getFoldExpressionString()
{
return foldExpressionString;
}
@JsonProperty("combine")
public String getCombineExpressionString()
{
return combineExpressionString;
}
@JsonProperty("compare")
@Nullable
public String getCompareExpressionString()
{
return compareExpressionString;
}
@JsonProperty("finalize")
@Nullable
public String getFinalizeExpressionString()
{
return finalizeExpressionString;
}
@JsonProperty("maxSizeBytes")
public HumanReadableBytes getMaxSizeBytes()
{
return maxSizeBytes;
}
@Override
public byte[] getCacheKey()
{
return new CacheKeyBuilder(AggregatorUtil.EXPRESSION_LAMBDA_CACHE_TYPE_ID)
.appendStrings(fields)
.appendString(initialValueExpressionString)
.appendString(initialCombineValueExpressionString)
.appendString(foldExpressionString)
.appendString(combineExpressionString)
.appendString(compareExpressionString)
.appendString(finalizeExpressionString)
.appendInt(maxSizeBytes.getBytesInInt())
.build();
}
@Override
public Aggregator factorize(ColumnSelectorFactory metricFactory)
{
FactorizePlan thePlan = new FactorizePlan(metricFactory);
return new ExpressionLambdaAggregator(
thePlan.getExpression(),
thePlan.getBindings()
);
}
@Override
public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
{
FactorizePlan thePlan = new FactorizePlan(metricFactory);
return new ExpressionLambdaBufferAggregator(
thePlan.getExpression(),
thePlan.getInitialValue(),
thePlan.getBindings(),
maxSizeBytes.getBytesInInt()
);
}
@Override
public Comparator getComparator()
{
Expr compareExpr = compareExpression.get();
if (compareExpr != null) {
return (o1, o2) ->
compareExpr.eval(compareBindings.get().withBinding(COMPARE_O1, o1).withBinding(COMPARE_O2, o2)).asInt();
}
switch (initialValue.get().type()) {
case LONG:
return LongSumAggregator.COMPARATOR;
case DOUBLE:
return DoubleSumAggregator.COMPARATOR;
default:
return Comparators.naturalNullsFirst();
}
}
@Nullable
@Override
public Object combine(@Nullable Object lhs, @Nullable Object rhs)
{
// arbitrarily assign lhs and rhs to accumulator and aggregator name inputs to re-use combine function
return combineExpression.get().eval(
combineBindings.get().withBinding(accumulatorId, lhs).withBinding(name, rhs)
).value();
}
@Override
public Object deserialize(Object object)
{
return object;
}
@Nullable
@Override
public Object finalizeComputation(@Nullable Object object)
{
Expr finalizeExpr;
finalizeExpr = finalizeExpression.get();
if (finalizeExpr != null) {
return finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, object)).value();
}
return object;
}
@Override
public List<String> requiredFields()
{
if (fields == null) {
return combineExpression.get().analyzeInputs().getRequiredBindingsList();
}
return foldExpression.get().analyzeInputs().getRequiredBindingsList();
}
@Override
public AggregatorFactory getCombiningFactory()
{
return new ExpressionLambdaAggregatorFactory(
name,
null,
accumulatorId,
initialValueExpressionString,
initialCombineValueExpressionString,
foldExpressionString,
combineExpressionString,
compareExpressionString,
finalizeExpressionString,
maxSizeBytes,
macroTable
);
}
@Override
public List<AggregatorFactory> getRequiredColumns()
{
return Collections.singletonList(
new ExpressionLambdaAggregatorFactory(
name,
fields,
accumulatorId,
initialValueExpressionString,
initialCombineValueExpressionString,
foldExpressionString,
combineExpressionString,
compareExpressionString,
finalizeExpressionString,
maxSizeBytes,
macroTable
)
);
}
@Override
public ValueType getType()
{
if (fields == null) {
return ExprType.toValueType(initialCombineValue.get().type());
}
return ExprType.toValueType(initialValue.get().type());
}
@Override
public ValueType getFinalizedType()
{
Expr finalizeExpr = finalizeExpression.get();
ExprEval<?> initialVal = initialCombineValue.get();
if (finalizeExpr != null) {
return ExprType.toValueType(
finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, initialVal)).type()
);
}
return ExprType.toValueType(initialVal.type());
}
@Override
public int getMaxIntermediateSize()
{
// numeric expressions are either longs or doubles, with strings or arrays max size is unknown
// for numeric arguments, the first 2 bytes are used for expression type byte and is_null byte
return getType().isNumeric() ? 2 + Long.BYTES : maxSizeBytes.getBytesInInt();
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ExpressionLambdaAggregatorFactory that = (ExpressionLambdaAggregatorFactory) o;
return maxSizeBytes.equals(that.maxSizeBytes)
&& name.equals(that.name)
&& Objects.equals(fields, that.fields)
&& accumulatorId.equals(that.accumulatorId)
&& foldExpressionString.equals(that.foldExpressionString)
&& initialValueExpressionString.equals(that.initialValueExpressionString)
&& initialCombineValueExpressionString.equals(that.initialCombineValueExpressionString)
&& combineExpressionString.equals(that.combineExpressionString)
&& Objects.equals(compareExpressionString, that.compareExpressionString)
&& Objects.equals(finalizeExpressionString, that.finalizeExpressionString);
}
@Override
public int hashCode()
{
return Objects.hash(
name,
fields,
accumulatorId,
foldExpressionString,
initialValueExpressionString,
initialCombineValueExpressionString,
combineExpressionString,
compareExpressionString,
finalizeExpressionString,
maxSizeBytes
);
}
@Override
public String toString()
{
return "ExpressionLambdaAggregatorFactory{" +
"name='" + name + '\'' +
", fields=" + fields +
", accumulatorId='" + accumulatorId + '\'' +
", foldExpressionString='" + foldExpressionString + '\'' +
", initialValueExpressionString='" + initialValueExpressionString + '\'' +
", initialCombineValueExpressionString='" + initialCombineValueExpressionString + '\'' +
", combineExpressionString='" + combineExpressionString + '\'' +
", compareExpressionString='" + compareExpressionString + '\'' +
", finalizeExpressionString='" + finalizeExpressionString + '\'' +
", maxSizeBytes=" + maxSizeBytes +
'}';
}
/**
* Determine how to factorize the aggregator
*/
private class FactorizePlan
{
private final ExpressionPlan plan;
private final ExprEval<?> seed;
private final ExpressionLambdaAggregatorInputBindings bindings;
FactorizePlan(ColumnSelectorFactory metricFactory)
{
final List<String> columns;
if (fields != null) {
// if fields are set, we are accumulating from raw inputs, use fold expression
plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), foldExpression.get());
seed = initialValue.get();
columns = plan.getAnalysis().getRequiredBindingsList();
} else {
// else we are merging intermediary results, use combine expression
plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), combineExpression.get());
seed = initialCombineValue.get();
columns = plan.getAnalysis().getRequiredBindingsList();
}
bindings = new ExpressionLambdaAggregatorInputBindings(
ExpressionSelectors.createBindings(metricFactory, columns),
accumulatorId,
seed
);
}
public Expr getExpression()
{
if (fields == null) {
return plan.getExpression();
}
// for fold expressions, check to see if it needs transformation due to scalar use of multi-valued or unknown
// inputs
return plan.getAppliedFoldExpression(accumulatorId);
}
public ExprEval<?> getInitialValue()
{
return seed;
}
public ExpressionLambdaAggregatorInputBindings getBindings()
{
return bindings;
}
private ColumnInspector inspectorWithAccumulator(ColumnInspector inspector)
{
return new ColumnInspector()
{
@Nullable
@Override
public ColumnCapabilities getColumnCapabilities(String column)
{
if (accumulatorId.equals(column)) {
return ColumnCapabilitiesImpl.createDefault().setType(ExprType.toValueType(initialValue.get().type()));
}
return inspector.getColumnCapabilities(column);
}
@Nullable
@Override
public ExprType getType(String name)
{
if (accumulatorId.equals(name)) {
return initialValue.get().type();
}
return inspector.getType(name);
}
};
}
}
}

View File

@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import javax.annotation.Nullable;
/**
* Special {@link Expr.ObjectBinding} for use with {@link ExpressionLambdaAggregatorFactory}.
* This value binding holds a value for a special 'accumulator' variable, in addition to the 'normal' bindings to the
* underlying selector inputs for other identifiers, which allows for easy forward feeding of the results of an
* expression evaluation to use in the bindings of the next evaluation.
*/
public class ExpressionLambdaAggregatorInputBindings implements Expr.ObjectBinding
{
private final Expr.ObjectBinding inputBindings;
private final String accumlatorIdentifier;
private ExprEval<?> accumulator;
public ExpressionLambdaAggregatorInputBindings(
Expr.ObjectBinding inputBindings,
String accumulatorIdentifier,
ExprEval<?> initialValue
)
{
this.accumlatorIdentifier = accumulatorIdentifier;
this.inputBindings = inputBindings;
this.accumulator = initialValue;
}
@Nullable
@Override
public Object get(String name)
{
if (accumlatorIdentifier.equals(name)) {
return accumulator.value();
}
return inputBindings.get(name);
}
public void accumulate(ExprEval<?> eval)
{
accumulator = eval;
}
public ExprEval<?> getAccumulator()
{
return accumulator;
}
public void setAccumulator(ExprEval<?> acc)
{
this.accumulator = acc;
}
}

View File

@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
public class ExpressionLambdaBufferAggregator implements BufferAggregator
{
private final Expr lambda;
private final ExprEval<?> initialValue;
private final ExpressionLambdaAggregatorInputBindings bindings;
private final int maxSizeBytes;
public ExpressionLambdaBufferAggregator(
Expr lambda,
ExprEval<?> initialValue,
ExpressionLambdaAggregatorInputBindings bindings,
int maxSizeBytes
)
{
this.lambda = lambda;
this.initialValue = initialValue;
this.bindings = bindings;
this.maxSizeBytes = maxSizeBytes;
}
@Override
public void init(ByteBuffer buf, int position)
{
ExprEval.serialize(buf, position, initialValue, maxSizeBytes);
}
@Override
public void aggregate(ByteBuffer buf, int position)
{
ExprEval<?> acc = ExprEval.deserialize(buf, position);
bindings.setAccumulator(acc);
ExprEval<?> newAcc = lambda.eval(bindings);
ExprEval.serialize(buf, position, newAcc, maxSizeBytes);
}
@Nullable
@Override
public Object get(ByteBuffer buf, int position)
{
return ExprEval.deserialize(buf, position).value();
}
@Override
public float getFloat(ByteBuffer buf, int position)
{
return (float) ExprEval.deserialize(buf, position).asDouble();
}
@Override
public double getDouble(ByteBuffer buf, int position)
{
return ExprEval.deserialize(buf, position).asDouble();
}
@Override
public long getLong(ByteBuffer buf, int position)
{
return ExprEval.deserialize(buf, position).asLong();
}
@Override
public void close()
{
// nothing to close
}
}

View File

@ -23,7 +23,6 @@ package org.apache.druid.query.aggregation;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.Parser;
@ -72,7 +71,7 @@ public abstract class SimpleDoubleAggregatorFactory extends NullableNumericAggre
this.fieldName = fieldName;
this.expression = expression;
this.storeDoubleAsFloat = ColumnHolder.storeDoubleAsFloat();
this.fieldExpression = Suppliers.memoize(() -> expression == null ? null : Parser.parse(expression, macroTable));
this.fieldExpression = Parser.lazyParse(expression, macroTable);
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
Preconditions.checkArgument(
fieldName == null ^ expression == null,

View File

@ -23,7 +23,6 @@ package org.apache.druid.query.aggregation;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.Parser;
@ -63,7 +62,7 @@ public abstract class SimpleFloatAggregatorFactory extends NullableNumericAggreg
this.name = name;
this.fieldName = fieldName;
this.expression = expression;
this.fieldExpression = Suppliers.memoize(() -> expression == null ? null : Parser.parse(expression, macroTable));
this.fieldExpression = Parser.lazyParse(expression, macroTable);
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
Preconditions.checkArgument(
fieldName == null ^ expression == null,

View File

@ -23,7 +23,6 @@ package org.apache.druid.query.aggregation;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.Parser;
@ -69,7 +68,7 @@ public abstract class SimpleLongAggregatorFactory extends NullableNumericAggrega
this.name = name;
this.fieldName = fieldName;
this.expression = expression;
this.fieldExpression = Suppliers.memoize(() -> expression == null ? null : Parser.parse(expression, macroTable));
this.fieldExpression = Parser.lazyParse(expression, macroTable);
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
Preconditions.checkArgument(
fieldName == null ^ expression == null,

View File

@ -92,7 +92,7 @@ public class ExpressionPostAggregator implements PostAggregator
ordering,
macroTable,
ImmutableMap.of(),
Suppliers.memoize(() -> Parser.parse(expression, macroTable))
Parser.lazyParse(expression, macroTable)
);
}

View File

@ -25,7 +25,6 @@ import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.RangeSet;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprMacroTable;
@ -53,7 +52,7 @@ public class ExpressionDimFilter extends AbstractOptimizableDimFilter implements
{
this.expression = expression;
this.filterTuning = filterTuning;
this.parsed = Suppliers.memoize(() -> Parser.parse(expression, macroTable));
this.parsed = Parser.lazyParse(expression, macroTable);
}
@VisibleForTesting

View File

@ -26,6 +26,7 @@ import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import org.apache.druid.data.input.Row;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.segment.column.ColumnHolder;
@ -106,7 +107,7 @@ public class ExpressionTransform implements Transform
} else {
Object raw = row.getRaw(column);
if (raw instanceof List) {
return ExpressionSelectors.coerceListToArray((List) raw);
return ExprEval.coerceListToArray((List) raw, true);
}
return raw;
}

View File

@ -122,6 +122,14 @@ public class ExpressionPlan
return expression;
}
public Expr getAppliedFoldExpression(String accumulatorId)
{
if (is(Trait.NEEDS_APPLIED)) {
return Parser.foldUnappliedBindings(expression, analysis, unappliedInputs, accumulatorId);
}
return expression;
}
public Expr.BindingAnalysis getAnalysis()
{
return analysis;

View File

@ -24,7 +24,6 @@ import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.collect.Iterables;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.Parser;
@ -242,13 +241,26 @@ public class ExpressionSelectors
* provides the set of identifiers which need a binding (list of required columns), and context of whether or not they
* are used as array or scalar inputs
*/
private static Expr.ObjectBinding createBindings(
public static Expr.ObjectBinding createBindings(
Expr.BindingAnalysis bindingAnalysis,
ColumnSelectorFactory columnSelectorFactory
)
{
final Map<String, Supplier<Object>> suppliers = new HashMap<>();
final List<String> columns = bindingAnalysis.getRequiredBindingsList();
return createBindings(columnSelectorFactory, columns);
}
/**
* Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link Expr.BindingAnalysis} which
* provides the set of identifiers which need a binding (list of required columns), and context of whether or not they
* are used as array or scalar inputs
*/
public static Expr.ObjectBinding createBindings(
ColumnSelectorFactory columnSelectorFactory,
List<String> columns
)
{
final Map<String, Supplier<Object>> suppliers = new HashMap<>();
for (String columnName : columns) {
final ColumnCapabilities columnCapabilities = columnSelectorFactory.getColumnCapabilities(columnName);
final ValueType nativeType = columnCapabilities != null ? columnCapabilities.getType() : null;
@ -269,8 +281,8 @@ public class ExpressionSelectors
columnSelectorFactory.makeDimensionSelector(new DefaultDimensionSpec(columnName, columnName)),
multiVal
);
} else if (nativeType == null) {
// Unknown ValueType. Try making an Object selector and see if that gives us anything useful.
} else if (nativeType == null || ValueType.isArray(nativeType)) {
// Unknown ValueType or array type. Try making an Object selector and see if that gives us anything useful.
supplier = supplierFromObjectSelector(columnSelectorFactory.makeColumnValueSelector(columnName));
} else {
// Unhandleable ValueType (COMPLEX).
@ -370,10 +382,10 @@ public class ExpressionSelectors
// Might be Numbers and Strings. Use a selector that double-checks.
return () -> {
final Object val = selector.getObject();
if (val instanceof Number || val instanceof String) {
if (val instanceof Number || val instanceof String || (val != null && val.getClass().isArray())) {
return val;
} else if (val instanceof List) {
return coerceListToArray((List) val);
return ExprEval.coerceListToArray((List) val, true);
} else {
return null;
}
@ -382,7 +394,7 @@ public class ExpressionSelectors
return () -> {
final Object val = selector.getObject();
if (val != null) {
return coerceListToArray((List) val);
return ExprEval.coerceListToArray((List) val, true);
}
return null;
};
@ -392,70 +404,6 @@ public class ExpressionSelectors
}
}
/**
* Selectors are not consistent in treatment of null, [], and [null], so coerce [] to [null]
*/
public static Object coerceListToArray(@Nullable List<?> val)
{
if (val != null && val.size() > 0) {
Class coercedType = null;
for (Object elem : val) {
if (elem != null) {
coercedType = convertType(coercedType, elem.getClass());
}
}
if (coercedType == Long.class || coercedType == Integer.class) {
return val.stream().map(x -> x != null ? ((Number) x).longValue() : null).toArray(Long[]::new);
}
if (coercedType == Float.class || coercedType == Double.class) {
return val.stream().map(x -> x != null ? ((Number) x).doubleValue() : null).toArray(Double[]::new);
}
// default to string
return val.stream().map(x -> x != null ? x.toString() : null).toArray(String[]::new);
}
return new String[]{null};
}
private static Class convertType(@Nullable Class existing, Class next)
{
if (Number.class.isAssignableFrom(next) || next == String.class) {
if (existing == null) {
return next;
}
// string wins everything
if (existing == String.class) {
return existing;
}
if (next == String.class) {
return next;
}
// all numbers win over Integer
if (existing == Integer.class) {
return next;
}
if (existing == Float.class) {
// doubles win over floats
if (next == Double.class) {
return next;
}
return existing;
}
if (existing == Long.class) {
if (next == Integer.class) {
// long beats int
return existing;
}
// double and float win over longs
return next;
}
// otherwise double
return Double.class;
}
throw new UOE("Invalid array expression type: %s", next);
}
/**
* Coerces {@link ExprEval} value back to selector friendly {@link List} if the evaluated expression result is an
* array type

View File

@ -72,7 +72,7 @@ public class ExpressionVirtualColumn implements VirtualColumn
this.name = Preconditions.checkNotNull(name, "name");
this.expression = Preconditions.checkNotNull(expression, "expression");
this.outputType = outputType;
this.parsedExpression = Suppliers.memoize(() -> Parser.parse(expression, macroTable));
this.parsedExpression = Parser.lazyParse(expression, macroTable);
}
/**

View File

@ -0,0 +1,570 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.Druids;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.query.expression.TestExprMacroTable;
import org.apache.druid.query.timeseries.TimeseriesQuery;
import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.io.IOException;
public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandlingTest
{
private static ObjectMapper MAPPER = TestHelper.makeJsonMapper();
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Test
public void testSerde() throws IOException
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column", "some_other_column"),
"customAccumulator",
"0.0",
"10.0",
"customAccumulator + some_column + some_other_column",
"customAccumulator + expr_agg_name",
"if (o1 > o2, if (o1 == o2, 0, 1), -1)",
"o + 100",
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(agg, MAPPER.readValue(MAPPER.writeValueAsBytes(agg), ExpressionLambdaAggregatorFactory.class));
}
@Test
public void testEqualsAndHashCode()
{
EqualsVerifier.forClass(ExpressionLambdaAggregatorFactory.class)
.usingGetClass()
.withIgnoredFields(
"macroTable",
"initialValue",
"initialCombineValue",
"foldExpression",
"combineExpression",
"compareExpression",
"finalizeExpression",
"compareBindings",
"combineBindings",
"finalizeBindings"
)
.verify();
}
@Test
public void testInitialValueMustBeConstant()
{
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("initial value must be constant");
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column", "some_other_column"),
null,
"x + y",
null,
"__acc + some_column + some_other_column",
"__acc + expr_agg_name",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
agg.getType();
}
@Test
public void testInitialCombineValueMustBeConstant()
{
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("initial combining value must be constant");
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column", "some_other_column"),
null,
"0.0",
"x + y",
"__acc + some_column + some_other_column",
"__acc + expr_agg_name",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
agg.getFinalizedType();
}
@Test
public void testSingleInputCombineExpressionIsOptional()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("x"),
null,
"0",
null,
"__acc + x",
null,
null,
null,
null,
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(1L, agg.combine(0L, 1L));
}
@Test
public void testFinalizeCanDo()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("x"),
null,
"0",
null,
"__acc + x",
null,
null,
"o + 100",
null,
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(100L, agg.finalizeComputation(0L));
}
@Test
public void testFinalizeCanDoArrays()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("x"),
null,
"0",
null,
"array_set_add(__acc, x)",
"array_set_add_all(__acc, expr_agg_name)",
null,
"array_to_string(o, ',')",
null,
TestExprMacroTable.INSTANCE
);
Assert.assertEquals("a,b,c", agg.finalizeComputation(new String[]{"a", "b", "c"}));
Assert.assertEquals("a,b,c", agg.finalizeComputation(ImmutableList.of("a", "b", "c")));
}
@Test
public void testStringType()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column", "some_other_column"),
null,
"''",
"''",
"concat(__acc, some_column, some_other_column)",
"concat(__acc, expr_agg_name)",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(ValueType.STRING, agg.getType());
Assert.assertEquals(ValueType.STRING, agg.getCombiningFactory().getType());
Assert.assertEquals(ValueType.STRING, agg.getFinalizedType());
}
@Test
public void testLongType()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column", "some_other_column"),
null,
"0",
null,
"__acc + some_column + some_other_column",
"__acc + expr_agg_name",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(ValueType.LONG, agg.getType());
Assert.assertEquals(ValueType.LONG, agg.getCombiningFactory().getType());
Assert.assertEquals(ValueType.LONG, agg.getFinalizedType());
}
@Test
public void testDoubleType()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column", "some_other_column"),
null,
"0.0",
null,
"__acc + some_column + some_other_column",
"__acc + expr_agg_name",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(ValueType.DOUBLE, agg.getType());
Assert.assertEquals(ValueType.DOUBLE, agg.getCombiningFactory().getType());
Assert.assertEquals(ValueType.DOUBLE, agg.getFinalizedType());
}
@Test
public void testStringArrayType()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column", "some_other_column"),
null,
"''",
"<STRING>[]",
"concat(__acc, some_column, some_other_column)",
"array_set_add(__acc, expr_agg_name)",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(ValueType.STRING, agg.getType());
Assert.assertEquals(ValueType.STRING_ARRAY, agg.getCombiningFactory().getType());
Assert.assertEquals(ValueType.STRING_ARRAY, agg.getFinalizedType());
}
@Test
public void testStringArrayTypeFinalized()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column", "some_other_column"),
null,
"''",
"<STRING>[]",
"concat(__acc, some_column, some_other_column)",
"array_set_add(__acc, expr_agg_name)",
null,
"array_to_string(o, ';')",
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(ValueType.STRING, agg.getType());
Assert.assertEquals(ValueType.STRING_ARRAY, agg.getCombiningFactory().getType());
Assert.assertEquals(ValueType.STRING, agg.getFinalizedType());
}
@Test
public void testLongArrayType()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column", "some_other_column"),
null,
"0",
"<LONG>[]",
"__acc + some_column + some_other_column",
"array_set_add(__acc, expr_agg_name)",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(ValueType.LONG, agg.getType());
Assert.assertEquals(ValueType.LONG_ARRAY, agg.getCombiningFactory().getType());
Assert.assertEquals(ValueType.LONG_ARRAY, agg.getFinalizedType());
}
@Test
public void testLongArrayTypeFinalized()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column", "some_other_column"),
null,
"0",
"<LONG>[]",
"__acc + some_column + some_other_column",
"array_set_add(__acc, expr_agg_name)",
null,
"array_to_string(o, ';')",
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(ValueType.LONG, agg.getType());
Assert.assertEquals(ValueType.LONG_ARRAY, agg.getCombiningFactory().getType());
Assert.assertEquals(ValueType.STRING, agg.getFinalizedType());
}
@Test
public void testDoubleArrayType()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column", "some_other_column"),
null,
"0.0",
"<DOUBLE>[]",
"__acc + some_column + some_other_column",
"array_set_add(__acc, expr_agg_name)",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(ValueType.DOUBLE, agg.getType());
Assert.assertEquals(ValueType.DOUBLE_ARRAY, agg.getCombiningFactory().getType());
Assert.assertEquals(ValueType.DOUBLE_ARRAY, agg.getFinalizedType());
}
@Test
public void testDoubleArrayTypeFinalized()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column", "some_other_column"),
null,
"0.0",
"<DOUBLE>[]",
"__acc + some_column + some_other_column",
"array_set_add(__acc, expr_agg_name)",
null,
"array_to_string(o, ';')",
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(ValueType.DOUBLE, agg.getType());
Assert.assertEquals(ValueType.DOUBLE_ARRAY, agg.getCombiningFactory().getType());
Assert.assertEquals(ValueType.STRING, agg.getFinalizedType());
}
@Test
public void testResultArraySignature()
{
final TimeseriesQuery query =
Druids.newTimeseriesQueryBuilder()
.dataSource("dummy")
.intervals("2000/3000")
.granularity(Granularities.HOUR)
.aggregators(
new ExpressionLambdaAggregatorFactory(
"string_expr",
ImmutableSet.of("some_column", "some_other_column"),
null,
"''",
"''",
"concat(__acc, some_column, some_other_column)",
"concat(__acc, string_expr)",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"double_expr",
ImmutableSet.of("some_column", "some_other_column"),
null,
"0.0",
null,
"__acc + some_column + some_other_column",
"__acc + double_expr",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"long_expr",
ImmutableSet.of("some_column", "some_other_column"),
null,
"0",
null,
"__acc + some_column + some_other_column",
"__acc + long_expr",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"string_array_expr",
ImmutableSet.of("some_column", "some_other_column"),
null,
"<STRING>[]",
"<STRING>[]",
"array_set_add(__acc, concat(some_column, some_other_column))",
"array_set_add_all(__acc, string_array_expr)",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"double_array_expr",
ImmutableSet.of("some_column", "some_other_column_expr"),
null,
"0.0",
"<DOUBLE>[]",
"__acc + some_column + some_other_column",
"array_set_add(__acc, double_array)",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"long_array_expr",
ImmutableSet.of("some_column", "some_other_column"),
null,
"0",
"<LONG>[]",
"__acc + some_column + some_other_column",
"array_set_add(__acc, long_array_expr)",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"string_array_expr_finalized",
ImmutableSet.of("some_column", "some_other_column"),
null,
"''",
"<STRING>[]",
"concat(__acc, some_column, some_other_column)",
"array_set_add(__acc, string_array_expr)",
null,
"array_to_string(o, ';')",
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"double_array_expr_finalized",
ImmutableSet.of("some_column", "some_other_column_expr"),
null,
"0.0",
"<DOUBLE>[]",
"__acc + some_column + some_other_column",
"array_set_add(__acc, double_array)",
null,
"array_to_string(o, ';')",
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"long_array_expr_finalized",
ImmutableSet.of("some_column", "some_other_column"),
null,
"0",
"<LONG>[]",
"__acc + some_column + some_other_column",
"array_set_add(__acc, long_array_expr)",
null,
"fold((x, acc) -> x + acc, o, 0)",
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
)
)
.postAggregators(
new FieldAccessPostAggregator("string-array-expr-access", "string_array_expr_finalized"),
new FinalizingFieldAccessPostAggregator("string-array-expr-finalize", "string_array_expr_finalized"),
new FieldAccessPostAggregator("double-array-expr-access", "double_array_expr_finalized"),
new FinalizingFieldAccessPostAggregator("double-array-expr-finalize", "double_array_expr_finalized"),
new FieldAccessPostAggregator("long-array-expr-access", "long_array_expr_finalized"),
new FinalizingFieldAccessPostAggregator("long-array-expr-finalize", "long_array_expr_finalized")
)
.build();
Assert.assertEquals(
RowSignature.builder()
.addTimeColumn()
.add("string_expr", ValueType.STRING)
.add("double_expr", ValueType.DOUBLE)
.add("long_expr", ValueType.LONG)
.add("string_array_expr", ValueType.STRING_ARRAY)
// type does not equal finalized type. (combining factory type does equal finalized type,
// but this signature doesn't use combining factory)
.add("double_array_expr", null)
// type does not equal finalized type. (combining factory type does equal finalized type,
// but this signature doesn't use combining factory)
.add("long_array_expr", null)
// string because fold type equals finalized type, even though merge type is array
.add("string_array_expr_finalized", ValueType.STRING)
// type does not equal finalized type. (combining factory type does equal finalized type,
// but this signature doesn't use combining factory)
.add("double_array_expr_finalized", null)
// long because fold type equals finalized type, even though merge type is array
.add("long_array_expr_finalized", ValueType.LONG)
// fold type is string
.add("string-array-expr-access", ValueType.STRING)
// finalized type is string
.add("string-array-expr-finalize", ValueType.STRING)
// double because fold type is double
.add("double-array-expr-access", ValueType.DOUBLE)
// string because finalize type is string
.add("double-array-expr-finalize", ValueType.STRING)
// long because fold type is long
.add("long-array-expr-access", ValueType.LONG)
// finalized type is long
.add("long-array-expr-finalize", ValueType.LONG)
.build(),
new TimeseriesQueryQueryToolChest().resultArraySignature(query)
);
}
}

View File

@ -25,6 +25,7 @@ import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
@ -67,6 +68,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.FloatSumAggregatorFactory;
import org.apache.druid.query.aggregation.JavaScriptAggregatorFactory;
@ -11138,6 +11140,704 @@ public class GroupByQueryRunnerTest extends InitializedNullHandlingTest
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
}
@Test
public void testGroupByWithExpressionAggregator()
{
// expression agg not yet vectorized
cannotVectorize();
GroupByQuery query = makeQueryBuilder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
.setAggregatorSpecs(
new ExpressionLambdaAggregatorFactory(
"rows",
Collections.emptySet(),
null,
"0",
null,
"__acc + 1",
"__acc + rows",
null,
null,
null,
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"idx",
ImmutableSet.of("index"),
null,
"0.0",
null,
"__acc + index",
null,
null,
null,
null,
TestExprMacroTable.INSTANCE
)
)
.setGranularity(QueryRunnerTestHelper.DAY_GRAN)
.build();
List<ResultRow> expectedResults = Arrays.asList(
makeRow(
query,
"2011-04-01",
"alias",
"automotive",
"rows",
1L,
"idx",
135.88510131835938d
),
makeRow(
query,
"2011-04-01",
"alias",
"business",
"rows",
1L,
"idx",
118.57034
),
makeRow(
query,
"2011-04-01",
"alias",
"entertainment",
"rows",
1L,
"idx",
158.747224
),
makeRow(
query,
"2011-04-01",
"alias",
"health",
"rows",
1L,
"idx",
120.134704
),
makeRow(
query,
"2011-04-01",
"alias",
"mezzanine",
"rows",
3L,
"idx",
2871.8866900000003d
),
makeRow(
query,
"2011-04-01",
"alias",
"news",
"rows",
1L,
"idx",
121.58358d
),
makeRow(
query,
"2011-04-01",
"alias",
"premium",
"rows",
3L,
"idx",
2900.798647d
),
makeRow(
query,
"2011-04-01",
"alias",
"technology",
"rows",
1L,
"idx",
78.622547d
),
makeRow(
query,
"2011-04-01",
"alias",
"travel",
"rows",
1L,
"idx",
119.922742d
),
makeRow(
query,
"2011-04-02",
"alias",
"automotive",
"rows",
1L,
"idx",
147.42593d
),
makeRow(
query,
"2011-04-02",
"alias",
"business",
"rows",
1L,
"idx",
112.987027d
),
makeRow(
query,
"2011-04-02",
"alias",
"entertainment",
"rows",
1L,
"idx",
166.016049d
),
makeRow(
query,
"2011-04-02",
"alias",
"health",
"rows",
1L,
"idx",
113.446008d
),
makeRow(
query,
"2011-04-02",
"alias",
"mezzanine",
"rows",
3L,
"idx",
2448.830613d
),
makeRow(
query,
"2011-04-02",
"alias",
"news",
"rows",
1L,
"idx",
114.290141d
),
makeRow(
query,
"2011-04-02",
"alias",
"premium",
"rows",
3L,
"idx",
2506.415148d
),
makeRow(
query,
"2011-04-02",
"alias",
"technology",
"rows",
1L,
"idx",
97.387433d
),
makeRow(
query,
"2011-04-02",
"alias",
"travel",
"rows",
1L,
"idx",
126.411364d
)
);
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
}
@Test
public void testGroupByWithExpressionAggregatorWithArrays()
{
// expression agg not yet vectorized
cannotVectorize();
// array types don't work with group by v1
if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) {
expectedException.expect(IllegalStateException.class);
expectedException.expectMessage("Unable to handle type[STRING_ARRAY] for AggregatorFactory[class org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory]");
}
GroupByQuery query = makeQueryBuilder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
.setAggregatorSpecs(
new ExpressionLambdaAggregatorFactory(
"rows",
Collections.emptySet(),
null,
"0",
null,
"__acc + 1",
"__acc + rows",
null,
null,
null,
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"idx",
ImmutableSet.of("index"),
null,
"0.0",
null,
"__acc + index",
null,
null,
null,
null,
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"array_agg_distinct",
ImmutableSet.of(QueryRunnerTestHelper.MARKET_DIMENSION),
"acc",
"[]",
null,
"array_set_add(acc, market)",
"array_set_add_all(acc, array_agg_distinct)",
null,
null,
null,
TestExprMacroTable.INSTANCE
)
)
.setGranularity(QueryRunnerTestHelper.DAY_GRAN)
.build();
List<ResultRow> expectedResults = Arrays.asList(
makeRow(
query,
"2011-04-01",
"alias",
"automotive",
"rows",
1L,
"idx",
135.88510131835938d,
"array_agg_distinct",
new String[] {"spot"}
),
makeRow(
query,
"2011-04-01",
"alias",
"business",
"rows",
1L,
"idx",
118.57034,
"array_agg_distinct",
new String[] {"spot"}
),
makeRow(
query,
"2011-04-01",
"alias",
"entertainment",
"rows",
1L,
"idx",
158.747224,
"array_agg_distinct",
new String[] {"spot"}
),
makeRow(
query,
"2011-04-01",
"alias",
"health",
"rows",
1L,
"idx",
120.134704,
"array_agg_distinct",
new String[] {"spot"}
),
makeRow(
query,
"2011-04-01",
"alias",
"mezzanine",
"rows",
3L,
"idx",
2871.8866900000003d,
"array_agg_distinct",
new String[] {"upfront", "spot", "total_market"}
),
makeRow(
query,
"2011-04-01",
"alias",
"news",
"rows",
1L,
"idx",
121.58358d,
"array_agg_distinct",
new String[] {"spot"}
),
makeRow(
query,
"2011-04-01",
"alias",
"premium",
"rows",
3L,
"idx",
2900.798647d,
"array_agg_distinct",
new String[] {"upfront", "spot", "total_market"}
),
makeRow(
query,
"2011-04-01",
"alias",
"technology",
"rows",
1L,
"idx",
78.622547d,
"array_agg_distinct",
new String[] {"spot"}
),
makeRow(
query,
"2011-04-01",
"alias",
"travel",
"rows",
1L,
"idx",
119.922742d,
"array_agg_distinct",
new String[] {"spot"}
),
makeRow(
query,
"2011-04-02",
"alias",
"automotive",
"rows",
1L,
"idx",
147.42593d,
"array_agg_distinct",
new String[] {"spot"}
),
makeRow(
query,
"2011-04-02",
"alias",
"business",
"rows",
1L,
"idx",
112.987027d,
"array_agg_distinct",
new String[] {"spot"}
),
makeRow(
query,
"2011-04-02",
"alias",
"entertainment",
"rows",
1L,
"idx",
166.016049d,
"array_agg_distinct",
new String[] {"spot"}
),
makeRow(
query,
"2011-04-02",
"alias",
"health",
"rows",
1L,
"idx",
113.446008d,
"array_agg_distinct",
new String[] {"spot"}
),
makeRow(
query,
"2011-04-02",
"alias",
"mezzanine",
"rows",
3L,
"idx",
2448.830613d,
"array_agg_distinct",
new String[] {"upfront", "spot", "total_market"}
),
makeRow(
query,
"2011-04-02",
"alias",
"news",
"rows",
1L,
"idx",
114.290141d,
"array_agg_distinct",
new String[] {"spot"}
),
makeRow(
query,
"2011-04-02",
"alias",
"premium",
"rows",
3L,
"idx",
2506.415148d,
"array_agg_distinct",
new String[] {"upfront", "spot", "total_market"}
),
makeRow(
query,
"2011-04-02",
"alias",
"technology",
"rows",
1L,
"idx",
97.387433d,
"array_agg_distinct",
new String[] {"spot"}
),
makeRow(
query,
"2011-04-02",
"alias",
"travel",
"rows",
1L,
"idx",
126.411364d,
"array_agg_distinct",
new String[] {"spot"}
)
);
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
}
@Test
public void testGroupByExpressionAggregatorArrayMultiValue()
{
// expression agg not yet vectorized
cannotVectorize();
// array types don't work with group by v1
if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) {
expectedException.expect(IllegalStateException.class);
expectedException.expectMessage("Unable to handle type[STRING_ARRAY] for AggregatorFactory[class org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory]");
}
GroupByQuery query = makeQueryBuilder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setDimensions(new DefaultDimensionSpec("quality", "alias"))
.setAggregatorSpecs(
new ExpressionLambdaAggregatorFactory(
"array_agg_distinct",
ImmutableSet.of(QueryRunnerTestHelper.PLACEMENTISH_DIMENSION),
"acc",
"[]",
null,
"array_set_add(acc, placementish)",
"array_set_add_all(acc, array_agg_distinct)",
null,
null,
null,
TestExprMacroTable.INSTANCE
)
)
.setGranularity(QueryRunnerTestHelper.DAY_GRAN)
.build();
List<ResultRow> expectedResults = Arrays.asList(
makeRow(
query,
"2011-04-01",
"alias",
"automotive",
"array_agg_distinct",
new String[] {"a", "preferred"}
),
makeRow(
query,
"2011-04-01",
"alias",
"business",
"array_agg_distinct",
new String[] {"b", "preferred"}
),
makeRow(
query,
"2011-04-01",
"alias",
"entertainment",
"array_agg_distinct",
new String[] {"e", "preferred"}
),
makeRow(
query,
"2011-04-01",
"alias",
"health",
"array_agg_distinct",
new String[] {"h", "preferred"}
),
makeRow(
query,
"2011-04-01",
"alias",
"mezzanine",
"array_agg_distinct",
new String[] {"m", "preferred"}
),
makeRow(
query,
"2011-04-01",
"alias",
"news",
"array_agg_distinct",
new String[] {"n", "preferred"}
),
makeRow(
query,
"2011-04-01",
"alias",
"premium",
"array_agg_distinct",
new String[] {"p", "preferred"}
),
makeRow(
query,
"2011-04-01",
"alias",
"technology",
"array_agg_distinct",
new String[] {"t", "preferred"}
),
makeRow(
query,
"2011-04-01",
"alias",
"travel",
"array_agg_distinct",
new String[] {"t", "preferred"}
),
makeRow(
query,
"2011-04-02",
"alias",
"automotive",
"array_agg_distinct",
new String[] {"a", "preferred"}
),
makeRow(
query,
"2011-04-02",
"alias",
"business",
"array_agg_distinct",
new String[] {"b", "preferred"}
),
makeRow(
query,
"2011-04-02",
"alias",
"entertainment",
"array_agg_distinct",
new String[] {"e", "preferred"}
),
makeRow(
query,
"2011-04-02",
"alias",
"health",
"array_agg_distinct",
new String[] {"h", "preferred"}
),
makeRow(
query,
"2011-04-02",
"alias",
"mezzanine",
"array_agg_distinct",
new String[] {"m", "preferred"}
),
makeRow(
query,
"2011-04-02",
"alias",
"news",
"array_agg_distinct",
new String[] {"n", "preferred"}
),
makeRow(
query,
"2011-04-02",
"alias",
"premium",
"array_agg_distinct",
new String[] {"p", "preferred"}
),
makeRow(
query,
"2011-04-02",
"alias",
"technology",
"array_agg_distinct",
new String[] {"t", "preferred"}
),
makeRow(
query,
"2011-04-02",
"alias",
"travel",
"array_agg_distinct",
new String[] {"t", "preferred"}
)
);
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
}
private static ResultRow makeRow(final GroupByQuery query, final String timestamp, final Object... vals)
{
return GroupByQueryRunnerTestHelper.createExpectedRow(query, timestamp, vals);

View File

@ -303,32 +303,4 @@ public class GroupByTimeseriesQueryRunnerTest extends TimeseriesQueryRunnerTest
// Skip this test because the timeseries test expects a day that doesn't have a filter match to be filled in,
// but group by just doesn't return a value if the filter doesn't match.
}
@Override
public void testTimeseriesWithTimestampResultFieldContextForArrayResponse()
{
// Cannot vectorize with an expression virtual column
if (!vectorize) {
super.testTimeseriesWithTimestampResultFieldContextForArrayResponse();
}
}
@Override
public void testTimeseriesWithTimestampResultFieldContextForMapResponse()
{
// Cannot vectorize with an expression virtual column
if (!vectorize) {
super.testTimeseriesWithTimestampResultFieldContextForMapResponse();
}
}
@Override
@Test
public void testTimeseriesWithPostAggregatorReferencingTimestampResultField()
{
// Cannot vectorize with an expression virtual column
if (!vectorize) {
super.testTimeseriesWithPostAggregatorReferencingTimestampResultField();
}
}
}

View File

@ -21,6 +21,7 @@ package org.apache.druid.query.timeseries;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.primitives.Doubles;
@ -44,6 +45,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory;
@ -2935,6 +2937,104 @@ public class TimeseriesQueryRunnerTest extends InitializedNullHandlingTest
assertExpectedResults(expectedResults, results);
}
@Test
public void testTimeseriesWithExpressionAggregator()
{
// expression agg cannot vectorize
cannotVectorize();
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(QueryRunnerTestHelper.DATA_SOURCE)
.granularity(QueryRunnerTestHelper.DAY_GRAN)
.intervals(QueryRunnerTestHelper.FIRST_TO_THIRD)
.aggregators(
Arrays.asList(
new ExpressionLambdaAggregatorFactory(
"diy_count",
ImmutableSet.of(),
null,
"0",
null,
"__acc + 1",
"__acc + diy_count",
null,
null,
null,
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"diy_sum",
ImmutableSet.of("index"),
null,
"0.0",
null,
"__acc + index",
null,
null,
null,
null,
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"diy_decomposed_sum",
ImmutableSet.of("index"),
null,
"0.0",
"<DOUBLE>[]",
"__acc + index",
"array_concat(__acc, diy_decomposed_sum)",
null,
"fold((x, acc) -> x + acc, o, 0.0)",
null,
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"array_agg_distinct",
ImmutableSet.of(QueryRunnerTestHelper.MARKET_DIMENSION),
"acc",
"[]",
null,
"array_set_add(acc, market)",
"array_set_add_all(acc, array_agg_distinct)",
null,
null,
null,
TestExprMacroTable.INSTANCE
)
)
)
.descending(descending)
.context(makeContext())
.build();
List<Result<TimeseriesResultValue>> expectedResults = Arrays.asList(
new Result<>(
DateTimes.of("2011-04-01"),
new TimeseriesResultValue(
ImmutableMap.of(
"diy_count", 13L,
"diy_sum", 6626.151569,
"diy_decomposed_sum", 6626.151569,
"array_agg_distinct", new String[] {"upfront", "spot", "total_market"}
)
)
),
new Result<>(
DateTimes.of("2011-04-02"),
new TimeseriesResultValue(
ImmutableMap.of(
"diy_count", 13L,
"diy_sum", 5833.209718,
"diy_decomposed_sum", 5833.209718,
"array_agg_distinct", new String[] {"upfront", "spot", "total_market"}
)
)
)
);
Iterable<Result<TimeseriesResultValue>> results = runner.run(QueryPlus.wrap(query)).toList();
assertExpectedResults(expectedResults, results);
}
private Map<String, Object> makeContext()
{
return makeContext(ImmutableMap.of());

View File

@ -22,6 +22,7 @@ package org.apache.druid.query.topn;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
@ -52,6 +53,7 @@ import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory;
import org.apache.druid.query.aggregation.FloatMinAggregatorFactory;
@ -5964,6 +5966,110 @@ public class TopNQueryRunnerTest extends InitializedNullHandlingTest
assertExpectedResults(expectedResults, query);
}
@Test
public void testExpressionAggregator()
{
// sorted by array length of array_agg_distinct
TopNQuery query = new TopNQueryBuilder()
.dataSource(QueryRunnerTestHelper.DATA_SOURCE)
.granularity(QueryRunnerTestHelper.ALL_GRAN)
.dimension(QueryRunnerTestHelper.MARKET_DIMENSION)
.metric("array_agg_distinct")
.threshold(4)
.intervals(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
.aggregators(
Lists.newArrayList(
Arrays.asList(
new ExpressionLambdaAggregatorFactory(
"diy_count",
Collections.emptySet(),
null,
"0",
null,
"__acc + 1",
"__acc + diy_count",
null,
null,
null,
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"diy_sum",
ImmutableSet.of("index"),
null,
"0.0",
null,
"__acc + index",
null,
null,
null,
null,
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"diy_decomposed_sum",
ImmutableSet.of("index"),
null,
"0.0",
"<DOUBLE>[]",
"__acc + index",
"array_concat(__acc, diy_decomposed_sum)",
null,
"fold((x, acc) -> x + acc, o, 0.0)",
null,
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"array_agg_distinct",
ImmutableSet.of(QueryRunnerTestHelper.QUALITY_DIMENSION),
"acc",
"[]",
null,
"array_set_add(acc, quality)",
"array_set_add_all(acc, array_agg_distinct)",
"if(array_length(o1) > array_length(o2), 1, if (array_length(o1) == array_length(o2), 0, -1))",
null,
null,
TestExprMacroTable.INSTANCE
)
)
)
)
.build();
List<Result<TopNResultValue>> expectedResults = Collections.singletonList(
new Result<>(
DateTimes.of("2011-01-12T00:00:00.000Z"),
new TopNResultValue(
Arrays.<Map<String, Object>>asList(
ImmutableMap.<String, Object>builder()
.put(QueryRunnerTestHelper.MARKET_DIMENSION, "spot")
.put("diy_count", 837L)
.put("diy_sum", 95606.57232284546D)
.put("diy_decomposed_sum", 95606.57232284546D)
.put("array_agg_distinct", new String[]{"mezzanine", "news", "premium", "business", "entertainment", "health", "technology", "automotive", "travel"})
.build(),
ImmutableMap.<String, Object>builder()
.put(QueryRunnerTestHelper.MARKET_DIMENSION, "total_market")
.put("diy_count", 186L)
.put("diy_sum", 215679.82879638672D)
.put("diy_decomposed_sum", 215679.82879638672D)
.put("array_agg_distinct", new String[]{"mezzanine", "premium"})
.build(),
ImmutableMap.<String, Object>builder()
.put(QueryRunnerTestHelper.MARKET_DIMENSION, "upfront")
.put("diy_count", 186L)
.put("diy_sum", 192046.1060180664D)
.put("diy_decomposed_sum", 192046.1060180664D)
.put("array_agg_distinct", new String[]{"mezzanine", "premium"})
.build()
)
)
)
);
assertExpectedResults(expectedResults, query);
}
private static Map<String, Object> makeRowWithNulls(
String dimName,
@Nullable Object dimValue,

View File

@ -32,6 +32,7 @@ import org.apache.druid.guice.GuiceAnnotationIntrospector;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.Result;
import org.apache.druid.query.expression.TestExprMacroTable;
@ -352,7 +353,9 @@ public class TestHelper
final Object expectedValue = expectedMap.get(key);
final Object actualValue = actualMap.get(key);
if (expectedValue instanceof Float || expectedValue instanceof Double) {
if (expectedValue != null && expectedValue.getClass().isArray()) {
Assert.assertArrayEquals((Object[]) expectedValue, (Object[]) actualValue);
} else if (expectedValue instanceof Float || expectedValue instanceof Double) {
Assert.assertEquals(
StringUtils.format("%s: key[%s]", msg, key),
((Number) expectedValue).doubleValue(),
@ -382,7 +385,23 @@ public class TestHelper
final Object expectedValue = expected.get(i);
final Object actualValue = actual.get(i);
if (expectedValue instanceof Float || expectedValue instanceof Double) {
if (expectedValue != null && expectedValue.getClass().isArray()) {
// spilled results will materialize into lists, coerce them back to arrays if we expected arrays
if (actualValue instanceof List) {
Assert.assertEquals(
message,
(Object[]) expectedValue,
(Object[]) ExprEval.coerceListToArray((List) actualValue, true)
);
} else {
Assert.assertArrayEquals(
message,
(Object[]) expectedValue,
(Object[]) actualValue
);
}
} else if (expectedValue instanceof Float || expectedValue instanceof Double) {
Assert.assertEquals(
message,
((Number) expectedValue).doubleValue(),

View File

@ -243,77 +243,6 @@ public class ExpressionSelectorsTest extends InitializedNullHandlingTest
}
@Test
public void test_coerceListToArray()
{
List<Long> longList = ImmutableList.of(1L, 2L, 3L);
Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExpressionSelectors.coerceListToArray(longList));
List<Integer> intList = ImmutableList.of(1, 2, 3);
Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExpressionSelectors.coerceListToArray(intList));
List<Float> floatList = ImmutableList.of(1.0f, 2.0f, 3.0f);
Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExpressionSelectors.coerceListToArray(floatList));
List<Double> doubleList = ImmutableList.of(1.0, 2.0, 3.0);
Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExpressionSelectors.coerceListToArray(doubleList));
List<String> stringList = ImmutableList.of("a", "b", "c");
Assert.assertArrayEquals(new String[]{"a", "b", "c"}, (String[]) ExpressionSelectors.coerceListToArray(stringList));
List<String> withNulls = new ArrayList<>();
withNulls.add("a");
withNulls.add(null);
withNulls.add("c");
Assert.assertArrayEquals(new String[]{"a", null, "c"}, (String[]) ExpressionSelectors.coerceListToArray(withNulls));
List<Long> withNumberNulls = new ArrayList<>();
withNumberNulls.add(1L);
withNumberNulls.add(null);
withNumberNulls.add(3L);
Assert.assertArrayEquals(new Long[]{1L, null, 3L}, (Long[]) ExpressionSelectors.coerceListToArray(withNumberNulls));
List<Object> withStringMix = ImmutableList.of(1L, "b", 3L);
Assert.assertArrayEquals(
new String[]{"1", "b", "3"},
(String[]) ExpressionSelectors.coerceListToArray(withStringMix)
);
List<Number> withIntsAndLongs = ImmutableList.of(1, 2L, 3);
Assert.assertArrayEquals(
new Long[]{1L, 2L, 3L},
(Long[]) ExpressionSelectors.coerceListToArray(withIntsAndLongs)
);
List<Number> withFloatsAndLongs = ImmutableList.of(1, 2L, 3.0f);
Assert.assertArrayEquals(
new Double[]{1.0, 2.0, 3.0},
(Double[]) ExpressionSelectors.coerceListToArray(withFloatsAndLongs)
);
List<Number> withDoublesAndLongs = ImmutableList.of(1, 2L, 3.0);
Assert.assertArrayEquals(
new Double[]{1.0, 2.0, 3.0},
(Double[]) ExpressionSelectors.coerceListToArray(withDoublesAndLongs)
);
List<Number> withFloatsAndDoubles = ImmutableList.of(1L, 2.0f, 3.0);
Assert.assertArrayEquals(
new Double[]{1.0, 2.0, 3.0},
(Double[]) ExpressionSelectors.coerceListToArray(withFloatsAndDoubles)
);
List<String> withAllNulls = new ArrayList<>();
withAllNulls.add(null);
withAllNulls.add(null);
withAllNulls.add(null);
Assert.assertArrayEquals(
new String[]{null, null, null},
(String[]) ExpressionSelectors.coerceListToArray(withAllNulls)
);
}
@Test
public void test_coerceEvalToSelectorObject()
{

View File

@ -1100,6 +1100,8 @@ arr1
arr2
array_append
array_concat
array_set_add
array_set_add_all
array_contains
array_length
array_offset