diff --git a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java index dd64629e269..5edb2fe7069 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java @@ -183,9 +183,9 @@ class NullLongExpr extends ConstantExpr class LongArrayExpr extends ConstantExpr { - LongArrayExpr(Long[] value) + LongArrayExpr(@Nullable Long[] value) { - super(ExprType.LONG_ARRAY, Preconditions.checkNotNull(value, "value")); + super(ExprType.LONG_ARRAY, value); } @Override @@ -320,9 +320,9 @@ class NullDoubleExpr extends ConstantExpr class DoubleArrayExpr extends ConstantExpr { - DoubleArrayExpr(Double[] value) + DoubleArrayExpr(@Nullable Double[] value) { - super(ExprType.DOUBLE_ARRAY, Preconditions.checkNotNull(value, "value")); + super(ExprType.DOUBLE_ARRAY, value); } @Override @@ -426,9 +426,9 @@ class StringExpr extends ConstantExpr class StringArrayExpr extends ConstantExpr { - StringArrayExpr(String[] value) + StringArrayExpr(@Nullable String[] value) { - super(ExprType.STRING_ARRAY, Preconditions.checkNotNull(value, "value")); + super(ExprType.STRING_ARRAY, value); } @Override diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java index 52e57309670..1a9a9628705 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java @@ -23,15 +23,357 @@ import com.google.common.primitives.Doubles; import org.apache.druid.common.config.NullHandling; import org.apache.druid.common.guava.GuavaUtils; import org.apache.druid.java.util.common.IAE; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.UOE; import javax.annotation.Nullable; +import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.List; /** * Generic result holder for evaluated {@link Expr} containing the value and {@link ExprType} of the value to allow */ public abstract class ExprEval { + private static final int NULL_LENGTH = -1; + + /** + * Deserialize an expression stored in a bytebuffer, e.g. for an agg. + * + * This should be refactored to be consolidated with some of the standard type handling of aggregators probably + */ + public static ExprEval deserialize(ByteBuffer buffer, int position) + { + // | expression type (byte) | expression bytes | + ExprType type = ExprType.fromByte(buffer.get(position)); + int offset = position + 1; + switch (type) { + case LONG: + // | expression type (byte) | is null (byte) | long bytes | + if (buffer.get(offset++) == NullHandling.IS_NOT_NULL_BYTE) { + return of(buffer.getLong(offset)); + } + return ofLong(null); + case DOUBLE: + // | expression type (byte) | is null (byte) | double bytes | + if (buffer.get(offset++) == NullHandling.IS_NOT_NULL_BYTE) { + return of(buffer.getDouble(offset)); + } + return ofDouble(null); + case STRING: + // | expression type (byte) | string length (int) | string bytes | + final int length = buffer.getInt(offset); + if (length < 0) { + return of(null); + } + final byte[] stringBytes = new byte[length]; + final int oldPosition = buffer.position(); + buffer.position(offset + Integer.BYTES); + buffer.get(stringBytes, 0, length); + buffer.position(oldPosition); + return of(StringUtils.fromUtf8(stringBytes)); + case LONG_ARRAY: + // | expression type (byte) | array length (int) | array bytes | + final int longArrayLength = buffer.getInt(offset); + offset += Integer.BYTES; + if (longArrayLength < 0) { + return ofLongArray(null); + } + final Long[] longs = new Long[longArrayLength]; + for (int i = 0; i < longArrayLength; i++) { + final byte isNull = buffer.get(offset); + offset += Byte.BYTES; + if (isNull == NullHandling.IS_NOT_NULL_BYTE) { + // | is null (byte) | long bytes | + longs[i] = buffer.getLong(offset); + offset += Long.BYTES; + } else { + // | is null (byte) | + longs[i] = null; + } + } + return ofLongArray(longs); + case DOUBLE_ARRAY: + // | expression type (byte) | array length (int) | array bytes | + final int doubleArrayLength = buffer.getInt(offset); + offset += Integer.BYTES; + if (doubleArrayLength < 0) { + return ofDoubleArray(null); + } + final Double[] doubles = new Double[doubleArrayLength]; + for (int i = 0; i < doubleArrayLength; i++) { + final byte isNull = buffer.get(offset); + offset += Byte.BYTES; + if (isNull == NullHandling.IS_NOT_NULL_BYTE) { + // | is null (byte) | double bytes | + doubles[i] = buffer.getDouble(offset); + offset += Double.BYTES; + } else { + // | is null (byte) | + doubles[i] = null; + } + } + return ofDoubleArray(doubles); + case STRING_ARRAY: + // | expression type (byte) | array length (int) | array bytes | + final int stringArrayLength = buffer.getInt(offset); + offset += Integer.BYTES; + if (stringArrayLength < 0) { + return ofStringArray(null); + } + final String[] stringArray = new String[stringArrayLength]; + for (int i = 0; i < stringArrayLength; i++) { + final int stringElementLength = buffer.getInt(offset); + offset += Integer.BYTES; + if (stringElementLength < 0) { + // | string length (int) | + stringArray[i] = null; + } else { + // | string length (int) | string bytes | + final byte[] stringElementBytes = new byte[stringElementLength]; + final int oldPosition2 = buffer.position(); + buffer.position(offset); + buffer.get(stringElementBytes, 0, stringElementLength); + buffer.position(oldPosition2); + stringArray[i] = StringUtils.fromUtf8(stringElementBytes); + offset += stringElementLength; + } + } + return ofStringArray(stringArray); + default: + throw new UOE("how can this be?"); + } + } + + /** + * Write an expression result to a bytebuffer, throwing an {@link ISE} if the data exceeds a maximum size. Primitive + * numeric types are not validated to be lower than max size, so it is expected to be at least 10 bytes. Callers + * of this method should enforce this themselves (instead of doing it here, which might be done every row) + * + * This should be refactored to be consolidated with some of the standard type handling of aggregators probably + */ + public static void serialize(ByteBuffer buffer, int position, ExprEval eval, int maxSizeBytes) + { + int offset = position; + buffer.put(offset++, eval.type().getId()); + switch (eval.type()) { + case LONG: + if (eval.isNumericNull()) { + buffer.put(offset, NullHandling.IS_NULL_BYTE); + } else { + buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE); + buffer.putLong(offset, eval.asLong()); + } + break; + case DOUBLE: + if (eval.isNumericNull()) { + buffer.put(offset, NullHandling.IS_NULL_BYTE); + } else { + buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE); + buffer.putDouble(offset, eval.asDouble()); + } + break; + case STRING: + final byte[] stringBytes = StringUtils.toUtf8Nullable(eval.asString()); + if (stringBytes != null) { + // | expression type (byte) | string length (int) | string bytes | + checkMaxBytes(eval.type(), 1 + Integer.BYTES + stringBytes.length, maxSizeBytes); + buffer.putInt(offset, stringBytes.length); + offset += Integer.BYTES; + final int oldPosition = buffer.position(); + buffer.position(offset); + buffer.put(stringBytes, 0, stringBytes.length); + buffer.position(oldPosition); + } else { + checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes); + buffer.putInt(offset, NULL_LENGTH); + } + break; + case LONG_ARRAY: + Long[] longs = eval.asLongArray(); + if (longs == null) { + // | expression type (byte) | array length (int) | + checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes); + buffer.putInt(offset, NULL_LENGTH); + } else { + // | expression type (byte) | array length (int) | array bytes | + final int sizeBytes = 1 + Integer.BYTES + (Long.BYTES * longs.length); + checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes); + buffer.putInt(offset, longs.length); + offset += Integer.BYTES; + for (Long aLong : longs) { + if (aLong != null) { + buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE); + offset++; + buffer.putLong(offset, aLong); + offset += Long.BYTES; + } else { + buffer.put(offset++, NullHandling.IS_NULL_BYTE); + } + } + } + break; + case DOUBLE_ARRAY: + Double[] doubles = eval.asDoubleArray(); + if (doubles == null) { + // | expression type (byte) | array length (int) | + checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes); + buffer.putInt(offset, NULL_LENGTH); + } else { + // | expression type (byte) | array length (int) | array bytes | + final int sizeBytes = 1 + Integer.BYTES + (Double.BYTES * doubles.length); + checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes); + buffer.putInt(offset, doubles.length); + offset += Integer.BYTES; + + for (Double aDouble : doubles) { + if (aDouble != null) { + buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE); + offset++; + buffer.putDouble(offset, aDouble); + offset += Long.BYTES; + } else { + buffer.put(offset++, NullHandling.IS_NULL_BYTE); + } + } + } + break; + case STRING_ARRAY: + String[] strings = eval.asStringArray(); + if (strings == null) { + // | expression type (byte) | array length (int) | + checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes); + buffer.putInt(offset, NULL_LENGTH); + } else { + // | expression type (byte) | array length (int) | array bytes | + buffer.putInt(offset, strings.length); + offset += Integer.BYTES; + int sizeBytes = 1 + Integer.BYTES; + for (String string : strings) { + if (string == null) { + // | string length (int) | + sizeBytes += Integer.BYTES; + checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes); + buffer.putInt(offset, NULL_LENGTH); + offset += Integer.BYTES; + } else { + // | string length (int) | string bytes | + final byte[] stringElementBytes = StringUtils.toUtf8(string); + sizeBytes += Integer.BYTES + stringElementBytes.length; + checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes); + buffer.putInt(offset, stringElementBytes.length); + offset += Integer.BYTES; + final int oldPosition = buffer.position(); + buffer.position(offset); + buffer.put(stringElementBytes, 0, stringElementBytes.length); + buffer.position(oldPosition); + offset += stringElementBytes.length; + } + } + } + break; + default: + throw new UOE("how can this be?"); + } + } + + private static void checkMaxBytes(ExprType type, int sizeBytes, int maxSizeBytes) + { + if (sizeBytes > maxSizeBytes) { + throw new ISE("Unable to serialize [%s], size [%s] is larger than max [%s]", type, sizeBytes, maxSizeBytes); + } + } + + /** + * Converts a List to an appropriate array type, optionally doing some conversion to make multi-valued strings + * consistent across selector types, which are not consistent in treatment of null, [], and [null]. + * + * If homogenizeMultiValueStrings is true, null and [] will be converted to [null], otherwise they will retain + */ + @Nullable + public static Object coerceListToArray(@Nullable List val, boolean homogenizeMultiValueStrings) + { + // if value is not null and has at least 1 element, conversion is unambigous regardless of the selector + if (val != null && val.size() > 0) { + Class coercedType = null; + + for (Object elem : val) { + if (elem != null) { + coercedType = convertType(coercedType, elem.getClass()); + } + } + + if (coercedType == Long.class || coercedType == Integer.class) { + return val.stream().map(x -> x != null ? ((Number) x).longValue() : null).toArray(Long[]::new); + } + if (coercedType == Float.class || coercedType == Double.class) { + return val.stream().map(x -> x != null ? ((Number) x).doubleValue() : null).toArray(Double[]::new); + } + // default to string + return val.stream().map(x -> x != null ? x.toString() : null).toArray(String[]::new); + } + if (homogenizeMultiValueStrings) { + return new String[]{null}; + } else { + if (val != null) { + return val.toArray(); + } + return null; + } + } + + /** + * Find the common type to use between 2 types, useful for choosing the appropriate type for an array given a set + * of objects with unknown type, following rules similar to Java, our own native Expr, and SQL implicit type + * conversions. This is used to assist in preparing native java objects for {@link Expr.ObjectBinding} which will + * later be wrapped in {@link ExprEval} when evaluating {@link IdentifierExpr}. + * + * If any type is string, then the result will be string because everything can be converted to a string, but a string + * cannot be converted to everything. + * + * For numbers, integer is the most restrictive type, only chosen if both types are integers. Longs win over integers, + * floats over longs and integers, and doubles win over everything. + */ + private static Class convertType(@Nullable Class existing, Class next) + { + if (Number.class.isAssignableFrom(next) || next == String.class) { + if (existing == null) { + return next; + } + // string wins everything + if (existing == String.class) { + return existing; + } + if (next == String.class) { + return next; + } + // all numbers win over Integer + if (existing == Integer.class) { + return next; + } + if (existing == Float.class) { + // doubles win over floats + if (next == Double.class) { + return next; + } + return existing; + } + if (existing == Long.class) { + if (next == Integer.class) { + // long beats int + return existing; + } + // double and float win over longs + return next; + } + // otherwise double + return Double.class; + } + throw new UOE("Invalid array expression type: %s", next); + } + public static ExprEval ofLong(@Nullable Number longValue) { return new LongExprEval(longValue); @@ -118,6 +460,13 @@ public abstract class ExprEval return new StringArrayExprEval((String[]) val); } + if (val instanceof List) { + // do not convert empty lists to arrays with a single null element here, because that should have been done + // by the selectors preparing their ObjectBindings if necessary. If we get to this point it was legitimately + // empty + return bestEffortOf(coerceListToArray((List) val, false)); + } + return new StringExprEval(val == null ? null : String.valueOf(val)); } diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprType.java b/core/src/main/java/org/apache/druid/math/expr/ExprType.java index eaacf5612af..80b5e0365b4 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprType.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprType.java @@ -19,6 +19,8 @@ package org.apache.druid.math.expr; +import it.unimi.dsi.fastutil.bytes.Byte2ObjectArrayMap; +import it.unimi.dsi.fastutil.bytes.Byte2ObjectMap; import org.apache.druid.java.util.common.ISE; import org.apache.druid.segment.column.ValueType; @@ -29,13 +31,32 @@ import javax.annotation.Nullable; */ public enum ExprType { - DOUBLE, - LONG, - STRING, - DOUBLE_ARRAY, - LONG_ARRAY, - STRING_ARRAY; + DOUBLE((byte) 0x01), + LONG((byte) 0x02), + STRING((byte) 0x03), + DOUBLE_ARRAY((byte) 0x04), + LONG_ARRAY((byte) 0x05), + STRING_ARRAY((byte) 0x06); + private static final Byte2ObjectMap TYPE_BYTES = new Byte2ObjectArrayMap<>(ExprType.values().length); + + static { + for (ExprType type : ExprType.values()) { + TYPE_BYTES.put(type.getId(), type); + } + } + + final byte id; + + ExprType(byte id) + { + this.id = id; + } + + public byte getId() + { + return id; + } public boolean isNumeric() { @@ -47,6 +68,11 @@ public enum ExprType return isScalar(this); } + public static ExprType fromByte(byte id) + { + return TYPE_BYTES.get(id); + } + /** * The expression system does not distinguish between {@link ValueType#FLOAT} and {@link ValueType#DOUBLE}, and * cannot currently handle {@link ValueType#COMPLEX} inputs. This method will convert {@link ValueType#FLOAT} to @@ -177,5 +203,4 @@ public enum ExprType } return elementType; } - } diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index baa5768f0b7..7b9fe5278d3 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -42,6 +42,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; +import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -384,13 +385,13 @@ public interface Function @Override public Set getScalarInputs(List args) { - return ImmutableSet.of(args.get(1)); + return ImmutableSet.of(getScalarArgument(args)); } @Override public Set getArrayInputs(List args) { - return ImmutableSet.of(args.get(0)); + return ImmutableSet.of(getArrayArgument(args)); } @Override @@ -402,14 +403,24 @@ public interface Function @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - final ExprEval arrayExpr = args.get(0).eval(bindings); - final ExprEval scalarExpr = args.get(1).eval(bindings); + final ExprEval arrayExpr = getArrayArgument(args).eval(bindings); + final ExprEval scalarExpr = getScalarArgument(args).eval(bindings); if (arrayExpr.asArray() == null) { return ExprEval.of(null); } return doApply(arrayExpr, scalarExpr); } + Expr getScalarArgument(List args) + { + return args.get(1); + } + + Expr getArrayArgument(List args) + { + return args.get(0); + } + abstract ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr); } @@ -450,8 +461,11 @@ public interface Function final ExprEval arrayExpr1 = args.get(0).eval(bindings); final ExprEval arrayExpr2 = args.get(1).eval(bindings); - if (arrayExpr1.asArray() == null || arrayExpr2.asArray() == null) { - return ExprEval.of(null); + if (arrayExpr1.asArray() == null) { + return arrayExpr1; + } + if (arrayExpr2.asArray() == null) { + return arrayExpr2; } return doApply(arrayExpr1, arrayExpr2); @@ -460,6 +474,118 @@ public interface Function abstract ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr); } + /** + * Scaffolding for a 2 argument {@link Function} which accepts one array and one scalar input and adds the scalar + * input to the array in some way. + */ + abstract class ArrayAddElementFunction extends ArrayScalarFunction + { + @Override + public boolean hasArrayOutput() + { + return true; + } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingInspector inspector, List args) + { + ExprType arrayType = getArrayArgument(args).getOutputType(inspector); + return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType); + } + + @Override + ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) + { + switch (arrayExpr.type()) { + case STRING: + case STRING_ARRAY: + return ExprEval.ofStringArray(add(arrayExpr.asStringArray(), scalarExpr.asString()).toArray(String[]::new)); + case LONG: + case LONG_ARRAY: + return ExprEval.ofLongArray( + add( + arrayExpr.asLongArray(), + scalarExpr.isNumericNull() ? null : scalarExpr.asLong() + ).toArray(Long[]::new) + ); + case DOUBLE: + case DOUBLE_ARRAY: + return ExprEval.ofDoubleArray( + add( + arrayExpr.asDoubleArray(), + scalarExpr.isNumericNull() ? null : scalarExpr.asDouble() + ).toArray(Double[]::new) + ); + } + + throw new RE("Unable to add to unknown array type %s", arrayExpr.type()); + } + + abstract Stream add(T[] array, @Nullable T val); + } + + /** + * Base scaffolding for functions which accept 2 array arguments and combine them in some way + */ + abstract class ArraysMergeFunction extends ArraysFunction + { + @Override + public Set getArrayInputs(List args) + { + return ImmutableSet.copyOf(args); + } + + @Override + public boolean hasArrayOutput() + { + return true; + } + + @Nullable + @Override + public ExprType getOutputType(Expr.InputBindingInspector inspector, List args) + { + ExprType arrayType = args.get(0).getOutputType(inspector); + return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType); + } + + @Override + ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr) + { + final Object[] array1 = lhsExpr.asArray(); + final Object[] array2 = rhsExpr.asArray(); + + if (array1 == null) { + return ExprEval.of(null); + } + if (array2 == null) { + return lhsExpr; + } + + switch (lhsExpr.type()) { + case STRING: + case STRING_ARRAY: + return ExprEval.ofStringArray( + merge(lhsExpr.asStringArray(), rhsExpr.asStringArray()).toArray(String[]::new) + ); + case LONG: + case LONG_ARRAY: + return ExprEval.ofLongArray( + merge(lhsExpr.asLongArray(), rhsExpr.asLongArray()).toArray(Long[]::new) + ); + case DOUBLE: + case DOUBLE_ARRAY: + return ExprEval.ofDoubleArray( + merge(lhsExpr.asDoubleArray(), rhsExpr.asDoubleArray()).toArray(Double[]::new) + ); + } + throw new RE("Unable to concatenate to unknown type %s", lhsExpr.type()); + } + + abstract Stream merge(T[] array1, T[] array2); + } + abstract class ReduceFunction implements Function { private final DoubleBinaryOperator doubleReducer; @@ -3168,7 +3294,7 @@ public interface Function } } - class ArrayAppendFunction extends ArrayScalarFunction + class ArrayAppendFunction extends ArrayAddElementFunction { @Override public String name() @@ -3177,48 +3303,7 @@ public interface Function } @Override - public boolean hasArrayOutput() - { - return true; - } - - @Nullable - @Override - public ExprType getOutputType(Expr.InputBindingInspector inspector, List args) - { - ExprType arrayType = args.get(0).getOutputType(inspector); - return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType); - } - - @Override - ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) - { - switch (arrayExpr.type()) { - case STRING: - case STRING_ARRAY: - return ExprEval.ofStringArray(this.append(arrayExpr.asStringArray(), scalarExpr.asString()).toArray(String[]::new)); - case LONG: - case LONG_ARRAY: - return ExprEval.ofLongArray( - this.append( - arrayExpr.asLongArray(), - scalarExpr.isNumericNull() ? null : scalarExpr.asLong()).toArray(Long[]::new - ) - ); - case DOUBLE: - case DOUBLE_ARRAY: - return ExprEval.ofDoubleArray( - this.append( - arrayExpr.asDoubleArray(), - scalarExpr.isNumericNull() ? null : scalarExpr.asDouble()).toArray(Double[]::new - ) - ); - } - - throw new RE("Unable to append to unknown type %s", arrayExpr.type()); - } - - private Stream append(T[] array, T val) + Stream add(T[] array, @Nullable T val) { List l = new ArrayList<>(Arrays.asList(array)); l.add(val); @@ -3226,7 +3311,36 @@ public interface Function } } - class ArrayConcatFunction extends ArraysFunction + class ArrayPrependFunction extends ArrayAddElementFunction + { + @Override + public String name() + { + return "array_prepend"; + } + + @Override + Expr getScalarArgument(List args) + { + return args.get(0); + } + + @Override + Expr getArrayArgument(List args) + { + return args.get(1); + } + + @Override + Stream add(T[] array, @Nullable T val) + { + List l = new ArrayList<>(Arrays.asList(array)); + l.add(0, val); + return l.stream(); + } + } + + class ArrayConcatFunction extends ArraysMergeFunction { @Override public String name() @@ -3235,59 +3349,7 @@ public interface Function } @Override - public Set getArrayInputs(List args) - { - return ImmutableSet.copyOf(args); - } - - @Override - public boolean hasArrayOutput() - { - return true; - } - - @Nullable - @Override - public ExprType getOutputType(Expr.InputBindingInspector inspector, List args) - { - ExprType arrayType = args.get(0).getOutputType(inspector); - return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType); - } - - @Override - ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr) - { - final Object[] array1 = lhsExpr.asArray(); - final Object[] array2 = rhsExpr.asArray(); - - if (array1 == null) { - return ExprEval.of(null); - } - if (array2 == null) { - return lhsExpr; - } - - switch (lhsExpr.type()) { - case STRING: - case STRING_ARRAY: - return ExprEval.ofStringArray( - cat(lhsExpr.asStringArray(), rhsExpr.asStringArray()).toArray(String[]::new) - ); - case LONG: - case LONG_ARRAY: - return ExprEval.ofLongArray( - cat(lhsExpr.asLongArray(), rhsExpr.asLongArray()).toArray(Long[]::new) - ); - case DOUBLE: - case DOUBLE_ARRAY: - return ExprEval.ofDoubleArray( - cat(lhsExpr.asDoubleArray(), rhsExpr.asDoubleArray()).toArray(Double[]::new) - ); - } - throw new RE("Unable to concatenate to unknown type %s", lhsExpr.type()); - } - - private Stream cat(T[] array1, T[] array2) + Stream merge(T[] array1, T[] array2) { List l = new ArrayList<>(Arrays.asList(array1)); l.addAll(Arrays.asList(array2)); @@ -3295,6 +3357,40 @@ public interface Function } } + class ArraySetAddFunction extends ArrayAddElementFunction + { + @Override + public String name() + { + return "array_set_add"; + } + + @Override + Stream add(T[] array, @Nullable T val) + { + Set l = new HashSet<>(Arrays.asList(array)); + l.add(val); + return l.stream(); + } + } + + class ArraySetAddAllFunction extends ArraysMergeFunction + { + @Override + public String name() + { + return "array_set_add_all"; + } + + @Override + Stream merge(T[] array1, T[] array2) + { + Set l = new HashSet<>(Arrays.asList(array1)); + l.addAll(Arrays.asList(array2)); + return l.stream(); + } + } + class ArrayContainsFunction extends ArraysFunction { @Override @@ -3438,93 +3534,4 @@ public interface Function throw new RE("Unable to slice to unknown type %s", expr.type()); } } - - class ArrayPrependFunction implements Function - { - @Override - public String name() - { - return "array_prepend"; - } - - @Override - public void validateArguments(List args) - { - if (args.size() != 2) { - throw new IAE("Function[%s] needs 2 arguments", name()); - } - } - - @Nullable - @Override - public ExprType getOutputType(Expr.InputBindingInspector inspector, List args) - { - ExprType arrayType = args.get(1).getOutputType(inspector); - return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType); - } - - @Override - public Set getScalarInputs(List args) - { - return ImmutableSet.of(args.get(0)); - } - - @Override - public Set getArrayInputs(List args) - { - return ImmutableSet.of(args.get(1)); - } - - @Override - public boolean hasArrayInputs() - { - return true; - } - - @Override - public boolean hasArrayOutput() - { - return true; - } - - @Override - public ExprEval apply(List args, Expr.ObjectBinding bindings) - { - final ExprEval scalarExpr = args.get(0).eval(bindings); - final ExprEval arrayExpr = args.get(1).eval(bindings); - if (arrayExpr.asArray() == null) { - return ExprEval.of(null); - } - switch (arrayExpr.type()) { - case STRING: - case STRING_ARRAY: - return ExprEval.ofStringArray(this.prepend(scalarExpr.asString(), arrayExpr.asStringArray()).toArray(String[]::new)); - case LONG: - case LONG_ARRAY: - return ExprEval.ofLongArray( - this.prepend( - scalarExpr.isNumericNull() ? null : scalarExpr.asLong(), - arrayExpr.asLongArray()).toArray(Long[]::new - ) - ); - case DOUBLE: - case DOUBLE_ARRAY: - return ExprEval.ofDoubleArray( - this.prepend( - scalarExpr.isNumericNull() ? null : scalarExpr.asDouble(), - arrayExpr.asDoubleArray()).toArray(Double[]::new - ) - ); - } - - throw new RE("Unable to prepend to unknown type %s", arrayExpr.type()); - } - - private Stream prepend(T val, T[] array) - { - List l = new ArrayList<>(Arrays.asList(array)); - l.add(0, val); - return l.stream(); - } - } } diff --git a/core/src/main/java/org/apache/druid/math/expr/Parser.java b/core/src/main/java/org/apache/druid/math/expr/Parser.java index c9388bff17f..b0c923c025f 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Parser.java +++ b/core/src/main/java/org/apache/druid/math/expr/Parser.java @@ -22,6 +22,7 @@ package org.apache.druid.math.expr; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; @@ -35,12 +36,14 @@ import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.math.expr.antlr.ExprLexer; import org.apache.druid.math.expr.antlr.ExprParser; +import javax.annotation.Nullable; import java.lang.reflect.Modifier; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.UnaryOperator; import java.util.stream.Collectors; public class Parser @@ -96,17 +99,33 @@ public class Parser } /** - * Parse a string into a flattened {@link Expr}. There is some overhead to this, and these objects are all immutable, - * so re-use instead of re-creating whenever possible. + * Create a memoized lazy supplier to parse a string into a flattened {@link Expr}. There is some overhead to this, + * and these objects are all immutable, so this assists in the goal of re-using instead of re-creating whenever + * possible. + * + * Lazy form of {@link #parse(String, ExprMacroTable)} + * + * @param in expression to parse + * @param macroTable additional extensions to expression language + */ + public static Supplier lazyParse(@Nullable String in, ExprMacroTable macroTable) + { + return Suppliers.memoize(() -> in == null ? null : Parser.parse(in, macroTable)); + } + + /** + * Parse a string into a flattened {@link Expr}. There is some overhead to this, and these objects are all immutable, + * so re-use instead of re-creating whenever possible. + * * @param in expression to parse * @param macroTable additional extensions to expression language - * @return */ public static Expr parse(String in, ExprMacroTable macroTable) { return parse(in, macroTable, true); } + @VisibleForTesting public static Expr parse(String in, ExprMacroTable macroTable, boolean withFlatten) { @@ -164,47 +183,43 @@ public class Parser /** * Applies a transformation to an {@link Expr} given a list of known (or uknown) multi-value input columns that are * used in a scalar manner, walking the {@link Expr} tree and lifting array variables into the {@link LambdaExpr} of - * {@link ApplyFunctionExpr} and transforming the arguments of {@link FunctionExpr} - * @param expr expression to visit and rewrite - * @param bindingsToApply - * @return + * {@link ApplyFunctionExpr} and transforming the arguments of {@link FunctionExpr} as necessary. + * + * This function applies a transformation for "map" style uses, such as column selectors, where the supplied + * expression will be transformed to return an array of results instead of the scalar result (or appropriately + * rewritten into existing apply expressions to produce correct results when referenced from a scalar context). + * + * This function and {@link #foldUnappliedBindings(Expr, Expr.BindingAnalysis, List, String)} exist to handle + * "multi-valued" string dimensions, which exist in a superposition of both single and multi-valued during realtime + * ingestion, until they are written to a segment and become locked into either single or multi-valued. This also + * means that multi-valued-ness can vary for a column from segment to segment, so this family of transformation + * functions exist so that multi-valued strings can be expressed in either and array or scalar context, which is + * important because the writer of the query might not actually know if the column is definitively always single or + * multi-valued (and it might in fact not be). + * + * @see #foldUnappliedBindings(Expr, Expr.BindingAnalysis, List, String) */ - public static Expr applyUnappliedBindings(Expr expr, Expr.BindingAnalysis bindingAnalysis, List bindingsToApply) + public static Expr applyUnappliedBindings( + Expr expr, + Expr.BindingAnalysis bindingAnalysis, + List bindingsToApply + ) { if (bindingsToApply.isEmpty()) { // nothing to do, expression is fine as is return expr; } // filter the list of bindings to those which are used in this expression - List unappliedBindingsInExpression = bindingsToApply.stream() - .filter(x -> bindingAnalysis.getRequiredBindings().contains(x)) - .collect(Collectors.toList()); + List unappliedBindingsInExpression = + bindingsToApply.stream() + .filter(x -> bindingAnalysis.getRequiredBindings().contains(x)) + .collect(Collectors.toList()); // any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten - Expr newExpr = expr.visit( - childExpr -> { - if (childExpr instanceof ApplyFunctionExpr) { - // try to lift unapplied arguments into the apply function lambda - return liftApplyLambda((ApplyFunctionExpr) childExpr, unappliedBindingsInExpression); - } else if (childExpr instanceof FunctionExpr) { - // check array function arguments for unapplied identifiers to transform if necessary - FunctionExpr fnExpr = (FunctionExpr) childExpr; - Set arrayInputs = fnExpr.function.getArrayInputs(fnExpr.args); - List newArgs = new ArrayList<>(); - for (Expr arg : fnExpr.args) { - if (arg.getIdentifierIfIdentifier() == null && arrayInputs.contains(arg)) { - Expr newArg = applyUnappliedBindings(arg, bindingAnalysis, unappliedBindingsInExpression); - newArgs.add(newArg); - } else { - newArgs.add(arg); - } - } - - FunctionExpr newFnExpr = new FunctionExpr(fnExpr.function, fnExpr.function.name(), newArgs); - return newFnExpr; - } - return childExpr; - } + Expr newExpr = rewriteUnappliedSubExpressions( + expr, + unappliedBindingsInExpression, + (arg) -> applyUnappliedBindings(arg, bindingAnalysis, bindingsToApply) ); Expr.BindingAnalysis newExprBindings = newExpr.analyzeInputs(); @@ -221,9 +236,123 @@ public class Parser return applyUnapplied(newExpr, remainingUnappliedBindings); } + + /** + * Applies a transformation to an {@link Expr} given a list of known (or uknown) multi-value input columns that are + * used in a scalar manner, walking the {@link Expr} tree and lifting array variables into the {@link LambdaExpr} of + * {@link ApplyFunctionExpr} and transforming the arguments of {@link FunctionExpr} as necessary. + * + * This function applies a transformation for "fold" style uses, such as aggregators, where the supplied + * expression will be transformed to accumulate the result of applying the expression to each value of the unapplied + * input (or appropriately rewritten into existing apply expressions to produce correct results when referenced from + * a scalar context). This rewriting assumes that there exists some accumulator variable, which is re-used as the + * accumulator for this fold rewrite, so that evaluating each expression can be accumulated into the larger external + * fold operation that an aggregator might be performing. + * + * This function and {@link #applyUnappliedBindings(Expr, Expr.BindingAnalysis, List)} exist to handle + * "multi-valued" string dimensions, which exist in a superposition of both single and multi-valued during realtime + * ingestion, until they are written to a segment and become locked into either single or multi-valued. This also + * means that multi-valued-ness can vary for a column from segment to segment, so this family of transformation + * functions exist so that multi-valued strings can be expressed in either and array or scalar context, which is + * important because the writer of the query might not actually know if the column is definitively always single or + * multi-valued (and it might in fact not be). + * + * @see #applyUnappliedBindings(Expr, Expr.BindingAnalysis, List) + */ + public static Expr foldUnappliedBindings(Expr expr, Expr.BindingAnalysis bindingAnalysis, List bindingsToApply, String accumulatorId) + { + if (bindingsToApply.isEmpty()) { + // nothing to do, expression is fine as is + return expr; + } + + // filter the list of bindings to those which are used in this expression + List unappliedBindingsInExpression = + bindingsToApply.stream() + .filter(x -> bindingAnalysis.getRequiredBindings().contains(x)) + .collect(Collectors.toList()); + + Expr newExpr = rewriteUnappliedSubExpressions( + expr, + unappliedBindingsInExpression, + (arg) -> foldUnappliedBindings(arg, bindingAnalysis, bindingsToApply, accumulatorId) + ); + + Expr.BindingAnalysis newExprBindings = newExpr.analyzeInputs(); + final Set expectedArrays = newExprBindings.getArrayVariables(); + + List remainingUnappliedBindings = + unappliedBindingsInExpression.stream().filter(x -> !expectedArrays.contains(x)).collect(Collectors.toList()); + + // if lifting the lambdas got rid of all missing bindings, return the transformed expression + if (remainingUnappliedBindings.isEmpty()) { + return newExpr; + } + + return foldUnapplied(newExpr, remainingUnappliedBindings, accumulatorId); + } + + /** + * Any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten to "lift" + * the identifier variables and transform the function. + * + * For example: + * if "y" is unapplied: + * map((x) -> x + y, x) => cartesian_map((x,y) -> x + y, x, y) + * + * @see #liftApplyLambda(ApplyFunctionExpr, List) + * + * Array functions on expressions using unapplied identifiers might also need transformed, so we recursively call the + * unapplied binding transformation function (supplied to this method) on that expression to ensure proper + * transformation and rewrite of these array expressions. + * + * For example: + * if "y" is unapplied: + * array_length(filter((x) -> x > y, x)) + */ + private static Expr rewriteUnappliedSubExpressions( + Expr expr, + List unappliedBindingsInExpression, + UnaryOperator applyUnappliedFn + ) + { + // any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten + return expr.visit( + childExpr -> { + if (childExpr instanceof ApplyFunctionExpr) { + // try to lift unapplied arguments into the apply function lambda + return liftApplyLambda((ApplyFunctionExpr) childExpr, unappliedBindingsInExpression); + } else if (childExpr instanceof FunctionExpr) { + // check array function arguments for unapplied identifiers to transform if necessary + FunctionExpr fnExpr = (FunctionExpr) childExpr; + Set arrayInputs = fnExpr.function.getArrayInputs(fnExpr.args); + List newArgs = new ArrayList<>(); + for (Expr arg : fnExpr.args) { + if (arg.getIdentifierIfIdentifier() == null && arrayInputs.contains(arg)) { + Expr newArg = applyUnappliedFn.apply(arg); + newArgs.add(newArg); + } else { + newArgs.add(arg); + } + } + + FunctionExpr newFnExpr = new FunctionExpr(fnExpr.function, fnExpr.function.name(), newArgs); + return newFnExpr; + } + return childExpr; + } + ); + } + /** * translate an {@link Expr} into an {@link ApplyFunctionExpr} for {@link ApplyFunction.MapFunction} or - * {@link ApplyFunction.CartesianMapFunction} if there are multiple unbound arguments to be applied + * {@link ApplyFunction.CartesianMapFunction} if there are multiple unbound arguments to be applied. + * + * For example: + * if "x" is unapplied: + * x + y => map((x) -> x + y, x) + * if "x" and "y" are unapplied: + * x + y => cartesian_map((x, y) -> x + y, x, y) */ private static Expr applyUnapplied(Expr expr, List unappliedBindings) { @@ -274,6 +403,72 @@ public class Parser return magic; } + /** + * translate an {@link Expr} into an {@link ApplyFunctionExpr} for {@link ApplyFunction.FoldFunction} or + * {@link ApplyFunction.CartesianFoldFunction} if there are multiple unbound arguments to be applied. + * + * This assumes a known {@link IdentifierExpr} is an "accumulator", which is re-used as the accumulator variable and + * input for the translated fold. + * + * For example given an accumulator "__acc": + * if "x" is unapplied: + * __acc + x => fold((x, __acc) -> x + __acc, x, __acc) + * if "x" and "y" are unapplied: + * __acc + x + y => cartesian_fold((x, y, __acc) -> __acc + x + y, x, y, __acc) + * + */ + private static Expr foldUnapplied(Expr expr, List unappliedBindings, String accumulatorId) + { + + // filter to get list of IdentifierExpr that are backed by the unapplied bindings + final List args = expr.analyzeInputs() + .getFreeVariables() + .stream() + .filter(x -> unappliedBindings.contains(x.getBinding())) + .collect(Collectors.toList()); + + final List lambdaArgs = new ArrayList<>(); + + // construct lambda args from list of args to apply. Identifiers in a lambda body have artificial 'binding' values + // that is the same as the 'identifier', because the bindings are supplied by the wrapping apply function + // replacements are done by binding rather than identifier because repeats of the same input should not result + // in a cartesian product + final Map toReplace = new HashMap<>(); + for (IdentifierExpr applyFnArg : args) { + if (!toReplace.containsKey(applyFnArg.getBinding())) { + IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getBinding()); + lambdaArgs.add(lambdaRewrite); + toReplace.put(applyFnArg.getBinding(), lambdaRewrite); + } + } + + lambdaArgs.add(new IdentifierExpr(accumulatorId)); + + // rewrite identifiers in the expression which will become the lambda body, so they match the lambda identifiers we + // are constructing + Expr newExpr = expr.visit(childExpr -> { + if (childExpr instanceof IdentifierExpr) { + if (toReplace.containsKey(((IdentifierExpr) childExpr).getBinding())) { + return toReplace.get(((IdentifierExpr) childExpr).getBinding()); + } + } + return childExpr; + }); + + + // wrap an expression in either fold or cartesian_fold to apply any unapplied identifiers + final LambdaExpr lambdaExpr = new LambdaExpr(lambdaArgs, newExpr); + final ApplyFunction fn; + if (lambdaArgs.size() == 2) { + fn = new ApplyFunction.FoldFunction(); + } else { + fn = new ApplyFunction.CartesianFoldFunction(); + } + + final Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, ImmutableList.copyOf(lambdaArgs)); + return magic; + } + /** * Performs partial lifting of free identifiers of the lambda expression of an {@link ApplyFunctionExpr}, constrained * by a list of "unapplied" identifiers, and translating them into arguments of a new {@link LambdaExpr} and diff --git a/core/src/main/java/org/apache/druid/math/expr/SettableObjectBinding.java b/core/src/main/java/org/apache/druid/math/expr/SettableObjectBinding.java new file mode 100644 index 00000000000..8b414e538e2 --- /dev/null +++ b/core/src/main/java/org/apache/druid/math/expr/SettableObjectBinding.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.math.expr; + +import com.google.common.collect.Maps; + +import javax.annotation.Nullable; +import java.util.HashMap; +import java.util.Map; + +/** + * Simple map backed object binding + */ +public class SettableObjectBinding implements Expr.ObjectBinding +{ + private final Map bindings; + + public SettableObjectBinding() + { + this.bindings = new HashMap<>(); + } + + public SettableObjectBinding(int expectedSize) + { + this.bindings = Maps.newHashMapWithExpectedSize(expectedSize); + } + + @Nullable + @Override + public Object get(String name) + { + return bindings.get(name); + } + + public SettableObjectBinding withBinding(String name, @Nullable Object value) + { + bindings.put(name, value); + return this; + } +} diff --git a/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java b/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java new file mode 100644 index 00000000000..b15f321f328 --- /dev/null +++ b/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.math.expr; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +public class ExprEvalTest extends InitializedNullHandlingTest +{ + private static int MAX_SIZE_BYTES = 1 << 13; + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + ByteBuffer buffer = ByteBuffer.allocate(1 << 16); + + @Test + public void testStringSerde() + { + assertExpr(0, "hello"); + assertExpr(1234, "hello"); + assertExpr(0, ExprEval.bestEffortOf(null)); + } + + @Test + public void testStringSerdeTooBig() + { + expectedException.expect(ISE.class); + expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.STRING, 16, 10)); + assertExpr(0, ExprEval.of("hello world"), 10); + } + + + @Test + public void testLongSerde() + { + assertExpr(0, 1L); + assertExpr(1234, 1L); + assertExpr(1234, ExprEval.ofLong(null)); + } + + @Test + public void testDoubleSerde() + { + assertExpr(0, 1.123); + assertExpr(1234, 1.123); + assertExpr(1234, ExprEval.ofDouble(null)); + } + + @Test + public void testStringArraySerde() + { + assertExpr(0, new String[] {"hello", "hi", "hey"}); + assertExpr(1024, new String[] {"hello", null, "hi", "hey"}); + assertExpr(2048, new String[] {}); + } + + @Test + public void testStringArraySerdeToBig() + { + expectedException.expect(ISE.class); + expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.STRING_ARRAY, 14, 10)); + assertExpr(0, ExprEval.ofStringArray(new String[] {"hello", "hi", "hey"}), 10); + } + + @Test + public void testLongArraySerde() + { + assertExpr(0, new Long[] {1L, 2L, 3L}); + assertExpr(1234, new Long[] {1L, 2L, null, 3L}); + assertExpr(1234, new Long[] {}); + } + + @Test + public void testLongArraySerdeTooBig() + { + expectedException.expect(ISE.class); + expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.LONG_ARRAY, 29, 10)); + assertExpr(0, ExprEval.ofLongArray(new Long[] {1L, 2L, 3L}), 10); + } + + @Test + public void testDoubleArraySerde() + { + assertExpr(0, new Double[] {1.1, 2.2, 3.3}); + assertExpr(1234, new Double[] {1.1, 2.2, null, 3.3}); + assertExpr(1234, new Double[] {}); + } + + @Test + public void testDoubleArraySerdeTooBig() + { + expectedException.expect(ISE.class); + expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.DOUBLE_ARRAY, 29, 10)); + assertExpr(0, ExprEval.ofDoubleArray(new Double[] {1.1, 2.2, 3.3}), 10); + } + + @Test + public void test_coerceListToArray() + { + Assert.assertNull(ExprEval.coerceListToArray(null, false)); + Assert.assertArrayEquals(new Object[0], (Object[]) ExprEval.coerceListToArray(ImmutableList.of(), false)); + Assert.assertArrayEquals(new String[]{null}, (String[]) ExprEval.coerceListToArray(null, true)); + Assert.assertArrayEquals(new String[]{null}, (String[]) ExprEval.coerceListToArray(ImmutableList.of(), true)); + + List longList = ImmutableList.of(1L, 2L, 3L); + Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExprEval.coerceListToArray(longList, false)); + + List intList = ImmutableList.of(1, 2, 3); + Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExprEval.coerceListToArray(intList, false)); + + List floatList = ImmutableList.of(1.0f, 2.0f, 3.0f); + Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExprEval.coerceListToArray(floatList, false)); + + List doubleList = ImmutableList.of(1.0, 2.0, 3.0); + Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExprEval.coerceListToArray(doubleList, false)); + + List stringList = ImmutableList.of("a", "b", "c"); + Assert.assertArrayEquals(new String[]{"a", "b", "c"}, (String[]) ExprEval.coerceListToArray(stringList, false)); + + List withNulls = new ArrayList<>(); + withNulls.add("a"); + withNulls.add(null); + withNulls.add("c"); + Assert.assertArrayEquals(new String[]{"a", null, "c"}, (String[]) ExprEval.coerceListToArray(withNulls, false)); + + List withNumberNulls = new ArrayList<>(); + withNumberNulls.add(1L); + withNumberNulls.add(null); + withNumberNulls.add(3L); + + Assert.assertArrayEquals(new Long[]{1L, null, 3L}, (Long[]) ExprEval.coerceListToArray(withNumberNulls, false)); + + List withStringMix = ImmutableList.of(1L, "b", 3L); + Assert.assertArrayEquals( + new String[]{"1", "b", "3"}, + (String[]) ExprEval.coerceListToArray(withStringMix, false) + ); + + List withIntsAndLongs = ImmutableList.of(1, 2L, 3); + Assert.assertArrayEquals( + new Long[]{1L, 2L, 3L}, + (Long[]) ExprEval.coerceListToArray(withIntsAndLongs, false) + ); + + List withFloatsAndLongs = ImmutableList.of(1, 2L, 3.0f); + Assert.assertArrayEquals( + new Double[]{1.0, 2.0, 3.0}, + (Double[]) ExprEval.coerceListToArray(withFloatsAndLongs, false) + ); + + List withDoublesAndLongs = ImmutableList.of(1, 2L, 3.0); + Assert.assertArrayEquals( + new Double[]{1.0, 2.0, 3.0}, + (Double[]) ExprEval.coerceListToArray(withDoublesAndLongs, false) + ); + + List withFloatsAndDoubles = ImmutableList.of(1L, 2.0f, 3.0); + Assert.assertArrayEquals( + new Double[]{1.0, 2.0, 3.0}, + (Double[]) ExprEval.coerceListToArray(withFloatsAndDoubles, false) + ); + + List withAllNulls = new ArrayList<>(); + withAllNulls.add(null); + withAllNulls.add(null); + withAllNulls.add(null); + Assert.assertArrayEquals( + new String[]{null, null, null}, + (String[]) ExprEval.coerceListToArray(withAllNulls, false) + ); + } + + private void assertExpr(int position, Object expected) + { + assertExpr(position, ExprEval.bestEffortOf(expected)); + } + + private void assertExpr(int position, ExprEval expected) + { + assertExpr(position, expected, MAX_SIZE_BYTES); + } + + private void assertExpr(int position, ExprEval expected, int maxSizeBytes) + { + ExprEval.serialize(buffer, position, expected, maxSizeBytes); + if (ExprType.isArray(expected.type())) { + Assert.assertArrayEquals(expected.asArray(), ExprEval.deserialize(buffer, position).asArray()); + } else { + Assert.assertEquals(expected.value(), ExprEval.deserialize(buffer, position).value()); + } + } +} diff --git a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java index bd729fe89d7..1bd423f57fb 100644 --- a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java @@ -290,6 +290,27 @@ public class FunctionTest extends InitializedNullHandlingTest assertArrayExpr("array_concat(0, 1)", new Long[]{0L, 1L}); } + @Test + public void testArraySetAdd() + { + assertArrayExpr("array_set_add([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L}); + assertArrayExpr("array_set_add([1, 2, 3], 'bar')", new Long[]{null, 1L, 2L, 3L}); + assertArrayExpr("array_set_add([1, 2, 2], 1)", new Long[]{1L, 2L}); + assertArrayExpr("array_set_add([], 1)", new String[]{"1"}); + assertArrayExpr("array_set_add([], 1)", new Long[]{1L}); + assertArrayExpr("array_set_add([], null)", new Long[]{null}); + } + + @Test + public void testArraySetAddAll() + { + assertArrayExpr("array_set_add_all([1, 2, 3], [2, 4, 6])", new Long[]{1L, 2L, 3L, 4L, 6L}); + assertArrayExpr("array_set_add_all([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L}); + assertArrayExpr("array_set_add_all(0, [1, 2, 3])", new Long[]{0L, 1L, 2L, 3L}); + assertArrayExpr("array_set_add_all(map(y -> y * 3, b), [1, 2, 3])", new Long[]{1L, 2L, 3L, 6L, 9L, 12L, 15L}); + assertArrayExpr("array_set_add_all(0, 1)", new Long[]{0L, 1L}); + } + @Test public void testArrayToString() { diff --git a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java index 6998a0e93d3..51f991f3f8f 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java @@ -528,6 +528,48 @@ public class ParserTest extends InitializedNullHandlingTest ); } + @Test + public void testFoldUnapplied() + { + validateFoldUnapplied("x + __acc", "(+ x __acc)", "(+ x __acc)", ImmutableList.of(), "__acc"); + validateFoldUnapplied("x + __acc", "(+ x __acc)", "(+ x __acc)", ImmutableList.of("z"), "__acc"); + validateFoldUnapplied( + "x + __acc", + "(+ x __acc)", + "(fold ([x, __acc] -> (+ x __acc)), [x, __acc])", + ImmutableList.of("x"), + "__acc" + ); + validateFoldUnapplied( + "x + y + __acc", + "(+ (+ x y) __acc)", + "(cartesian_fold ([x, y, __acc] -> (+ (+ x y) __acc)), [x, y, __acc])", + ImmutableList.of("x", "y"), + "__acc" + ); + validateFoldUnapplied( + "__acc + z + fold((x, acc) -> acc + x + y, x, 0)", + "(+ (+ __acc z) (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))", + "(fold ([z, __acc] -> (+ (+ __acc z) (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))), [z, __acc])", + ImmutableList.of("z"), + "__acc" + ); + validateFoldUnapplied( + "__acc + z + fold((x, acc) -> acc + x + y, x, 0)", + "(+ (+ __acc z) (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))", + "(fold ([z, __acc] -> (+ (+ __acc z) (cartesian_fold ([x, y, acc] -> (+ (+ acc x) y)), [x, y, 0]))), [z, __acc])", + ImmutableList.of("y", "z"), + "__acc" + ); + validateFoldUnapplied( + "__acc + fold((x, acc) -> x + y + acc, x, __acc)", + "(+ __acc (fold ([x, acc] -> (+ (+ x y) acc)), [x, __acc]))", + "(+ __acc (cartesian_fold ([x, y, acc] -> (+ (+ x y) acc)), [x, y, __acc]))", + ImmutableList.of("y"), + "__acc" + ); + } + @Test public void testUniquify() { @@ -666,6 +708,33 @@ public class ParserTest extends InitializedNullHandlingTest Assert.assertEquals(transformed.stringify(), transformedRoundTrip.stringify()); } + private void validateFoldUnapplied( + String expression, + String unapplied, + String applied, + List identifiers, + String accumulator + ) + { + final Expr parsed = Parser.parse(expression, ExprMacroTable.nil()); + Expr.BindingAnalysis deets = parsed.analyzeInputs(); + Parser.validateExpr(parsed, deets); + final Expr transformed = Parser.foldUnappliedBindings(parsed, deets, identifiers, accumulator); + Assert.assertEquals(expression, unapplied, parsed.toString()); + Assert.assertEquals(applied, applied, transformed.toString()); + + final Expr parsedNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false); + final Expr parsedRoundTrip = Parser.parse(parsedNoFlatten.stringify(), ExprMacroTable.nil()); + Expr.BindingAnalysis roundTripDeets = parsedRoundTrip.analyzeInputs(); + Parser.validateExpr(parsedRoundTrip, roundTripDeets); + final Expr transformedRoundTrip = Parser.foldUnappliedBindings(parsedRoundTrip, roundTripDeets, identifiers, accumulator); + Assert.assertEquals(expression, unapplied, parsedRoundTrip.toString()); + Assert.assertEquals(applied, applied, transformedRoundTrip.toString()); + + Assert.assertEquals(parsed.stringify(), parsedRoundTrip.stringify()); + Assert.assertEquals(transformed.stringify(), transformedRoundTrip.stringify()); + } + private void validateConstantExpression(String expression, Object expected) { Expr parsed = Parser.parse(expression, ExprMacroTable.nil()); diff --git a/core/src/test/java/org/apache/druid/math/expr/VectorExprSanityTest.java b/core/src/test/java/org/apache/druid/math/expr/VectorExprSanityTest.java index 6d769e06713..7b005ee60d6 100644 --- a/core/src/test/java/org/apache/druid/math/expr/VectorExprSanityTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/VectorExprSanityTest.java @@ -415,29 +415,6 @@ public class VectorExprSanityTest extends InitializedNullHandlingTest .toArray(String[][]::new); } - static class SettableObjectBinding implements Expr.ObjectBinding - { - private final Map bindings; - - SettableObjectBinding() - { - this.bindings = new HashMap<>(); - } - - @Nullable - @Override - public Object get(String name) - { - return bindings.get(name); - } - - public SettableObjectBinding withBinding(String name, @Nullable Object value) - { - bindings.put(name, value); - return this; - } - } - static class SettableVectorInputBinding implements Expr.VectorInputBinding { private final Map nulls; diff --git a/docs/misc/math-expr.md b/docs/misc/math-expr.md index 0fbe40dcb96..174e7116b86 100644 --- a/docs/misc/math-expr.md +++ b/docs/misc/math-expr.md @@ -177,14 +177,17 @@ See javadoc of java.lang.Math for detailed explanation for each function. | array_offset_of(arr,expr) | returns the 0 based index of the first occurrence of expr in the array, or `-1` or `null` if `druid.generic.useDefaultValueForNull=false`if no matching elements exist in the array. | | array_ordinal_of(arr,expr) | returns the 1 based index of the first occurrence of expr in the array, or `-1` or `null` if `druid.generic.useDefaultValueForNull=false` if no matching elements exist in the array. | | array_prepend(expr,arr) | adds expr to arr at the beginning, the resulting array type determined by the type of the array | -| array_append(arr1,expr) | appends expr to arr, the resulting array type determined by the type of the first array | +| array_append(arr,expr) | appends expr to arr, the resulting array type determined by the type of the first array | | array_concat(arr1,arr2) | concatenates 2 arrays, the resulting array type determined by the type of the first array | +| array_set_add(arr,expr) | adds expr to arr and converts the array to a new array composed of the unique set of elements. The resulting array type determined by the type of the array | +| array_set_add_all(arr1,arr2) | combines the unique set of elements of 2 arrays, the resulting array type determined by the type of the first array | | array_slice(arr,start,end) | return the subarray of arr from the 0 based index start(inclusive) to end(exclusive), or `null`, if start is less than 0, greater than length of arr or less than end| | array_to_string(arr,str) | joins all elements of arr by the delimiter specified by str | | string_to_array(str1,str2) | splits str1 into an array on the delimiter specified by str2 | ## Apply functions +Apply functions allow for special 'lambda' expressions to be defined and applied to array inputs to enable free-form transformations. | function | description | | --- | --- | @@ -197,6 +200,26 @@ See javadoc of java.lang.Math for detailed explanation for each function. | all(lambda,arr) | returns 1 if all elements in the array matches the lambda expression, else 0 | +### Lambda expressions syntax +Lambda expressions are a sort of function definition, where new identifiers can be defined and passed as input to the expression body +``` +(identifier1 ...) -> expr +``` +e.g. +``` +(x, y) -> x + y +``` +The identifier arguments of a lambda expression correspond to the elements of the array it is being applied to. For example: +``` +map((x) -> x + 1, some_multi_value_column) +``` +will map each element of `some_multi_value_column` to the identifier `x` so that the lambda expression body can be evaluated for each `x`. The scoping rules are that lambda arguments will override identifiers which are defined externally from the lambda expression body. Using the same example: + +``` +map((x) -> x + 1, x) +``` +in this case, the `x` when evaluating `x + 1` is the lambda argument, thus an element of the multi-valued column `x`, rather than the column `x` itself. + ## Reduction functions Reduction functions operate on zero or more expressions and return a single expression. If no expressions are passed as diff --git a/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java b/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java index 795ea5b6d31..155b8e7e7b7 100644 --- a/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java +++ b/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java @@ -27,6 +27,7 @@ import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory; import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; +import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory; import org.apache.druid.query.aggregation.FloatMinAggregatorFactory; @@ -120,7 +121,8 @@ public class AggregatorsModule extends SimpleModule @JsonSubTypes.Type(name = "floatAny", value = FloatAnyAggregatorFactory.class), @JsonSubTypes.Type(name = "doubleAny", value = DoubleAnyAggregatorFactory.class), @JsonSubTypes.Type(name = "stringAny", value = StringAnyAggregatorFactory.class), - @JsonSubTypes.Type(name = "grouping", value = GroupingAggregatorFactory.class) + @JsonSubTypes.Type(name = "grouping", value = GroupingAggregatorFactory.class), + @JsonSubTypes.Type(name = "expression", value = ExpressionLambdaAggregatorFactory.class) }) public interface AggregatorFactoryMixin { diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java b/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java index 3c7b8d47439..34a6eada499 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java @@ -137,6 +137,9 @@ public class AggregatorUtil // GROUPING aggregator public static final byte GROUPING_CACHE_TYPE_ID = 0x46; + // expression lambda aggregator + public static final byte EXPRESSION_LAMBDA_CACHE_TYPE_ID = 0x47; + /** * returns the list of dependent postAggregators that should be calculated in order to calculate given postAgg diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java new file mode 100644 index 00000000000..0305c8acaf5 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation; + +import org.apache.druid.math.expr.Expr; + +import javax.annotation.Nullable; + +public class ExpressionLambdaAggregator implements Aggregator +{ + private final Expr lambda; + private final ExpressionLambdaAggregatorInputBindings bindings; + + public ExpressionLambdaAggregator(Expr lambda, ExpressionLambdaAggregatorInputBindings bindings) + { + this.lambda = lambda; + this.bindings = bindings; + } + + @Override + public void aggregate() + { + bindings.accumulate(lambda.eval(bindings)); + } + + @Nullable + @Override + public Object get() + { + return bindings.getAccumulator().value(); + } + + @Override + public float getFloat() + { + return (float) bindings.getAccumulator().asDouble(); + } + + @Override + public long getLong() + { + return bindings.getAccumulator().asLong(); + } + + @Override + public double getDouble() + { + return bindings.getAccumulator().asDouble(); + } + + @Override + public boolean isNull() + { + return bindings.getAccumulator().isNumericNull(); + } + + @Override + public void close() + { + // nothing to close + } +} diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java new file mode 100644 index 00000000000..2da1abde8c0 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java @@ -0,0 +1,516 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation; + +import com.fasterxml.jackson.annotation.JacksonInject; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; +import com.google.common.collect.Iterables; +import org.apache.druid.java.util.common.HumanReadableBytes; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.guava.Comparators; +import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.ExprEval; +import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; +import org.apache.druid.math.expr.Parser; +import org.apache.druid.math.expr.SettableObjectBinding; +import org.apache.druid.query.cache.CacheKeyBuilder; +import org.apache.druid.query.expression.ExprUtils; +import org.apache.druid.segment.ColumnInspector; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ColumnCapabilitiesImpl; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.virtual.ExpressionPlan; +import org.apache.druid.segment.virtual.ExpressionPlanner; +import org.apache.druid.segment.virtual.ExpressionSelectors; + +import javax.annotation.Nullable; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +public class ExpressionLambdaAggregatorFactory extends AggregatorFactory +{ + private static final String FINALIZE_IDENTIFIER = "o"; + private static final String COMPARE_O1 = "o1"; + private static final String COMPARE_O2 = "o2"; + private static final String DEFAULT_ACCUMULATOR_ID = "__acc"; + + // minimum permitted agg size is 10 bytes so it is at least large enough to hold primitive numerics (long, double) + // | expression type byte | is_null byte | primitive value (8 bytes) | + private static final int MIN_SIZE_BYTES = 10; + private static final HumanReadableBytes DEFAULT_MAX_SIZE_BYTES = new HumanReadableBytes(1L << 10); + + private final String name; + @Nullable + private final Set fields; + private final String accumulatorId; + private final String foldExpressionString; + private final String initialValueExpressionString; + private final String initialCombineValueExpressionString; + + private final String combineExpressionString; + @Nullable + private final String compareExpressionString; + @Nullable + private final String finalizeExpressionString; + + private final ExprMacroTable macroTable; + private final Supplier> initialValue; + private final Supplier> initialCombineValue; + private final Supplier foldExpression; + private final Supplier combineExpression; + private final Supplier compareExpression; + private final Supplier finalizeExpression; + private final HumanReadableBytes maxSizeBytes; + + private final Supplier compareBindings = + Suppliers.memoize(() -> new SettableObjectBinding(2)); + private final Supplier combineBindings = + Suppliers.memoize(() -> new SettableObjectBinding(2)); + private final Supplier finalizeBindings = + Suppliers.memoize(() -> new SettableObjectBinding(1)); + + @JsonCreator + public ExpressionLambdaAggregatorFactory( + @JsonProperty("name") String name, + @JsonProperty("fields") @Nullable final Set fields, + @JsonProperty("accumulatorIdentifier") @Nullable final String accumulatorIdentifier, + @JsonProperty("initialValue") final String initialValue, + @JsonProperty("initialCombineValue") @Nullable final String initialCombineValue, + @JsonProperty("fold") final String foldExpression, + @JsonProperty("combine") @Nullable final String combineExpression, + @JsonProperty("compare") @Nullable final String compareExpression, + @JsonProperty("finalize") @Nullable final String finalizeExpression, + @JsonProperty("maxSizeBytes") @Nullable final HumanReadableBytes maxSizeBytes, + @JacksonInject ExprMacroTable macroTable + ) + { + Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name"); + + this.name = name; + this.fields = fields; + this.accumulatorId = accumulatorIdentifier != null ? accumulatorIdentifier : DEFAULT_ACCUMULATOR_ID; + + this.initialValueExpressionString = initialValue; + this.initialCombineValueExpressionString = initialCombineValue == null ? initialValue : initialCombineValue; + this.foldExpressionString = foldExpression; + if (combineExpression != null) { + this.combineExpressionString = combineExpression; + } else { + // if the combine expression is null, allow single input aggregator expressions to be rewritten to replace the + // field with the aggregator name. Fields is null for the combining/merging aggregator, but the expression should + // already be set with the rewritten value at that point + Preconditions.checkArgument( + fields != null && fields.size() == 1, + "Must have a single input field if no combine expression is supplied" + ); + this.combineExpressionString = StringUtils.replace(foldExpression, Iterables.getOnlyElement(fields), name); + } + this.compareExpressionString = compareExpression; + this.finalizeExpressionString = finalizeExpression; + this.macroTable = macroTable; + + this.initialValue = Suppliers.memoize(() -> { + Expr parsed = Parser.parse(initialValue, macroTable); + Preconditions.checkArgument(parsed.isLiteral(), "initial value must be constant"); + return parsed.eval(ExprUtils.nilBindings()); + }); + this.initialCombineValue = Suppliers.memoize(() -> { + Expr parsed = Parser.parse(this.initialCombineValueExpressionString, macroTable); + Preconditions.checkArgument(parsed.isLiteral(), "initial combining value must be constant"); + return parsed.eval(ExprUtils.nilBindings()); + }); + this.foldExpression = Parser.lazyParse(foldExpressionString, macroTable); + this.combineExpression = Parser.lazyParse(combineExpressionString, macroTable); + this.compareExpression = Parser.lazyParse(compareExpressionString, macroTable); + this.finalizeExpression = Parser.lazyParse(finalizeExpressionString, macroTable); + this.maxSizeBytes = maxSizeBytes != null ? maxSizeBytes : DEFAULT_MAX_SIZE_BYTES; + Preconditions.checkArgument(this.maxSizeBytes.getBytesInInt() >= MIN_SIZE_BYTES); + } + + @JsonProperty + @Override + public String getName() + { + return name; + } + + @JsonProperty + @Nullable + public Set getFields() + { + return fields; + } + + @JsonProperty + @Nullable + public String getAccumulatorIdentifier() + { + return accumulatorId; + } + + @JsonProperty("initialValue") + public String getInitialValueExpressionString() + { + return initialValueExpressionString; + } + + @JsonProperty("initialCombineValue") + public String getInitialCombineValueExpressionString() + { + return initialCombineValueExpressionString; + } + + @JsonProperty("fold") + public String getFoldExpressionString() + { + return foldExpressionString; + } + + @JsonProperty("combine") + public String getCombineExpressionString() + { + return combineExpressionString; + } + + @JsonProperty("compare") + @Nullable + public String getCompareExpressionString() + { + return compareExpressionString; + } + + @JsonProperty("finalize") + @Nullable + public String getFinalizeExpressionString() + { + return finalizeExpressionString; + } + + @JsonProperty("maxSizeBytes") + public HumanReadableBytes getMaxSizeBytes() + { + return maxSizeBytes; + } + + @Override + public byte[] getCacheKey() + { + return new CacheKeyBuilder(AggregatorUtil.EXPRESSION_LAMBDA_CACHE_TYPE_ID) + .appendStrings(fields) + .appendString(initialValueExpressionString) + .appendString(initialCombineValueExpressionString) + .appendString(foldExpressionString) + .appendString(combineExpressionString) + .appendString(compareExpressionString) + .appendString(finalizeExpressionString) + .appendInt(maxSizeBytes.getBytesInInt()) + .build(); + } + + @Override + public Aggregator factorize(ColumnSelectorFactory metricFactory) + { + FactorizePlan thePlan = new FactorizePlan(metricFactory); + return new ExpressionLambdaAggregator( + thePlan.getExpression(), + thePlan.getBindings() + ); + } + + @Override + public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) + { + FactorizePlan thePlan = new FactorizePlan(metricFactory); + return new ExpressionLambdaBufferAggregator( + thePlan.getExpression(), + thePlan.getInitialValue(), + thePlan.getBindings(), + maxSizeBytes.getBytesInInt() + ); + } + + @Override + public Comparator getComparator() + { + Expr compareExpr = compareExpression.get(); + if (compareExpr != null) { + return (o1, o2) -> + compareExpr.eval(compareBindings.get().withBinding(COMPARE_O1, o1).withBinding(COMPARE_O2, o2)).asInt(); + } + switch (initialValue.get().type()) { + case LONG: + return LongSumAggregator.COMPARATOR; + case DOUBLE: + return DoubleSumAggregator.COMPARATOR; + default: + return Comparators.naturalNullsFirst(); + } + } + + @Nullable + @Override + public Object combine(@Nullable Object lhs, @Nullable Object rhs) + { + // arbitrarily assign lhs and rhs to accumulator and aggregator name inputs to re-use combine function + return combineExpression.get().eval( + combineBindings.get().withBinding(accumulatorId, lhs).withBinding(name, rhs) + ).value(); + } + + @Override + public Object deserialize(Object object) + { + return object; + } + + @Nullable + @Override + public Object finalizeComputation(@Nullable Object object) + { + Expr finalizeExpr; + finalizeExpr = finalizeExpression.get(); + if (finalizeExpr != null) { + return finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, object)).value(); + } + return object; + } + + @Override + public List requiredFields() + { + if (fields == null) { + return combineExpression.get().analyzeInputs().getRequiredBindingsList(); + } + return foldExpression.get().analyzeInputs().getRequiredBindingsList(); + } + + @Override + public AggregatorFactory getCombiningFactory() + { + return new ExpressionLambdaAggregatorFactory( + name, + null, + accumulatorId, + initialValueExpressionString, + initialCombineValueExpressionString, + foldExpressionString, + combineExpressionString, + compareExpressionString, + finalizeExpressionString, + maxSizeBytes, + macroTable + ); + } + + @Override + public List getRequiredColumns() + { + return Collections.singletonList( + new ExpressionLambdaAggregatorFactory( + name, + fields, + accumulatorId, + initialValueExpressionString, + initialCombineValueExpressionString, + foldExpressionString, + combineExpressionString, + compareExpressionString, + finalizeExpressionString, + maxSizeBytes, + macroTable + ) + ); + } + + @Override + public ValueType getType() + { + if (fields == null) { + return ExprType.toValueType(initialCombineValue.get().type()); + } + return ExprType.toValueType(initialValue.get().type()); + } + + @Override + public ValueType getFinalizedType() + { + Expr finalizeExpr = finalizeExpression.get(); + ExprEval initialVal = initialCombineValue.get(); + if (finalizeExpr != null) { + return ExprType.toValueType( + finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, initialVal)).type() + ); + } + return ExprType.toValueType(initialVal.type()); + } + + @Override + public int getMaxIntermediateSize() + { + // numeric expressions are either longs or doubles, with strings or arrays max size is unknown + // for numeric arguments, the first 2 bytes are used for expression type byte and is_null byte + return getType().isNumeric() ? 2 + Long.BYTES : maxSizeBytes.getBytesInInt(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ExpressionLambdaAggregatorFactory that = (ExpressionLambdaAggregatorFactory) o; + return maxSizeBytes.equals(that.maxSizeBytes) + && name.equals(that.name) + && Objects.equals(fields, that.fields) + && accumulatorId.equals(that.accumulatorId) + && foldExpressionString.equals(that.foldExpressionString) + && initialValueExpressionString.equals(that.initialValueExpressionString) + && initialCombineValueExpressionString.equals(that.initialCombineValueExpressionString) + && combineExpressionString.equals(that.combineExpressionString) + && Objects.equals(compareExpressionString, that.compareExpressionString) + && Objects.equals(finalizeExpressionString, that.finalizeExpressionString); + } + + @Override + public int hashCode() + { + return Objects.hash( + name, + fields, + accumulatorId, + foldExpressionString, + initialValueExpressionString, + initialCombineValueExpressionString, + combineExpressionString, + compareExpressionString, + finalizeExpressionString, + maxSizeBytes + ); + } + + @Override + public String toString() + { + return "ExpressionLambdaAggregatorFactory{" + + "name='" + name + '\'' + + ", fields=" + fields + + ", accumulatorId='" + accumulatorId + '\'' + + ", foldExpressionString='" + foldExpressionString + '\'' + + ", initialValueExpressionString='" + initialValueExpressionString + '\'' + + ", initialCombineValueExpressionString='" + initialCombineValueExpressionString + '\'' + + ", combineExpressionString='" + combineExpressionString + '\'' + + ", compareExpressionString='" + compareExpressionString + '\'' + + ", finalizeExpressionString='" + finalizeExpressionString + '\'' + + ", maxSizeBytes=" + maxSizeBytes + + '}'; + } + + /** + * Determine how to factorize the aggregator + */ + private class FactorizePlan + { + private final ExpressionPlan plan; + + private final ExprEval seed; + private final ExpressionLambdaAggregatorInputBindings bindings; + + FactorizePlan(ColumnSelectorFactory metricFactory) + { + final List columns; + + if (fields != null) { + // if fields are set, we are accumulating from raw inputs, use fold expression + plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), foldExpression.get()); + seed = initialValue.get(); + columns = plan.getAnalysis().getRequiredBindingsList(); + } else { + // else we are merging intermediary results, use combine expression + plan = ExpressionPlanner.plan(inspectorWithAccumulator(metricFactory), combineExpression.get()); + seed = initialCombineValue.get(); + columns = plan.getAnalysis().getRequiredBindingsList(); + } + + bindings = new ExpressionLambdaAggregatorInputBindings( + ExpressionSelectors.createBindings(metricFactory, columns), + accumulatorId, + seed + ); + } + + public Expr getExpression() + { + if (fields == null) { + return plan.getExpression(); + } + // for fold expressions, check to see if it needs transformation due to scalar use of multi-valued or unknown + // inputs + return plan.getAppliedFoldExpression(accumulatorId); + } + + public ExprEval getInitialValue() + { + return seed; + } + + public ExpressionLambdaAggregatorInputBindings getBindings() + { + return bindings; + } + + private ColumnInspector inspectorWithAccumulator(ColumnInspector inspector) + { + return new ColumnInspector() + { + @Nullable + @Override + public ColumnCapabilities getColumnCapabilities(String column) + { + if (accumulatorId.equals(column)) { + return ColumnCapabilitiesImpl.createDefault().setType(ExprType.toValueType(initialValue.get().type())); + } + return inspector.getColumnCapabilities(column); + } + + @Nullable + @Override + public ExprType getType(String name) + { + if (accumulatorId.equals(name)) { + return initialValue.get().type(); + } + return inspector.getType(name); + } + }; + } + } +} diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorInputBindings.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorInputBindings.java new file mode 100644 index 00000000000..5e4864efc77 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorInputBindings.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation; + +import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.ExprEval; + +import javax.annotation.Nullable; + +/** + * Special {@link Expr.ObjectBinding} for use with {@link ExpressionLambdaAggregatorFactory}. + * This value binding holds a value for a special 'accumulator' variable, in addition to the 'normal' bindings to the + * underlying selector inputs for other identifiers, which allows for easy forward feeding of the results of an + * expression evaluation to use in the bindings of the next evaluation. + */ +public class ExpressionLambdaAggregatorInputBindings implements Expr.ObjectBinding +{ + private final Expr.ObjectBinding inputBindings; + private final String accumlatorIdentifier; + private ExprEval accumulator; + + public ExpressionLambdaAggregatorInputBindings( + Expr.ObjectBinding inputBindings, + String accumulatorIdentifier, + ExprEval initialValue + ) + { + this.accumlatorIdentifier = accumulatorIdentifier; + this.inputBindings = inputBindings; + this.accumulator = initialValue; + } + + @Nullable + @Override + public Object get(String name) + { + if (accumlatorIdentifier.equals(name)) { + return accumulator.value(); + } + return inputBindings.get(name); + } + + public void accumulate(ExprEval eval) + { + accumulator = eval; + } + + public ExprEval getAccumulator() + { + return accumulator; + } + + public void setAccumulator(ExprEval acc) + { + this.accumulator = acc; + } +} diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java new file mode 100644 index 00000000000..357dd4b7d6b --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation; + +import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.ExprEval; + +import javax.annotation.Nullable; +import java.nio.ByteBuffer; + +public class ExpressionLambdaBufferAggregator implements BufferAggregator +{ + private final Expr lambda; + private final ExprEval initialValue; + private final ExpressionLambdaAggregatorInputBindings bindings; + private final int maxSizeBytes; + + public ExpressionLambdaBufferAggregator( + Expr lambda, + ExprEval initialValue, + ExpressionLambdaAggregatorInputBindings bindings, + int maxSizeBytes + ) + { + this.lambda = lambda; + this.initialValue = initialValue; + this.bindings = bindings; + this.maxSizeBytes = maxSizeBytes; + } + + @Override + public void init(ByteBuffer buf, int position) + { + ExprEval.serialize(buf, position, initialValue, maxSizeBytes); + } + + @Override + public void aggregate(ByteBuffer buf, int position) + { + ExprEval acc = ExprEval.deserialize(buf, position); + bindings.setAccumulator(acc); + ExprEval newAcc = lambda.eval(bindings); + ExprEval.serialize(buf, position, newAcc, maxSizeBytes); + } + + @Nullable + @Override + public Object get(ByteBuffer buf, int position) + { + return ExprEval.deserialize(buf, position).value(); + } + + @Override + public float getFloat(ByteBuffer buf, int position) + { + return (float) ExprEval.deserialize(buf, position).asDouble(); + } + + @Override + public double getDouble(ByteBuffer buf, int position) + { + return ExprEval.deserialize(buf, position).asDouble(); + } + + @Override + public long getLong(ByteBuffer buf, int position) + { + return ExprEval.deserialize(buf, position).asLong(); + } + + @Override + public void close() + { + // nothing to close + } +} diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java index b8644329c71..c540018dc6c 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java @@ -23,7 +23,6 @@ package org.apache.druid.query.aggregation; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import com.google.common.base.Supplier; -import com.google.common.base.Suppliers; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.Parser; @@ -72,7 +71,7 @@ public abstract class SimpleDoubleAggregatorFactory extends NullableNumericAggre this.fieldName = fieldName; this.expression = expression; this.storeDoubleAsFloat = ColumnHolder.storeDoubleAsFloat(); - this.fieldExpression = Suppliers.memoize(() -> expression == null ? null : Parser.parse(expression, macroTable)); + this.fieldExpression = Parser.lazyParse(expression, macroTable); Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name"); Preconditions.checkArgument( fieldName == null ^ expression == null, diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java index 380ceb14941..03b9f923da6 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java @@ -23,7 +23,6 @@ package org.apache.druid.query.aggregation; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import com.google.common.base.Supplier; -import com.google.common.base.Suppliers; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.Parser; @@ -63,7 +62,7 @@ public abstract class SimpleFloatAggregatorFactory extends NullableNumericAggreg this.name = name; this.fieldName = fieldName; this.expression = expression; - this.fieldExpression = Suppliers.memoize(() -> expression == null ? null : Parser.parse(expression, macroTable)); + this.fieldExpression = Parser.lazyParse(expression, macroTable); Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name"); Preconditions.checkArgument( fieldName == null ^ expression == null, diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java index 7d148d5b6f0..bf297e1d4e9 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java @@ -23,7 +23,6 @@ package org.apache.druid.query.aggregation; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import com.google.common.base.Supplier; -import com.google.common.base.Suppliers; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.Parser; @@ -69,7 +68,7 @@ public abstract class SimpleLongAggregatorFactory extends NullableNumericAggrega this.name = name; this.fieldName = fieldName; this.expression = expression; - this.fieldExpression = Suppliers.memoize(() -> expression == null ? null : Parser.parse(expression, macroTable)); + this.fieldExpression = Parser.lazyParse(expression, macroTable); Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name"); Preconditions.checkArgument( fieldName == null ^ expression == null, diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java index 978bf3e04ab..34cbb16d163 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java @@ -92,7 +92,7 @@ public class ExpressionPostAggregator implements PostAggregator ordering, macroTable, ImmutableMap.of(), - Suppliers.memoize(() -> Parser.parse(expression, macroTable)) + Parser.lazyParse(expression, macroTable) ); } diff --git a/processing/src/main/java/org/apache/druid/query/filter/ExpressionDimFilter.java b/processing/src/main/java/org/apache/druid/query/filter/ExpressionDimFilter.java index 2f65cc64fda..66927332df2 100644 --- a/processing/src/main/java/org/apache/druid/query/filter/ExpressionDimFilter.java +++ b/processing/src/main/java/org/apache/druid/query/filter/ExpressionDimFilter.java @@ -25,7 +25,6 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Supplier; -import com.google.common.base.Suppliers; import com.google.common.collect.RangeSet; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprMacroTable; @@ -53,7 +52,7 @@ public class ExpressionDimFilter extends AbstractOptimizableDimFilter implements { this.expression = expression; this.filterTuning = filterTuning; - this.parsed = Suppliers.memoize(() -> Parser.parse(expression, macroTable)); + this.parsed = Parser.lazyParse(expression, macroTable); } @VisibleForTesting diff --git a/processing/src/main/java/org/apache/druid/segment/transform/ExpressionTransform.java b/processing/src/main/java/org/apache/druid/segment/transform/ExpressionTransform.java index 2ace9b06bf1..caa8dafe87a 100644 --- a/processing/src/main/java/org/apache/druid/segment/transform/ExpressionTransform.java +++ b/processing/src/main/java/org/apache/druid/segment/transform/ExpressionTransform.java @@ -26,6 +26,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Suppliers; import org.apache.druid.data.input.Row; import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.Parser; import org.apache.druid.segment.column.ColumnHolder; @@ -106,7 +107,7 @@ public class ExpressionTransform implements Transform } else { Object raw = row.getRaw(column); if (raw instanceof List) { - return ExpressionSelectors.coerceListToArray((List) raw); + return ExprEval.coerceListToArray((List) raw, true); } return raw; } diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlan.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlan.java index 38a3fc3812c..b1ab3a9010f 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlan.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlan.java @@ -122,6 +122,14 @@ public class ExpressionPlan return expression; } + public Expr getAppliedFoldExpression(String accumulatorId) + { + if (is(Trait.NEEDS_APPLIED)) { + return Parser.foldUnappliedBindings(expression, analysis, unappliedInputs, accumulatorId); + } + return expression; + } + public Expr.BindingAnalysis getAnalysis() { return analysis; diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java index 0ff00b68c05..d910d7b54f7 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java @@ -24,7 +24,6 @@ import com.google.common.base.Preconditions; import com.google.common.base.Supplier; import com.google.common.collect.Iterables; import org.apache.druid.common.config.NullHandling; -import org.apache.druid.java.util.common.UOE; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.Parser; @@ -242,13 +241,26 @@ public class ExpressionSelectors * provides the set of identifiers which need a binding (list of required columns), and context of whether or not they * are used as array or scalar inputs */ - private static Expr.ObjectBinding createBindings( + public static Expr.ObjectBinding createBindings( Expr.BindingAnalysis bindingAnalysis, ColumnSelectorFactory columnSelectorFactory ) { - final Map> suppliers = new HashMap<>(); final List columns = bindingAnalysis.getRequiredBindingsList(); + return createBindings(columnSelectorFactory, columns); + } + + /** + * Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link Expr.BindingAnalysis} which + * provides the set of identifiers which need a binding (list of required columns), and context of whether or not they + * are used as array or scalar inputs + */ + public static Expr.ObjectBinding createBindings( + ColumnSelectorFactory columnSelectorFactory, + List columns + ) + { + final Map> suppliers = new HashMap<>(); for (String columnName : columns) { final ColumnCapabilities columnCapabilities = columnSelectorFactory.getColumnCapabilities(columnName); final ValueType nativeType = columnCapabilities != null ? columnCapabilities.getType() : null; @@ -269,8 +281,8 @@ public class ExpressionSelectors columnSelectorFactory.makeDimensionSelector(new DefaultDimensionSpec(columnName, columnName)), multiVal ); - } else if (nativeType == null) { - // Unknown ValueType. Try making an Object selector and see if that gives us anything useful. + } else if (nativeType == null || ValueType.isArray(nativeType)) { + // Unknown ValueType or array type. Try making an Object selector and see if that gives us anything useful. supplier = supplierFromObjectSelector(columnSelectorFactory.makeColumnValueSelector(columnName)); } else { // Unhandleable ValueType (COMPLEX). @@ -370,10 +382,10 @@ public class ExpressionSelectors // Might be Numbers and Strings. Use a selector that double-checks. return () -> { final Object val = selector.getObject(); - if (val instanceof Number || val instanceof String) { + if (val instanceof Number || val instanceof String || (val != null && val.getClass().isArray())) { return val; } else if (val instanceof List) { - return coerceListToArray((List) val); + return ExprEval.coerceListToArray((List) val, true); } else { return null; } @@ -382,7 +394,7 @@ public class ExpressionSelectors return () -> { final Object val = selector.getObject(); if (val != null) { - return coerceListToArray((List) val); + return ExprEval.coerceListToArray((List) val, true); } return null; }; @@ -392,70 +404,6 @@ public class ExpressionSelectors } } - /** - * Selectors are not consistent in treatment of null, [], and [null], so coerce [] to [null] - */ - public static Object coerceListToArray(@Nullable List val) - { - if (val != null && val.size() > 0) { - Class coercedType = null; - - for (Object elem : val) { - if (elem != null) { - coercedType = convertType(coercedType, elem.getClass()); - } - } - - if (coercedType == Long.class || coercedType == Integer.class) { - return val.stream().map(x -> x != null ? ((Number) x).longValue() : null).toArray(Long[]::new); - } - if (coercedType == Float.class || coercedType == Double.class) { - return val.stream().map(x -> x != null ? ((Number) x).doubleValue() : null).toArray(Double[]::new); - } - // default to string - return val.stream().map(x -> x != null ? x.toString() : null).toArray(String[]::new); - } - return new String[]{null}; - } - - private static Class convertType(@Nullable Class existing, Class next) - { - if (Number.class.isAssignableFrom(next) || next == String.class) { - if (existing == null) { - return next; - } - // string wins everything - if (existing == String.class) { - return existing; - } - if (next == String.class) { - return next; - } - // all numbers win over Integer - if (existing == Integer.class) { - return next; - } - if (existing == Float.class) { - // doubles win over floats - if (next == Double.class) { - return next; - } - return existing; - } - if (existing == Long.class) { - if (next == Integer.class) { - // long beats int - return existing; - } - // double and float win over longs - return next; - } - // otherwise double - return Double.class; - } - throw new UOE("Invalid array expression type: %s", next); - } - /** * Coerces {@link ExprEval} value back to selector friendly {@link List} if the evaluated expression result is an * array type diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java index 4a7635cad0f..343cd8a4978 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java @@ -72,7 +72,7 @@ public class ExpressionVirtualColumn implements VirtualColumn this.name = Preconditions.checkNotNull(name, "name"); this.expression = Preconditions.checkNotNull(expression, "expression"); this.outputType = outputType; - this.parsedExpression = Suppliers.memoize(() -> Parser.parse(expression, macroTable)); + this.parsedExpression = Parser.lazyParse(expression, macroTable); } /** diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java new file mode 100644 index 00000000000..5b5c2960af1 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java @@ -0,0 +1,570 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.java.util.common.HumanReadableBytes; +import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.query.Druids; +import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; +import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; +import org.apache.druid.query.expression.TestExprMacroTable; +import org.apache.druid.query.timeseries.TimeseriesQuery; +import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest; +import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import java.io.IOException; + +public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandlingTest +{ + private static ObjectMapper MAPPER = TestHelper.makeJsonMapper(); + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testSerde() throws IOException + { + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("some_column", "some_other_column"), + "customAccumulator", + "0.0", + "10.0", + "customAccumulator + some_column + some_other_column", + "customAccumulator + expr_agg_name", + "if (o1 > o2, if (o1 == o2, 0, 1), -1)", + "o + 100", + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ); + + Assert.assertEquals(agg, MAPPER.readValue(MAPPER.writeValueAsBytes(agg), ExpressionLambdaAggregatorFactory.class)); + } + + @Test + public void testEqualsAndHashCode() + { + EqualsVerifier.forClass(ExpressionLambdaAggregatorFactory.class) + .usingGetClass() + .withIgnoredFields( + "macroTable", + "initialValue", + "initialCombineValue", + "foldExpression", + "combineExpression", + "compareExpression", + "finalizeExpression", + "compareBindings", + "combineBindings", + "finalizeBindings" + ) + .verify(); + } + + @Test + public void testInitialValueMustBeConstant() + { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("initial value must be constant"); + + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("some_column", "some_other_column"), + null, + "x + y", + null, + "__acc + some_column + some_other_column", + "__acc + expr_agg_name", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ); + + agg.getType(); + } + + @Test + public void testInitialCombineValueMustBeConstant() + { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("initial combining value must be constant"); + + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("some_column", "some_other_column"), + null, + "0.0", + "x + y", + "__acc + some_column + some_other_column", + "__acc + expr_agg_name", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ); + + agg.getFinalizedType(); + } + + @Test + public void testSingleInputCombineExpressionIsOptional() + { + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("x"), + null, + "0", + null, + "__acc + x", + null, + null, + null, + null, + TestExprMacroTable.INSTANCE + ); + + Assert.assertEquals(1L, agg.combine(0L, 1L)); + } + + @Test + public void testFinalizeCanDo() + { + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("x"), + null, + "0", + null, + "__acc + x", + null, + null, + "o + 100", + null, + TestExprMacroTable.INSTANCE + ); + + Assert.assertEquals(100L, agg.finalizeComputation(0L)); + } + + @Test + public void testFinalizeCanDoArrays() + { + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("x"), + null, + "0", + null, + "array_set_add(__acc, x)", + "array_set_add_all(__acc, expr_agg_name)", + null, + "array_to_string(o, ',')", + null, + TestExprMacroTable.INSTANCE + ); + + Assert.assertEquals("a,b,c", agg.finalizeComputation(new String[]{"a", "b", "c"})); + Assert.assertEquals("a,b,c", agg.finalizeComputation(ImmutableList.of("a", "b", "c"))); + } + + @Test + public void testStringType() + { + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("some_column", "some_other_column"), + null, + "''", + "''", + "concat(__acc, some_column, some_other_column)", + "concat(__acc, expr_agg_name)", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ); + + Assert.assertEquals(ValueType.STRING, agg.getType()); + Assert.assertEquals(ValueType.STRING, agg.getCombiningFactory().getType()); + Assert.assertEquals(ValueType.STRING, agg.getFinalizedType()); + } + + @Test + public void testLongType() + { + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("some_column", "some_other_column"), + null, + "0", + null, + "__acc + some_column + some_other_column", + "__acc + expr_agg_name", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ); + + Assert.assertEquals(ValueType.LONG, agg.getType()); + Assert.assertEquals(ValueType.LONG, agg.getCombiningFactory().getType()); + Assert.assertEquals(ValueType.LONG, agg.getFinalizedType()); + } + + @Test + public void testDoubleType() + { + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("some_column", "some_other_column"), + null, + "0.0", + null, + "__acc + some_column + some_other_column", + "__acc + expr_agg_name", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ); + + Assert.assertEquals(ValueType.DOUBLE, agg.getType()); + Assert.assertEquals(ValueType.DOUBLE, agg.getCombiningFactory().getType()); + Assert.assertEquals(ValueType.DOUBLE, agg.getFinalizedType()); + } + + @Test + public void testStringArrayType() + { + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("some_column", "some_other_column"), + null, + "''", + "[]", + "concat(__acc, some_column, some_other_column)", + "array_set_add(__acc, expr_agg_name)", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ); + + Assert.assertEquals(ValueType.STRING, agg.getType()); + Assert.assertEquals(ValueType.STRING_ARRAY, agg.getCombiningFactory().getType()); + Assert.assertEquals(ValueType.STRING_ARRAY, agg.getFinalizedType()); + } + + @Test + public void testStringArrayTypeFinalized() + { + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("some_column", "some_other_column"), + null, + "''", + "[]", + "concat(__acc, some_column, some_other_column)", + "array_set_add(__acc, expr_agg_name)", + null, + "array_to_string(o, ';')", + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ); + + Assert.assertEquals(ValueType.STRING, agg.getType()); + Assert.assertEquals(ValueType.STRING_ARRAY, agg.getCombiningFactory().getType()); + Assert.assertEquals(ValueType.STRING, agg.getFinalizedType()); + } + + @Test + public void testLongArrayType() + { + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("some_column", "some_other_column"), + null, + "0", + "[]", + "__acc + some_column + some_other_column", + "array_set_add(__acc, expr_agg_name)", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ); + + Assert.assertEquals(ValueType.LONG, agg.getType()); + Assert.assertEquals(ValueType.LONG_ARRAY, agg.getCombiningFactory().getType()); + Assert.assertEquals(ValueType.LONG_ARRAY, agg.getFinalizedType()); + } + + @Test + public void testLongArrayTypeFinalized() + { + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("some_column", "some_other_column"), + null, + "0", + "[]", + "__acc + some_column + some_other_column", + "array_set_add(__acc, expr_agg_name)", + null, + "array_to_string(o, ';')", + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ); + + Assert.assertEquals(ValueType.LONG, agg.getType()); + Assert.assertEquals(ValueType.LONG_ARRAY, agg.getCombiningFactory().getType()); + Assert.assertEquals(ValueType.STRING, agg.getFinalizedType()); + } + + @Test + public void testDoubleArrayType() + { + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("some_column", "some_other_column"), + null, + "0.0", + "[]", + "__acc + some_column + some_other_column", + "array_set_add(__acc, expr_agg_name)", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ); + + Assert.assertEquals(ValueType.DOUBLE, agg.getType()); + Assert.assertEquals(ValueType.DOUBLE_ARRAY, agg.getCombiningFactory().getType()); + Assert.assertEquals(ValueType.DOUBLE_ARRAY, agg.getFinalizedType()); + } + + @Test + public void testDoubleArrayTypeFinalized() + { + ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory( + "expr_agg_name", + ImmutableSet.of("some_column", "some_other_column"), + null, + "0.0", + "[]", + "__acc + some_column + some_other_column", + "array_set_add(__acc, expr_agg_name)", + null, + "array_to_string(o, ';')", + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ); + + Assert.assertEquals(ValueType.DOUBLE, agg.getType()); + Assert.assertEquals(ValueType.DOUBLE_ARRAY, agg.getCombiningFactory().getType()); + Assert.assertEquals(ValueType.STRING, agg.getFinalizedType()); + } + + @Test + public void testResultArraySignature() + { + final TimeseriesQuery query = + Druids.newTimeseriesQueryBuilder() + .dataSource("dummy") + .intervals("2000/3000") + .granularity(Granularities.HOUR) + .aggregators( + new ExpressionLambdaAggregatorFactory( + "string_expr", + ImmutableSet.of("some_column", "some_other_column"), + null, + "''", + "''", + "concat(__acc, some_column, some_other_column)", + "concat(__acc, string_expr)", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "double_expr", + ImmutableSet.of("some_column", "some_other_column"), + null, + "0.0", + null, + "__acc + some_column + some_other_column", + "__acc + double_expr", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "long_expr", + ImmutableSet.of("some_column", "some_other_column"), + null, + "0", + null, + "__acc + some_column + some_other_column", + "__acc + long_expr", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "string_array_expr", + ImmutableSet.of("some_column", "some_other_column"), + null, + "[]", + "[]", + "array_set_add(__acc, concat(some_column, some_other_column))", + "array_set_add_all(__acc, string_array_expr)", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "double_array_expr", + ImmutableSet.of("some_column", "some_other_column_expr"), + null, + "0.0", + "[]", + "__acc + some_column + some_other_column", + "array_set_add(__acc, double_array)", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "long_array_expr", + ImmutableSet.of("some_column", "some_other_column"), + null, + "0", + "[]", + "__acc + some_column + some_other_column", + "array_set_add(__acc, long_array_expr)", + null, + null, + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "string_array_expr_finalized", + ImmutableSet.of("some_column", "some_other_column"), + null, + "''", + "[]", + "concat(__acc, some_column, some_other_column)", + "array_set_add(__acc, string_array_expr)", + null, + "array_to_string(o, ';')", + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "double_array_expr_finalized", + ImmutableSet.of("some_column", "some_other_column_expr"), + null, + "0.0", + "[]", + "__acc + some_column + some_other_column", + "array_set_add(__acc, double_array)", + null, + "array_to_string(o, ';')", + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "long_array_expr_finalized", + ImmutableSet.of("some_column", "some_other_column"), + null, + "0", + "[]", + "__acc + some_column + some_other_column", + "array_set_add(__acc, long_array_expr)", + null, + "fold((x, acc) -> x + acc, o, 0)", + new HumanReadableBytes(2048), + TestExprMacroTable.INSTANCE + ) + ) + .postAggregators( + new FieldAccessPostAggregator("string-array-expr-access", "string_array_expr_finalized"), + new FinalizingFieldAccessPostAggregator("string-array-expr-finalize", "string_array_expr_finalized"), + new FieldAccessPostAggregator("double-array-expr-access", "double_array_expr_finalized"), + new FinalizingFieldAccessPostAggregator("double-array-expr-finalize", "double_array_expr_finalized"), + new FieldAccessPostAggregator("long-array-expr-access", "long_array_expr_finalized"), + new FinalizingFieldAccessPostAggregator("long-array-expr-finalize", "long_array_expr_finalized") + ) + .build(); + + Assert.assertEquals( + RowSignature.builder() + .addTimeColumn() + .add("string_expr", ValueType.STRING) + .add("double_expr", ValueType.DOUBLE) + .add("long_expr", ValueType.LONG) + .add("string_array_expr", ValueType.STRING_ARRAY) + // type does not equal finalized type. (combining factory type does equal finalized type, + // but this signature doesn't use combining factory) + .add("double_array_expr", null) + // type does not equal finalized type. (combining factory type does equal finalized type, + // but this signature doesn't use combining factory) + .add("long_array_expr", null) + // string because fold type equals finalized type, even though merge type is array + .add("string_array_expr_finalized", ValueType.STRING) + // type does not equal finalized type. (combining factory type does equal finalized type, + // but this signature doesn't use combining factory) + .add("double_array_expr_finalized", null) + // long because fold type equals finalized type, even though merge type is array + .add("long_array_expr_finalized", ValueType.LONG) + // fold type is string + .add("string-array-expr-access", ValueType.STRING) + // finalized type is string + .add("string-array-expr-finalize", ValueType.STRING) + // double because fold type is double + .add("double-array-expr-access", ValueType.DOUBLE) + // string because finalize type is string + .add("double-array-expr-finalize", ValueType.STRING) + // long because fold type is long + .add("long-array-expr-access", ValueType.LONG) + // finalized type is long + .add("long-array-expr-finalize", ValueType.LONG) + .build(), + new TimeseriesQueryQueryToolChest().resultArraySignature(query) + ); + } +} diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java index bc0f3b2ed7a..8d0c8648b0b 100644 --- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java +++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java @@ -25,6 +25,7 @@ import com.google.common.base.Supplier; import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Ordering; @@ -67,6 +68,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; +import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.FloatSumAggregatorFactory; import org.apache.druid.query.aggregation.JavaScriptAggregatorFactory; @@ -11138,6 +11140,704 @@ public class GroupByQueryRunnerTest extends InitializedNullHandlingTest TestHelper.assertExpectedObjects(expectedResults, results, "groupBy"); } + @Test + public void testGroupByWithExpressionAggregator() + { + // expression agg not yet vectorized + cannotVectorize(); + GroupByQuery query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setDimensions(new DefaultDimensionSpec("quality", "alias")) + .setAggregatorSpecs( + new ExpressionLambdaAggregatorFactory( + "rows", + Collections.emptySet(), + null, + "0", + null, + "__acc + 1", + "__acc + rows", + null, + null, + null, + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "idx", + ImmutableSet.of("index"), + null, + "0.0", + null, + "__acc + index", + null, + null, + null, + null, + TestExprMacroTable.INSTANCE + ) + ) + .setGranularity(QueryRunnerTestHelper.DAY_GRAN) + .build(); + + List expectedResults = Arrays.asList( + makeRow( + query, + "2011-04-01", + "alias", + "automotive", + "rows", + 1L, + "idx", + 135.88510131835938d + ), + makeRow( + query, + "2011-04-01", + "alias", + "business", + "rows", + 1L, + "idx", + 118.57034 + ), + makeRow( + query, + "2011-04-01", + "alias", + "entertainment", + "rows", + 1L, + "idx", + 158.747224 + ), + makeRow( + query, + "2011-04-01", + "alias", + "health", + "rows", + 1L, + "idx", + 120.134704 + ), + makeRow( + query, + "2011-04-01", + "alias", + "mezzanine", + "rows", + 3L, + "idx", + 2871.8866900000003d + ), + makeRow( + query, + "2011-04-01", + "alias", + "news", + "rows", + 1L, + "idx", + 121.58358d + ), + makeRow( + query, + "2011-04-01", + "alias", + "premium", + "rows", + 3L, + "idx", + 2900.798647d + ), + makeRow( + query, + "2011-04-01", + "alias", + "technology", + "rows", + 1L, + "idx", + 78.622547d + ), + makeRow( + query, + "2011-04-01", + "alias", + "travel", + "rows", + 1L, + "idx", + 119.922742d + ), + + makeRow( + query, + "2011-04-02", + "alias", + "automotive", + "rows", + 1L, + "idx", + 147.42593d + ), + makeRow( + query, + "2011-04-02", + "alias", + "business", + "rows", + 1L, + "idx", + 112.987027d + ), + makeRow( + query, + "2011-04-02", + "alias", + "entertainment", + "rows", + 1L, + "idx", + 166.016049d + ), + makeRow( + query, + "2011-04-02", + "alias", + "health", + "rows", + 1L, + "idx", + 113.446008d + ), + makeRow( + query, + "2011-04-02", + "alias", + "mezzanine", + "rows", + 3L, + "idx", + 2448.830613d + ), + makeRow( + query, + "2011-04-02", + "alias", + "news", + "rows", + 1L, + "idx", + 114.290141d + ), + makeRow( + query, + "2011-04-02", + "alias", + "premium", + "rows", + 3L, + "idx", + 2506.415148d + ), + makeRow( + query, + "2011-04-02", + "alias", + "technology", + "rows", + 1L, + "idx", + 97.387433d + ), + makeRow( + query, + "2011-04-02", + "alias", + "travel", + "rows", + 1L, + "idx", + 126.411364d + ) + ); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + TestHelper.assertExpectedObjects(expectedResults, results, "groupBy"); + } + + @Test + public void testGroupByWithExpressionAggregatorWithArrays() + { + // expression agg not yet vectorized + cannotVectorize(); + + // array types don't work with group by v1 + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("Unable to handle type[STRING_ARRAY] for AggregatorFactory[class org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory]"); + } + + GroupByQuery query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setDimensions(new DefaultDimensionSpec("quality", "alias")) + .setAggregatorSpecs( + new ExpressionLambdaAggregatorFactory( + "rows", + Collections.emptySet(), + null, + "0", + null, + "__acc + 1", + "__acc + rows", + null, + null, + null, + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "idx", + ImmutableSet.of("index"), + null, + "0.0", + null, + "__acc + index", + null, + null, + null, + null, + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "array_agg_distinct", + ImmutableSet.of(QueryRunnerTestHelper.MARKET_DIMENSION), + "acc", + "[]", + null, + "array_set_add(acc, market)", + "array_set_add_all(acc, array_agg_distinct)", + null, + null, + null, + TestExprMacroTable.INSTANCE + ) + ) + .setGranularity(QueryRunnerTestHelper.DAY_GRAN) + .build(); + + List expectedResults = Arrays.asList( + makeRow( + query, + "2011-04-01", + "alias", + "automotive", + "rows", + 1L, + "idx", + 135.88510131835938d, + "array_agg_distinct", + new String[] {"spot"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "business", + "rows", + 1L, + "idx", + 118.57034, + "array_agg_distinct", + new String[] {"spot"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "entertainment", + "rows", + 1L, + "idx", + 158.747224, + "array_agg_distinct", + new String[] {"spot"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "health", + "rows", + 1L, + "idx", + 120.134704, + "array_agg_distinct", + new String[] {"spot"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "mezzanine", + "rows", + 3L, + "idx", + 2871.8866900000003d, + "array_agg_distinct", + new String[] {"upfront", "spot", "total_market"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "news", + "rows", + 1L, + "idx", + 121.58358d, + "array_agg_distinct", + new String[] {"spot"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "premium", + "rows", + 3L, + "idx", + 2900.798647d, + "array_agg_distinct", + new String[] {"upfront", "spot", "total_market"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "technology", + "rows", + 1L, + "idx", + 78.622547d, + "array_agg_distinct", + new String[] {"spot"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "travel", + "rows", + 1L, + "idx", + 119.922742d, + "array_agg_distinct", + new String[] {"spot"} + ), + + makeRow( + query, + "2011-04-02", + "alias", + "automotive", + "rows", + 1L, + "idx", + 147.42593d, + "array_agg_distinct", + new String[] {"spot"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "business", + "rows", + 1L, + "idx", + 112.987027d, + "array_agg_distinct", + new String[] {"spot"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "entertainment", + "rows", + 1L, + "idx", + 166.016049d, + "array_agg_distinct", + new String[] {"spot"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "health", + "rows", + 1L, + "idx", + 113.446008d, + "array_agg_distinct", + new String[] {"spot"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "mezzanine", + "rows", + 3L, + "idx", + 2448.830613d, + "array_agg_distinct", + new String[] {"upfront", "spot", "total_market"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "news", + "rows", + 1L, + "idx", + 114.290141d, + "array_agg_distinct", + new String[] {"spot"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "premium", + "rows", + 3L, + "idx", + 2506.415148d, + "array_agg_distinct", + new String[] {"upfront", "spot", "total_market"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "technology", + "rows", + 1L, + "idx", + 97.387433d, + "array_agg_distinct", + new String[] {"spot"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "travel", + "rows", + 1L, + "idx", + 126.411364d, + "array_agg_distinct", + new String[] {"spot"} + ) + ); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + TestHelper.assertExpectedObjects(expectedResults, results, "groupBy"); + } + + @Test + public void testGroupByExpressionAggregatorArrayMultiValue() + { + // expression agg not yet vectorized + cannotVectorize(); + + // array types don't work with group by v1 + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("Unable to handle type[STRING_ARRAY] for AggregatorFactory[class org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory]"); + } + + GroupByQuery query = makeQueryBuilder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setDimensions(new DefaultDimensionSpec("quality", "alias")) + .setAggregatorSpecs( + new ExpressionLambdaAggregatorFactory( + "array_agg_distinct", + ImmutableSet.of(QueryRunnerTestHelper.PLACEMENTISH_DIMENSION), + "acc", + "[]", + null, + "array_set_add(acc, placementish)", + "array_set_add_all(acc, array_agg_distinct)", + null, + null, + null, + TestExprMacroTable.INSTANCE + ) + ) + .setGranularity(QueryRunnerTestHelper.DAY_GRAN) + .build(); + + List expectedResults = Arrays.asList( + makeRow( + query, + "2011-04-01", + "alias", + "automotive", + "array_agg_distinct", + new String[] {"a", "preferred"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "business", + "array_agg_distinct", + new String[] {"b", "preferred"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "entertainment", + "array_agg_distinct", + new String[] {"e", "preferred"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "health", + "array_agg_distinct", + new String[] {"h", "preferred"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "mezzanine", + "array_agg_distinct", + new String[] {"m", "preferred"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "news", + "array_agg_distinct", + new String[] {"n", "preferred"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "premium", + "array_agg_distinct", + new String[] {"p", "preferred"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "technology", + "array_agg_distinct", + new String[] {"t", "preferred"} + ), + makeRow( + query, + "2011-04-01", + "alias", + "travel", + "array_agg_distinct", + new String[] {"t", "preferred"} + ), + + makeRow( + query, + "2011-04-02", + "alias", + "automotive", + "array_agg_distinct", + new String[] {"a", "preferred"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "business", + "array_agg_distinct", + new String[] {"b", "preferred"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "entertainment", + "array_agg_distinct", + new String[] {"e", "preferred"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "health", + "array_agg_distinct", + new String[] {"h", "preferred"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "mezzanine", + "array_agg_distinct", + new String[] {"m", "preferred"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "news", + "array_agg_distinct", + new String[] {"n", "preferred"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "premium", + "array_agg_distinct", + new String[] {"p", "preferred"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "technology", + "array_agg_distinct", + new String[] {"t", "preferred"} + ), + makeRow( + query, + "2011-04-02", + "alias", + "travel", + "array_agg_distinct", + new String[] {"t", "preferred"} + ) + ); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + TestHelper.assertExpectedObjects(expectedResults, results, "groupBy"); + } + private static ResultRow makeRow(final GroupByQuery query, final String timestamp, final Object... vals) { return GroupByQueryRunnerTestHelper.createExpectedRow(query, timestamp, vals); diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByTimeseriesQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByTimeseriesQueryRunnerTest.java index 8b80fc436ba..50810c00ac2 100644 --- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByTimeseriesQueryRunnerTest.java +++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByTimeseriesQueryRunnerTest.java @@ -303,32 +303,4 @@ public class GroupByTimeseriesQueryRunnerTest extends TimeseriesQueryRunnerTest // Skip this test because the timeseries test expects a day that doesn't have a filter match to be filled in, // but group by just doesn't return a value if the filter doesn't match. } - - @Override - public void testTimeseriesWithTimestampResultFieldContextForArrayResponse() - { - // Cannot vectorize with an expression virtual column - if (!vectorize) { - super.testTimeseriesWithTimestampResultFieldContextForArrayResponse(); - } - } - - @Override - public void testTimeseriesWithTimestampResultFieldContextForMapResponse() - { - // Cannot vectorize with an expression virtual column - if (!vectorize) { - super.testTimeseriesWithTimestampResultFieldContextForMapResponse(); - } - } - - @Override - @Test - public void testTimeseriesWithPostAggregatorReferencingTimestampResultField() - { - // Cannot vectorize with an expression virtual column - if (!vectorize) { - super.testTimeseriesWithPostAggregatorReferencingTimestampResultField(); - } - } } diff --git a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java index c1ae6d408c7..58b2a90596c 100644 --- a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java +++ b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java @@ -21,6 +21,7 @@ package org.apache.druid.query.timeseries; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.primitives.Doubles; @@ -44,6 +45,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory; import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory; +import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory; @@ -2935,6 +2937,104 @@ public class TimeseriesQueryRunnerTest extends InitializedNullHandlingTest assertExpectedResults(expectedResults, results); } + @Test + public void testTimeseriesWithExpressionAggregator() + { + // expression agg cannot vectorize + cannotVectorize(); + TimeseriesQuery query = Druids.newTimeseriesQueryBuilder() + .dataSource(QueryRunnerTestHelper.DATA_SOURCE) + .granularity(QueryRunnerTestHelper.DAY_GRAN) + .intervals(QueryRunnerTestHelper.FIRST_TO_THIRD) + .aggregators( + Arrays.asList( + new ExpressionLambdaAggregatorFactory( + "diy_count", + ImmutableSet.of(), + null, + "0", + null, + "__acc + 1", + "__acc + diy_count", + null, + null, + null, + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "diy_sum", + ImmutableSet.of("index"), + null, + "0.0", + null, + "__acc + index", + null, + null, + null, + null, + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "diy_decomposed_sum", + ImmutableSet.of("index"), + null, + "0.0", + "[]", + "__acc + index", + "array_concat(__acc, diy_decomposed_sum)", + null, + "fold((x, acc) -> x + acc, o, 0.0)", + null, + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "array_agg_distinct", + ImmutableSet.of(QueryRunnerTestHelper.MARKET_DIMENSION), + "acc", + "[]", + null, + "array_set_add(acc, market)", + "array_set_add_all(acc, array_agg_distinct)", + null, + null, + null, + TestExprMacroTable.INSTANCE + ) + ) + ) + .descending(descending) + .context(makeContext()) + .build(); + + List> expectedResults = Arrays.asList( + new Result<>( + DateTimes.of("2011-04-01"), + new TimeseriesResultValue( + ImmutableMap.of( + "diy_count", 13L, + "diy_sum", 6626.151569, + "diy_decomposed_sum", 6626.151569, + "array_agg_distinct", new String[] {"upfront", "spot", "total_market"} + ) + ) + ), + new Result<>( + DateTimes.of("2011-04-02"), + new TimeseriesResultValue( + ImmutableMap.of( + "diy_count", 13L, + "diy_sum", 5833.209718, + "diy_decomposed_sum", 5833.209718, + "array_agg_distinct", new String[] {"upfront", "spot", "total_market"} + ) + ) + ) + ); + + Iterable> results = runner.run(QueryPlus.wrap(query)).toList(); + assertExpectedResults(expectedResults, results); + } + private Map makeContext() { return makeContext(ImmutableMap.of()); diff --git a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java index f98160a1a39..eb8709d4822 100644 --- a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java +++ b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java @@ -22,6 +22,7 @@ package org.apache.druid.query.topn; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Sets; @@ -52,6 +53,7 @@ import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory; import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; +import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory; import org.apache.druid.query.aggregation.FloatMinAggregatorFactory; @@ -5964,6 +5966,110 @@ public class TopNQueryRunnerTest extends InitializedNullHandlingTest assertExpectedResults(expectedResults, query); } + @Test + public void testExpressionAggregator() + { + // sorted by array length of array_agg_distinct + TopNQuery query = new TopNQueryBuilder() + .dataSource(QueryRunnerTestHelper.DATA_SOURCE) + .granularity(QueryRunnerTestHelper.ALL_GRAN) + .dimension(QueryRunnerTestHelper.MARKET_DIMENSION) + .metric("array_agg_distinct") + .threshold(4) + .intervals(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC) + .aggregators( + Lists.newArrayList( + Arrays.asList( + new ExpressionLambdaAggregatorFactory( + "diy_count", + Collections.emptySet(), + null, + "0", + null, + "__acc + 1", + "__acc + diy_count", + null, + null, + null, + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "diy_sum", + ImmutableSet.of("index"), + null, + "0.0", + null, + "__acc + index", + null, + null, + null, + null, + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "diy_decomposed_sum", + ImmutableSet.of("index"), + null, + "0.0", + "[]", + "__acc + index", + "array_concat(__acc, diy_decomposed_sum)", + null, + "fold((x, acc) -> x + acc, o, 0.0)", + null, + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "array_agg_distinct", + ImmutableSet.of(QueryRunnerTestHelper.QUALITY_DIMENSION), + "acc", + "[]", + null, + "array_set_add(acc, quality)", + "array_set_add_all(acc, array_agg_distinct)", + "if(array_length(o1) > array_length(o2), 1, if (array_length(o1) == array_length(o2), 0, -1))", + null, + null, + TestExprMacroTable.INSTANCE + ) + ) + ) + ) + .build(); + + List> expectedResults = Collections.singletonList( + new Result<>( + DateTimes.of("2011-01-12T00:00:00.000Z"), + new TopNResultValue( + Arrays.>asList( + ImmutableMap.builder() + .put(QueryRunnerTestHelper.MARKET_DIMENSION, "spot") + .put("diy_count", 837L) + .put("diy_sum", 95606.57232284546D) + .put("diy_decomposed_sum", 95606.57232284546D) + .put("array_agg_distinct", new String[]{"mezzanine", "news", "premium", "business", "entertainment", "health", "technology", "automotive", "travel"}) + .build(), + ImmutableMap.builder() + .put(QueryRunnerTestHelper.MARKET_DIMENSION, "total_market") + .put("diy_count", 186L) + .put("diy_sum", 215679.82879638672D) + .put("diy_decomposed_sum", 215679.82879638672D) + .put("array_agg_distinct", new String[]{"mezzanine", "premium"}) + .build(), + ImmutableMap.builder() + .put(QueryRunnerTestHelper.MARKET_DIMENSION, "upfront") + .put("diy_count", 186L) + .put("diy_sum", 192046.1060180664D) + .put("diy_decomposed_sum", 192046.1060180664D) + .put("array_agg_distinct", new String[]{"mezzanine", "premium"}) + .build() + ) + ) + ) + ); + assertExpectedResults(expectedResults, query); + } + private static Map makeRowWithNulls( String dimName, @Nullable Object dimValue, diff --git a/processing/src/test/java/org/apache/druid/segment/TestHelper.java b/processing/src/test/java/org/apache/druid/segment/TestHelper.java index a2e409d9553..59f739aa8f0 100644 --- a/processing/src/test/java/org/apache/druid/segment/TestHelper.java +++ b/processing/src/test/java/org/apache/druid/segment/TestHelper.java @@ -32,6 +32,7 @@ import org.apache.druid.guice.GuiceAnnotationIntrospector; import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.Result; import org.apache.druid.query.expression.TestExprMacroTable; @@ -352,7 +353,9 @@ public class TestHelper final Object expectedValue = expectedMap.get(key); final Object actualValue = actualMap.get(key); - if (expectedValue instanceof Float || expectedValue instanceof Double) { + if (expectedValue != null && expectedValue.getClass().isArray()) { + Assert.assertArrayEquals((Object[]) expectedValue, (Object[]) actualValue); + } else if (expectedValue instanceof Float || expectedValue instanceof Double) { Assert.assertEquals( StringUtils.format("%s: key[%s]", msg, key), ((Number) expectedValue).doubleValue(), @@ -382,7 +385,23 @@ public class TestHelper final Object expectedValue = expected.get(i); final Object actualValue = actual.get(i); - if (expectedValue instanceof Float || expectedValue instanceof Double) { + + if (expectedValue != null && expectedValue.getClass().isArray()) { + // spilled results will materialize into lists, coerce them back to arrays if we expected arrays + if (actualValue instanceof List) { + Assert.assertEquals( + message, + (Object[]) expectedValue, + (Object[]) ExprEval.coerceListToArray((List) actualValue, true) + ); + } else { + Assert.assertArrayEquals( + message, + (Object[]) expectedValue, + (Object[]) actualValue + ); + } + } else if (expectedValue instanceof Float || expectedValue instanceof Double) { Assert.assertEquals( message, ((Number) expectedValue).doubleValue(), diff --git a/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionSelectorsTest.java b/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionSelectorsTest.java index 64da13d8cc4..e1a9f27ded9 100644 --- a/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionSelectorsTest.java +++ b/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionSelectorsTest.java @@ -243,77 +243,6 @@ public class ExpressionSelectorsTest extends InitializedNullHandlingTest } - @Test - public void test_coerceListToArray() - { - List longList = ImmutableList.of(1L, 2L, 3L); - Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExpressionSelectors.coerceListToArray(longList)); - - List intList = ImmutableList.of(1, 2, 3); - Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExpressionSelectors.coerceListToArray(intList)); - - List floatList = ImmutableList.of(1.0f, 2.0f, 3.0f); - Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExpressionSelectors.coerceListToArray(floatList)); - - List doubleList = ImmutableList.of(1.0, 2.0, 3.0); - Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExpressionSelectors.coerceListToArray(doubleList)); - - List stringList = ImmutableList.of("a", "b", "c"); - Assert.assertArrayEquals(new String[]{"a", "b", "c"}, (String[]) ExpressionSelectors.coerceListToArray(stringList)); - - List withNulls = new ArrayList<>(); - withNulls.add("a"); - withNulls.add(null); - withNulls.add("c"); - Assert.assertArrayEquals(new String[]{"a", null, "c"}, (String[]) ExpressionSelectors.coerceListToArray(withNulls)); - - List withNumberNulls = new ArrayList<>(); - withNumberNulls.add(1L); - withNumberNulls.add(null); - withNumberNulls.add(3L); - - Assert.assertArrayEquals(new Long[]{1L, null, 3L}, (Long[]) ExpressionSelectors.coerceListToArray(withNumberNulls)); - - List withStringMix = ImmutableList.of(1L, "b", 3L); - Assert.assertArrayEquals( - new String[]{"1", "b", "3"}, - (String[]) ExpressionSelectors.coerceListToArray(withStringMix) - ); - - List withIntsAndLongs = ImmutableList.of(1, 2L, 3); - Assert.assertArrayEquals( - new Long[]{1L, 2L, 3L}, - (Long[]) ExpressionSelectors.coerceListToArray(withIntsAndLongs) - ); - - List withFloatsAndLongs = ImmutableList.of(1, 2L, 3.0f); - Assert.assertArrayEquals( - new Double[]{1.0, 2.0, 3.0}, - (Double[]) ExpressionSelectors.coerceListToArray(withFloatsAndLongs) - ); - - List withDoublesAndLongs = ImmutableList.of(1, 2L, 3.0); - Assert.assertArrayEquals( - new Double[]{1.0, 2.0, 3.0}, - (Double[]) ExpressionSelectors.coerceListToArray(withDoublesAndLongs) - ); - - List withFloatsAndDoubles = ImmutableList.of(1L, 2.0f, 3.0); - Assert.assertArrayEquals( - new Double[]{1.0, 2.0, 3.0}, - (Double[]) ExpressionSelectors.coerceListToArray(withFloatsAndDoubles) - ); - - List withAllNulls = new ArrayList<>(); - withAllNulls.add(null); - withAllNulls.add(null); - withAllNulls.add(null); - Assert.assertArrayEquals( - new String[]{null, null, null}, - (String[]) ExpressionSelectors.coerceListToArray(withAllNulls) - ); - } - @Test public void test_coerceEvalToSelectorObject() { diff --git a/website/.spelling b/website/.spelling index 315e5ffc9aa..d717fbf9ecf 100644 --- a/website/.spelling +++ b/website/.spelling @@ -1100,6 +1100,8 @@ arr1 arr2 array_append array_concat +array_set_add +array_set_add_all array_contains array_length array_offset