complex typed expressions (#11853)

* complex typed expressions

* add built-in hll collector expressions to get coverage on druid-processing, more types, more better

* rampage!!!

* more javadoc

* adjustments

* oops

* lol

* remove unused dependency

* contradiction?

* more test
This commit is contained in:
Clint Wylie 2021-11-08 00:33:06 -08:00 committed by GitHub
parent 8e7e679984
commit 7237dc837c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
96 changed files with 4589 additions and 1590 deletions

View File

@ -187,10 +187,6 @@
<groupId>it.unimi.dsi</groupId>
<artifactId>fastutil-core</artifactId>
</dependency>
<dependency>
<groupId>it.unimi.dsi</groupId>
<artifactId>fastutil-extra</artifactId>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-buffer</artifactId>

View File

@ -36,6 +36,7 @@ expr : NULL # null
| '<LONG>' '[' (numericElement (',' numericElement)*)? ']' # explicitLongArray
| '<DOUBLE>'? '[' (numericElement (',' numericElement)*)? ']' # doubleArray
| '<STRING>' '[' (literalElement (',' literalElement)*)? ']' # explicitStringArray
| ARRAY_TYPE '[' (literalElement (',' literalElement)*)? ']' # explicitArray
;
lambda : (IDENTIFIER | '(' ')' | '(' IDENTIFIER (',' IDENTIFIER)* ')') '->' expr
@ -52,6 +53,8 @@ numericElement : (LONG | DOUBLE | NULL);
literalElement : (STRING | LONG | DOUBLE | NULL);
ARRAY_TYPE : 'ARRAY<' ( 'LONG' | 'DOUBLE' | 'STRING' | ('COMPLEX<' IDENTIFIER '>')| ARRAY_TYPE ) '>';
NULL : 'null';
LONG : [0-9]+;
EXP: [eE] [-]? LONG;

View File

@ -25,7 +25,6 @@ import it.unimi.dsi.fastutil.objects.Object2IntArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;
@ -138,54 +137,16 @@ public interface ApplyFunction
/**
* Evaluate {@link LambdaExpr} against every index position of an {@link IndexableMapLambdaObjectBinding}
*/
ExprEval applyMap(LambdaExpr expr, IndexableMapLambdaObjectBinding bindings)
ExprEval applyMap(@Nullable ExpressionType arrayType, LambdaExpr expr, IndexableMapLambdaObjectBinding bindings)
{
final int length = bindings.getLength();
String[] stringsOut = null;
Long[] longsOut = null;
Double[] doublesOut = null;
ExpressionType elementType = null;
Object[] out = new Object[length];
for (int i = 0; i < length; i++) {
ExprEval evaluated = expr.eval(bindings.withIndex(i));
if (elementType == null) {
elementType = evaluated.type();
switch (elementType.getType()) {
case STRING:
stringsOut = new String[length];
break;
case LONG:
longsOut = new Long[length];
break;
case DOUBLE:
doublesOut = new Double[length];
break;
default:
throw new RE("Unhandled map function output type [%s]", elementType);
}
}
Function.ArrayConstructorFunction.setArrayOutputElement(
stringsOut,
longsOut,
doublesOut,
elementType,
i,
evaluated
);
}
switch (elementType.getType()) {
case STRING:
return ExprEval.ofStringArray(stringsOut);
case LONG:
return ExprEval.ofLongArray(longsOut);
case DOUBLE:
return ExprEval.ofDoubleArray(doublesOut);
default:
throw new RE("Unhandled map function output type [%s]", elementType);
arrayType = Function.ArrayConstructorFunction.setArrayOutput(arrayType, out, i, evaluated);
}
return ExprEval.ofArray(arrayType, out);
}
}
@ -216,8 +177,9 @@ public interface ApplyFunction
return arrayEval;
}
MapLambdaBinding lambdaBinding = new MapLambdaBinding(array, lambdaExpr, bindings);
return applyMap(lambdaExpr, lambdaBinding);
MapLambdaBinding lambdaBinding = new MapLambdaBinding(arrayEval.elementType(), array, lambdaExpr, bindings);
ExpressionType lambdaType = lambdaExpr.getOutputType(lambdaBinding);
return applyMap(lambdaType == null ? null : ExpressionTypeFactory.getInstance().ofArray(lambdaType), lambdaExpr, lambdaBinding);
}
@Override
@ -261,6 +223,7 @@ public interface ApplyFunction
List<List<Object>> arrayInputs = new ArrayList<>();
boolean hadNull = false;
boolean hadEmpty = false;
ExpressionType elementType = null;
for (Expr expr : argsExpr) {
ExprEval arrayEval = expr.eval(bindings);
Object[] array = arrayEval.asArray();
@ -268,6 +231,7 @@ public interface ApplyFunction
hadNull = true;
continue;
}
elementType = arrayEval.elementType();
if (array.length == 0) {
hadEmpty = true;
continue;
@ -282,8 +246,9 @@ public interface ApplyFunction
}
List<List<Object>> product = CartesianList.create(arrayInputs);
CartesianMapLambdaBinding lambdaBinding = new CartesianMapLambdaBinding(product, lambdaExpr, bindings);
return applyMap(lambdaExpr, lambdaBinding);
CartesianMapLambdaBinding lambdaBinding = new CartesianMapLambdaBinding(elementType, product, lambdaExpr, bindings);
ExpressionType lambdaType = lambdaExpr.getOutputType(lambdaBinding);
return applyMap(ExpressionType.asArrayType(lambdaType), lambdaExpr, lambdaBinding);
}
@Override
@ -324,7 +289,7 @@ public interface ApplyFunction
if (accumulator instanceof Boolean) {
return ExprEval.ofLongBoolean((boolean) accumulator);
}
return ExprEval.bestEffortOf(accumulator);
return ExprEval.ofType(bindings.getAccumulatorType(), accumulator);
}
@Override
@ -372,7 +337,14 @@ public interface ApplyFunction
}
Object accumulator = accEval.value();
FoldLambdaBinding lambdaBinding = new FoldLambdaBinding(array, accumulator, lambdaExpr, bindings);
FoldLambdaBinding lambdaBinding = new FoldLambdaBinding(
arrayEval.elementType(),
array,
accEval.type(),
accumulator,
lambdaExpr,
bindings
);
return applyFold(lambdaExpr, accumulator, lambdaBinding);
}
@ -415,6 +387,7 @@ public interface ApplyFunction
List<List<Object>> arrayInputs = new ArrayList<>();
boolean hadNull = false;
boolean hadEmpty = false;
ExpressionType arrayElementType = null;
for (int i = 0; i < argsExpr.size() - 1; i++) {
Expr expr = argsExpr.get(i);
ExprEval arrayEval = expr.eval(bindings);
@ -423,6 +396,7 @@ public interface ApplyFunction
hadNull = true;
continue;
}
arrayElementType = arrayEval.elementType();
if (array.length == 0) {
hadEmpty = true;
continue;
@ -444,7 +418,7 @@ public interface ApplyFunction
Object accumulator = accEval.value();
CartesianFoldLambdaBinding lambdaBindings =
new CartesianFoldLambdaBinding(product, accumulator, lambdaExpr, bindings);
new CartesianFoldLambdaBinding(arrayElementType, product, accEval.type(), accumulator, lambdaExpr, bindings);
return applyFold(lambdaExpr, accumulator, lambdaBindings);
}
@ -495,23 +469,9 @@ public interface ApplyFunction
return ExprEval.of(null);
}
SettableLambdaBinding lambdaBinding = new SettableLambdaBinding(lambdaExpr, bindings);
switch (arrayEval.elementType().getType()) {
case STRING:
String[] filteredString =
this.filter(arrayEval.asStringArray(), lambdaExpr, lambdaBinding).toArray(String[]::new);
return ExprEval.ofStringArray(filteredString);
case LONG:
Long[] filteredLong =
this.filter(arrayEval.asLongArray(), lambdaExpr, lambdaBinding).toArray(Long[]::new);
return ExprEval.ofLongArray(filteredLong);
case DOUBLE:
Double[] filteredDouble =
this.filter(arrayEval.asDoubleArray(), lambdaExpr, lambdaBinding).toArray(Double[]::new);
return ExprEval.ofDoubleArray(filteredDouble);
default:
throw new RE("Unhandled filter function input type [%s]", arrayEval.type());
}
SettableLambdaBinding lambdaBinding = new SettableLambdaBinding(arrayEval.elementType(), lambdaExpr, bindings);
Object[] filtered = filter(arrayEval.asArray(), lambdaExpr, lambdaBinding).toArray();
return ExprEval.ofArray(arrayEval.asArrayType(), filtered);
}
@Override
@ -565,7 +525,7 @@ public interface ApplyFunction
return ExprEval.ofLongBoolean(false);
}
SettableLambdaBinding lambdaBinding = new SettableLambdaBinding(lambdaExpr, bindings);
SettableLambdaBinding lambdaBinding = new SettableLambdaBinding(arrayEval.elementType(), lambdaExpr, bindings);
return match(array, lambdaExpr, lambdaBinding);
}
@ -654,14 +614,16 @@ public interface ApplyFunction
{
private final Expr.ObjectBinding bindings;
private final Map<String, Object> lambdaBindings;
private final ExpressionType elementType;
SettableLambdaBinding(LambdaExpr expr, Expr.ObjectBinding bindings)
SettableLambdaBinding(ExpressionType elementType, LambdaExpr expr, Expr.ObjectBinding bindings)
{
this.elementType = elementType;
this.lambdaBindings = new HashMap<>();
for (String lambdaIdentifier : expr.getIdentifiers()) {
lambdaBindings.put(lambdaIdentifier, null);
}
this.bindings = bindings != null ? bindings : Collections.emptyMap()::get;
this.bindings = bindings != null ? bindings : InputBindings.nilBindings();
}
@Nullable
@ -679,6 +641,16 @@ public interface ApplyFunction
this.lambdaBindings.put(key, value);
return this;
}
@Nullable
@Override
public ExpressionType getType(String name)
{
if (lambdaBindings.containsKey(name)) {
return elementType;
}
return bindings.getType(name);
}
}
/**
@ -707,17 +679,19 @@ public interface ApplyFunction
class MapLambdaBinding implements IndexableMapLambdaObjectBinding
{
private final Expr.ObjectBinding bindings;
private final ExpressionType arrayElementType;
@Nullable
private final String lambdaIdentifier;
private final Object[] arrayValues;
private int index = 0;
private final boolean scoped;
MapLambdaBinding(Object[] arrayValues, LambdaExpr expr, Expr.ObjectBinding bindings)
MapLambdaBinding(ExpressionType elementType, Object[] arrayValues, LambdaExpr expr, Expr.ObjectBinding bindings)
{
this.lambdaIdentifier = expr.getIdentifier();
this.arrayElementType = elementType;
this.arrayValues = arrayValues;
this.bindings = bindings != null ? bindings : Collections.emptyMap()::get;
this.bindings = bindings != null ? bindings : InputBindings.nilBindings();
this.scoped = lambdaIdentifier != null;
}
@ -743,6 +717,16 @@ public interface ApplyFunction
this.index = index;
return this;
}
@Nullable
@Override
public ExpressionType getType(String name)
{
if (scoped && name.equals(lambdaIdentifier)) {
return arrayElementType;
}
return bindings.getType(name);
}
}
/**
@ -753,14 +737,16 @@ public interface ApplyFunction
class CartesianMapLambdaBinding implements IndexableMapLambdaObjectBinding
{
private final Expr.ObjectBinding bindings;
private final ExpressionType arrayElementType;
private final Object2IntMap<String> lambdaIdentifiers;
private final List<List<Object>> lambdaInputs;
private final boolean scoped;
private int index = 0;
CartesianMapLambdaBinding(List<List<Object>> inputs, LambdaExpr expr, Expr.ObjectBinding bindings)
CartesianMapLambdaBinding(ExpressionType arrayElementType, List<List<Object>> inputs, LambdaExpr expr, Expr.ObjectBinding bindings)
{
this.lambdaInputs = inputs;
this.arrayElementType = arrayElementType;
List<String> ids = expr.getIdentifiers();
this.scoped = ids.size() > 0;
this.lambdaIdentifiers = new Object2IntArrayMap<>(ids.size());
@ -768,7 +754,7 @@ public interface ApplyFunction
lambdaIdentifiers.put(ids.get(i), i);
}
this.bindings = bindings != null ? bindings : Collections.emptyMap()::get;
this.bindings = bindings != null ? bindings : InputBindings.nilBindings();
}
@Nullable
@ -793,6 +779,16 @@ public interface ApplyFunction
this.index = index;
return this;
}
@Nullable
@Override
public ExpressionType getType(String name)
{
if (scoped && lambdaIdentifiers.containsKey(name)) {
return arrayElementType;
}
return bindings.getType(name);
}
}
/**
@ -803,6 +799,8 @@ public interface ApplyFunction
*/
interface IndexableFoldLambdaBinding extends Expr.ObjectBinding
{
ExpressionType getAccumulatorType();
/**
* Total number of bindings in this binding
*/
@ -821,20 +819,31 @@ public interface ApplyFunction
class FoldLambdaBinding implements IndexableFoldLambdaBinding
{
private final Expr.ObjectBinding bindings;
private final ExpressionType arrayElementType;
private final ExpressionType accumulatorType;
private final String elementIdentifier;
private final Object[] arrayValues;
private final String accumulatorIdentifier;
private Object accumulatorValue;
private int index;
FoldLambdaBinding(Object[] arrayValues, Object initialAccumulator, LambdaExpr expr, Expr.ObjectBinding bindings)
FoldLambdaBinding(
ExpressionType arrayElementType,
Object[] arrayValues,
ExpressionType accumulatorType,
Object initialAccumulator,
LambdaExpr expr,
Expr.ObjectBinding bindings
)
{
List<String> ids = expr.getIdentifiers();
this.elementIdentifier = ids.get(0);
this.arrayElementType = arrayElementType;
this.accumulatorType = accumulatorType;
this.accumulatorIdentifier = ids.get(1);
this.arrayValues = arrayValues;
this.accumulatorValue = initialAccumulator;
this.bindings = bindings != null ? bindings : Collections.emptyMap()::get;
this.bindings = bindings != null ? bindings : InputBindings.nilBindings();
}
@Nullable
@ -849,6 +858,12 @@ public interface ApplyFunction
return bindings.get(name);
}
@Override
public ExpressionType getAccumulatorType()
{
return accumulatorType;
}
@Override
public int getLength()
{
@ -862,6 +877,18 @@ public interface ApplyFunction
this.accumulatorValue = acc;
return this;
}
@Nullable
@Override
public ExpressionType getType(String name)
{
if (name.equals(elementIdentifier)) {
return arrayElementType;
} else if (name.equals(accumulatorIdentifier)) {
return accumulatorType;
}
return bindings.getType(name);
}
}
/**
@ -871,14 +898,25 @@ public interface ApplyFunction
class CartesianFoldLambdaBinding implements IndexableFoldLambdaBinding
{
private final Expr.ObjectBinding bindings;
private final ExpressionType arrayElementType;
private final ExpressionType accumulatorType;
private final Object2IntMap<String> lambdaIdentifiers;
private final List<List<Object>> lambdaInputs;
private final String accumulatorIdentifier;
private Object accumulatorValue;
private int index = 0;
CartesianFoldLambdaBinding(List<List<Object>> inputs, Object accumulatorValue, LambdaExpr expr, Expr.ObjectBinding bindings)
CartesianFoldLambdaBinding(
@Nullable ExpressionType arrayElementType,
List<List<Object>> inputs,
ExpressionType accumulatorType,
Object accumulatorValue,
LambdaExpr expr,
Expr.ObjectBinding bindings
)
{
this.arrayElementType = arrayElementType;
this.accumulatorType = accumulatorType;
this.lambdaInputs = inputs;
List<String> ids = expr.getIdentifiers();
this.lambdaIdentifiers = new Object2IntArrayMap<>(ids.size());
@ -886,7 +924,7 @@ public interface ApplyFunction
lambdaIdentifiers.put(ids.get(i), i);
}
this.accumulatorIdentifier = ids.get(ids.size() - 1);
this.bindings = bindings != null ? bindings : Collections.emptyMap()::get;
this.bindings = bindings != null ? bindings : InputBindings.nilBindings();
this.accumulatorValue = accumulatorValue;
}
@ -902,6 +940,12 @@ public interface ApplyFunction
return bindings.get(name);
}
@Override
public ExpressionType getAccumulatorType()
{
return accumulatorType;
}
@Override
public int getLength()
{
@ -915,6 +959,18 @@ public interface ApplyFunction
this.accumulatorValue = acc;
return this;
}
@Nullable
@Override
public ExpressionType getType(String name)
{
if (lambdaIdentifiers.containsKey(name)) {
return arrayElementType;
} else if (accumulatorIdentifier.equals(name)) {
return accumulatorType;
}
return bindings.getType(name);
}
}
/**

View File

@ -22,9 +22,12 @@ package org.apache.druid.math.expr;
import com.google.common.base.Preconditions;
import org.apache.commons.lang.StringEscapeUtils;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;
import org.apache.druid.math.expr.vector.VectorProcessors;
import org.apache.druid.segment.column.ObjectByteStrategy;
import org.apache.druid.segment.column.Types;
import javax.annotation.Nullable;
import java.util.Arrays;
@ -82,7 +85,7 @@ abstract class ConstantExpr<T> implements Expr
@Override
public BindingAnalysis analyzeInputs()
{
return new BindingAnalysis();
return BindingAnalysis.EMTPY;
}
@Override
@ -181,60 +184,6 @@ class NullLongExpr extends ConstantExpr<Long>
}
}
class LongArrayExpr extends ConstantExpr<Long[]>
{
LongArrayExpr(@Nullable Long[] value)
{
super(ExpressionType.LONG_ARRAY, value);
}
@Override
public String toString()
{
return Arrays.toString(value);
}
@Override
public ExprEval eval(ObjectBinding bindings)
{
return ExprEval.ofLongArray(value);
}
@Override
public boolean canVectorize(InputBindingInspector inspector)
{
return false;
}
@Override
public String stringify()
{
if (value.length == 0) {
return "<LONG>[]";
}
return StringUtils.format("<LONG>%s", toString());
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
LongArrayExpr that = (LongArrayExpr) o;
return Arrays.equals(value, that.value);
}
@Override
public int hashCode()
{
return Arrays.hashCode(value);
}
}
class DoubleExpr extends ConstantExpr<Double>
{
DoubleExpr(Double value)
@ -318,60 +267,6 @@ class NullDoubleExpr extends ConstantExpr<Double>
}
}
class DoubleArrayExpr extends ConstantExpr<Double[]>
{
DoubleArrayExpr(@Nullable Double[] value)
{
super(ExpressionType.DOUBLE_ARRAY, value);
}
@Override
public String toString()
{
return Arrays.toString(value);
}
@Override
public ExprEval eval(ObjectBinding bindings)
{
return ExprEval.ofDoubleArray(value);
}
@Override
public boolean canVectorize(InputBindingInspector inspector)
{
return false;
}
@Override
public String stringify()
{
if (value.length == 0) {
return "<DOUBLE>[]";
}
return StringUtils.format("<DOUBLE>%s", toString());
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
DoubleArrayExpr that = (DoubleArrayExpr) o;
return Arrays.equals(value, that.value);
}
@Override
public int hashCode()
{
return Arrays.hashCode(value);
}
}
class StringExpr extends ConstantExpr<String>
{
StringExpr(@Nullable String value)
@ -424,23 +319,19 @@ class StringExpr extends ConstantExpr<String>
}
}
class StringArrayExpr extends ConstantExpr<String[]>
class ArrayExpr extends ConstantExpr<Object[]>
{
StringArrayExpr(@Nullable String[] value)
public ArrayExpr(ExpressionType outputType, @Nullable Object[] value)
{
super(ExpressionType.STRING_ARRAY, value);
}
@Override
public String toString()
{
return Arrays.toString(value);
super(outputType, value);
Preconditions.checkArgument(outputType.isArray());
ExpressionType.checkNestedArrayAllowed(outputType);
}
@Override
public ExprEval eval(ObjectBinding bindings)
{
return ExprEval.ofStringArray(value);
return ExprEval.ofArray(outputType, value);
}
@Override
@ -452,21 +343,105 @@ class StringArrayExpr extends ConstantExpr<String[]>
@Override
public String stringify()
{
if (value.length == 0) {
return "<STRING>[]";
if (value == null) {
return NULL_LITERAL;
}
if (value.length == 0) {
return outputType.asTypeString() + "[]";
}
if (outputType.getElementType().is(ExprType.STRING)) {
return StringUtils.format(
"%s[%s]",
outputType.asTypeString(),
ARG_JOINER.join(
Arrays.stream(value)
.map(s -> s == null
? NULL_LITERAL
// escape as javascript string since string literals are wrapped in single quotes
: StringUtils.format("'%s'", StringEscapeUtils.escapeJavaScript((String) s))
)
.iterator()
)
);
} else if (outputType.getElementType().isNumeric()) {
return outputType.asTypeString() + Arrays.toString(value);
} else if (outputType.getElementType().is(ExprType.COMPLEX)) {
Object[] stringified = new Object[value.length];
for (int i = 0; i < value.length; i++) {
stringified[i] = new ComplexExpr((ExpressionType) outputType.getElementType(), value[i]).stringify();
}
// use array function to rebuild since we can't stringify complex types directly
return StringUtils.format("array(%s)", Arrays.toString(stringified));
} else if (outputType.getElementType().isArray()) {
// use array function to rebuild since the parser can't yet recognize nested arrays e.g. [['foo', 'bar'],['baz']]
Object[] stringified = new Object[value.length];
for (int i = 0; i < value.length; i++) {
stringified[i] = new ArrayExpr((ExpressionType) outputType.getElementType(), (Object[]) value[i]).stringify();
}
return StringUtils.format("array(%s)", Arrays.toString(stringified));
}
throw new IAE("cannot stringify array type %s", outputType);
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ArrayExpr that = (ArrayExpr) o;
return outputType.equals(that.outputType) && Arrays.equals(value, that.value);
}
@Override
public int hashCode()
{
return Objects.hash(outputType, Arrays.hashCode(value));
}
@Override
public String toString()
{
return Arrays.toString(value);
}
}
class ComplexExpr extends ConstantExpr<Object>
{
protected ComplexExpr(ExpressionType outputType, @Nullable Object value)
{
super(outputType, value);
}
@Override
public ExprEval eval(ObjectBinding bindings)
{
return ExprEval.ofComplex(outputType, value);
}
@Override
public boolean canVectorize(InputBindingInspector inspector)
{
return false;
}
@Override
public String stringify()
{
if (value == null) {
return StringUtils.format("complex_decode_base64('%s', %s)", outputType.getComplexTypeName(), NULL_LITERAL);
}
ObjectByteStrategy strategy = Types.getStrategy(outputType.getComplexTypeName());
if (strategy == null) {
throw new IAE("Cannot stringify type[%s]", outputType.asTypeString());
}
return StringUtils.format(
"<STRING>[%s]",
ARG_JOINER.join(
Arrays.stream(value)
.map(s -> s == null
? NULL_LITERAL
// escape as javascript string since string literals are wrapped in single quotes
: StringUtils.format("'%s'", StringEscapeUtils.escapeJavaScript(s))
)
.iterator()
)
"complex_decode_base64('%s', '%s')",
outputType.getComplexTypeName(),
StringUtils.encodeBase64String(strategy.toBytes(value))
);
}
@ -479,13 +454,13 @@ class StringArrayExpr extends ConstantExpr<String[]>
if (o == null || getClass() != o.getClass()) {
return false;
}
StringArrayExpr that = (StringArrayExpr) o;
return Arrays.equals(value, that.value);
ComplexExpr that = (ComplexExpr) o;
return outputType.equals(that.outputType) && Objects.equals(value, that.value);
}
@Override
public int hashCode()
{
return Arrays.hashCode(value);
return Objects.hash(outputType, value);
}
}

View File

@ -284,7 +284,7 @@ public interface Expr extends Cacheable
/**
* Mechanism to supply values to back {@link IdentifierExpr} during expression evaluation
*/
interface ObjectBinding
interface ObjectBinding extends InputBindingInspector
{
/**
* Get value binding for string identifier of {@link IdentifierExpr}
@ -364,13 +364,15 @@ public interface Expr extends Cacheable
@SuppressWarnings("JavadocReference")
class BindingAnalysis
{
public static final BindingAnalysis EMTPY = new BindingAnalysis();
private final ImmutableSet<IdentifierExpr> freeVariables;
private final ImmutableSet<IdentifierExpr> scalarVariables;
private final ImmutableSet<IdentifierExpr> arrayVariables;
private final boolean hasInputArrays;
private final boolean isOutputArray;
BindingAnalysis()
public BindingAnalysis()
{
this(ImmutableSet.of(), ImmutableSet.of(), ImmutableSet.of(), false, false);
}

File diff suppressed because it is too large Load Diff

View File

@ -390,7 +390,7 @@ public class ExprListenerImpl extends ExprBaseListener
@Override
public void exitDoubleArray(ExprParser.DoubleArrayContext ctx)
{
Double[] values = new Double[ctx.numericElement().size()];
Object[] values = new Object[ctx.numericElement().size()];
for (int i = 0; i < values.length; i++) {
if (ctx.numericElement(i).NULL() != null) {
values[i] = null;
@ -402,13 +402,51 @@ public class ExprListenerImpl extends ExprBaseListener
throw new RE("Failed to parse array element %s as a double", ctx.numericElement(i).getText());
}
}
nodes.put(ctx, new DoubleArrayExpr(values));
nodes.put(ctx, new ArrayExpr(ExpressionType.DOUBLE_ARRAY, values));
}
@Override
public void exitExplicitArray(ExprParser.ExplicitArrayContext ctx)
{
ExpressionType type = ExpressionType.fromString(ctx.ARRAY_TYPE().getText());
if (type == null) {
throw new RE("Failed to convert array type %s to expression type", ctx.ARRAY_TYPE().getText());
}
Object[] values = new Object[ctx.literalElement().size()];
for (int i = 0; i < values.length; i++) {
if (ctx.literalElement(i).NULL() != null) {
values[i] = null;
} else {
final ExprParser.LiteralElementContext elementContext = ctx.literalElement(i);
// if value is a string, escape quoting
final String toParse;
if (elementContext.STRING() != null) {
toParse = escapeStringLiteral(elementContext.STRING().getText());
} else {
toParse = elementContext.getText();
}
switch (type.getElementType().getType()) {
case LONG:
values[i] = Numbers.parseLongObject(toParse);
break;
case DOUBLE:
values[i] = Numbers.parseDoubleObject(toParse);
break;
case STRING:
values[i] = toParse;
break;
default:
throw new RE("Failed to parse array element %s as a %s", toParse, type.getElementType().asTypeString());
}
}
}
nodes.put(ctx, new ArrayExpr(type, values));
}
@Override
public void exitLongArray(ExprParser.LongArrayContext ctx)
{
Long[] values = new Long[ctx.longElement().size()];
Object[] values = new Object[ctx.longElement().size()];
for (int i = 0; i < values.length; i++) {
if (ctx.longElement(i).NULL() != null) {
values[i] = null;
@ -418,13 +456,13 @@ public class ExprListenerImpl extends ExprBaseListener
throw new RE("Failed to parse array element %s as a long", ctx.longElement(i).getText());
}
}
nodes.put(ctx, new LongArrayExpr(values));
nodes.put(ctx, new ArrayExpr(ExpressionType.LONG_ARRAY, values));
}
@Override
public void exitExplicitLongArray(ExprParser.ExplicitLongArrayContext ctx)
{
Long[] values = new Long[ctx.numericElement().size()];
Object[] values = new Object[ctx.numericElement().size()];
for (int i = 0; i < values.length; i++) {
if (ctx.numericElement(i).NULL() != null) {
values[i] = null;
@ -436,13 +474,13 @@ public class ExprListenerImpl extends ExprBaseListener
throw new RE("Failed to parse array element %s as a long", ctx.numericElement(i).getText());
}
}
nodes.put(ctx, new LongArrayExpr(values));
nodes.put(ctx, new ArrayExpr(ExpressionType.LONG_ARRAY, values));
}
@Override
public void exitStringArray(ExprParser.StringArrayContext ctx)
{
String[] values = new String[ctx.stringElement().size()];
Object[] values = new Object[ctx.stringElement().size()];
for (int i = 0; i < values.length; i++) {
if (ctx.stringElement(i).NULL() != null) {
values[i] = null;
@ -452,13 +490,13 @@ public class ExprListenerImpl extends ExprBaseListener
throw new RE("Failed to parse array: element %s is not a string", ctx.stringElement(i).getText());
}
}
nodes.put(ctx, new StringArrayExpr(values));
nodes.put(ctx, new ArrayExpr(ExpressionType.STRING_ARRAY, values));
}
@Override
public void exitExplicitStringArray(ExprParser.ExplicitStringArrayContext ctx)
{
String[] values = new String[ctx.literalElement().size()];
Object[] values = new Object[ctx.literalElement().size()];
for (int i = 0; i < values.length; i++) {
if (ctx.literalElement(i).NULL() != null) {
values[i] = null;
@ -472,7 +510,7 @@ public class ExprListenerImpl extends ExprBaseListener
throw new RE("Failed to parse array element %s as a string", ctx.literalElement(i).getText());
}
}
nodes.put(ctx, new StringArrayExpr(values));
nodes.put(ctx, new ArrayExpr(ExpressionType.STRING_ARRAY, values));
}
/**

View File

@ -93,10 +93,18 @@ public class ExprMacroTable
Expr apply(List<Expr> args);
}
/**
* stub interface to allow {@link Parser#flatten(Expr)} a way to recognize macro functions that exend this
*/
public interface ExprMacroFunctionExpr extends Expr
{
List<Expr> getArgs();
}
/**
* Base class for single argument {@link ExprMacro} function {@link Expr}
*/
public abstract static class BaseScalarUnivariateMacroFunctionExpr implements Expr
public abstract static class BaseScalarUnivariateMacroFunctionExpr implements ExprMacroFunctionExpr
{
protected final String name;
protected final Expr arg;
@ -111,6 +119,12 @@ public class ExprMacroTable
analyzeInputsSupplier = Suppliers.memoize(this::supplyAnalyzeInputs);
}
@Override
public List<Expr> getArgs()
{
return Collections.singletonList(arg);
}
@Override
public BindingAnalysis analyzeInputs()
{
@ -147,12 +161,19 @@ public class ExprMacroTable
{
return arg.analyzeInputs().withScalarArguments(ImmutableSet.of(arg));
}
@Override
public String toString()
{
return StringUtils.format("(%s %s)", name, getArgs());
}
}
/**
* Base class for multi-argument {@link ExprMacro} function {@link Expr}
*/
public abstract static class BaseScalarMacroFunctionExpr implements Expr
public abstract static class BaseScalarMacroFunctionExpr implements ExprMacroFunctionExpr
{
protected final String name;
protected final List<Expr> args;
@ -167,6 +188,12 @@ public class ExprMacroTable
analyzeInputsSupplier = Suppliers.memoize(this::supplyAnalyzeInputs);
}
@Override
public List<Expr> getArgs()
{
return args;
}
@Override
public String stringify()
{
@ -213,5 +240,11 @@ public class ExprMacroTable
}
return accumulator.withScalarArguments(argSet);
}
@Override
public String toString()
{
return StringUtils.format("(%s %s)", name, getArgs());
}
}
}

View File

@ -19,8 +19,6 @@
package org.apache.druid.math.expr;
import it.unimi.dsi.fastutil.bytes.Byte2ObjectArrayMap;
import it.unimi.dsi.fastutil.bytes.Byte2ObjectMap;
import org.apache.druid.segment.column.TypeDescriptor;
/**
@ -28,31 +26,11 @@ import org.apache.druid.segment.column.TypeDescriptor;
*/
public enum ExprType implements TypeDescriptor
{
DOUBLE((byte) 0x01),
LONG((byte) 0x02),
STRING((byte) 0x03),
ARRAY((byte) 0x04),
COMPLEX((byte) 0x05);
private static final Byte2ObjectMap<ExprType> TYPE_BYTES = new Byte2ObjectArrayMap<>(ExprType.values().length);
static {
for (ExprType type : ExprType.values()) {
TYPE_BYTES.put(type.getId(), type);
}
}
final byte id;
ExprType(byte id)
{
this.id = id;
}
public byte getId()
{
return id;
}
DOUBLE,
LONG,
STRING,
ARRAY,
COMPLEX;
@Override
public boolean isNumeric()
@ -71,9 +49,4 @@ public enum ExprType implements TypeDescriptor
{
return this == ExprType.ARRAY;
}
public static ExprType fromByte(byte id)
{
return TYPE_BYTES.get(id);
}
}

View File

@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.math.expr;
import com.google.common.annotations.VisibleForTesting;
import com.google.inject.Inject;
import javax.annotation.Nullable;
/**
* Like {@link org.apache.druid.common.config.NullHandling}, except for expressions processing configs
*/
public class ExpressionProcessing
{
/**
* INSTANCE is injected using static injection to avoid adding JacksonInject annotations all over the code.
* @see {@link ExpressionProcessingModule} for details.
*
* It does not take effect in all unit tests since we don't use Guice Injection. Use {@link #initializeForTests}
* when modules are not available.
*/
@Inject
private static ExpressionProcessingConfig INSTANCE;
/**
* Many unit tests do not setup modules for this value to be injected, this method provides a manual way to initialize
* {@link #INSTANCE}
* @param allowNestedArrays
*/
@VisibleForTesting
public static void initializeForTests(@Nullable Boolean allowNestedArrays)
{
INSTANCE = new ExpressionProcessingConfig(allowNestedArrays);
}
/**
* whether nulls should be replaced with default value.
*/
public static boolean allowNestedArrays()
{
// this should only be null in a unit test context
// in production this will be injected by the expression processing module
if (INSTANCE == null) {
throw new IllegalStateException(
"Expressions module not initialized, call ExpressionProcessing.initializeForTests()"
);
}
return INSTANCE.allowNestedArrays();
}
}

View File

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.math.expr;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import javax.annotation.Nullable;
public class ExpressionProcessingConfig
{
public static final String NESTED_ARRAYS_CONFIG_STRING = "druid.expressions.allowNestedArrays";
@JsonProperty("allowNestedArrays")
private final boolean allowNestedArrays;
@JsonCreator
public ExpressionProcessingConfig(@JsonProperty("allowNestedArrays") @Nullable Boolean allowNestedArrays)
{
this.allowNestedArrays = allowNestedArrays == null
? Boolean.valueOf(System.getProperty(NESTED_ARRAYS_CONFIG_STRING, "false"))
: allowNestedArrays;
}
public boolean allowNestedArrays()
{
return allowNestedArrays;
}
}

View File

@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.math.expr;
import com.google.inject.Binder;
import com.google.inject.Module;
import org.apache.druid.guice.JsonConfigProvider;
public class ExpressionProcessingModule implements Module
{
@Override
public void configure(Binder binder)
{
JsonConfigProvider.bind(binder, "druid.expressions", ExpressionProcessingConfig.class);
binder.requestStaticInjection(ExpressionProcessing.class);
}
}

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.column.BaseTypeSignature;
import org.apache.druid.segment.column.ColumnType;
@ -51,6 +52,8 @@ public class ExpressionType extends BaseTypeSignature<ExprType>
new ExpressionType(ExprType.ARRAY, null, LONG);
public static final ExpressionType DOUBLE_ARRAY =
new ExpressionType(ExprType.ARRAY, null, DOUBLE);
public static final ExpressionType UNKNOWN_COMPLEX =
new ExpressionType(ExprType.COMPLEX, null, null);
@JsonCreator
public ExpressionType(
@ -205,4 +208,11 @@ public class ExpressionType extends BaseTypeSignature<ExprType>
throw new ISE("Unsupported expression type[%s]", exprType);
}
}
public static void checkNestedArrayAllowed(ExpressionType outputType)
{
if (outputType.isArray() && outputType.getElementType().isArray() && !ExpressionProcessing.allowNestedArrays()) {
throw new IAE("Cannot create a nested array type [%s], 'druid.expressions.allowNestedArrays' must be set to true", outputType);
}
}
}

View File

@ -24,7 +24,6 @@ import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.math.expr.vector.CastToTypeVectorProcessor;
@ -32,6 +31,8 @@ import org.apache.druid.math.expr.vector.ExprVectorProcessor;
import org.apache.druid.math.expr.vector.VectorMathProcessors;
import org.apache.druid.math.expr.vector.VectorProcessors;
import org.apache.druid.math.expr.vector.VectorStringProcessors;
import org.apache.druid.segment.column.ObjectByteStrategy;
import org.apache.druid.segment.column.Types;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.joda.time.format.DateTimeFormat;
@ -39,6 +40,7 @@ import org.joda.time.format.DateTimeFormat;
import javax.annotation.Nullable;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@ -501,26 +503,12 @@ public interface Function
@Override
ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
{
switch (arrayExpr.elementType().getType()) {
case STRING:
return ExprEval.ofStringArray(add(arrayExpr.asStringArray(), scalarExpr.asString()).toArray(String[]::new));
case LONG:
return ExprEval.ofLongArray(
add(
arrayExpr.asLongArray(),
scalarExpr.isNumericNull() ? null : scalarExpr.asLong()
).toArray(Long[]::new)
);
case DOUBLE:
return ExprEval.ofDoubleArray(
add(
arrayExpr.asDoubleArray(),
scalarExpr.isNumericNull() ? null : scalarExpr.asDouble()
).toArray(Double[]::new)
);
if (!scalarExpr.type().equals(arrayExpr.elementType())) {
// try to cast
ExprEval coerced = scalarExpr.castTo(arrayExpr.elementType());
return ExprEval.ofArray(arrayExpr.asArrayType(), add(arrayExpr.asArray(), coerced.value()).toArray());
}
throw new RE("Unable to add to unknown array type %s", arrayExpr.type());
return ExprEval.ofArray(arrayExpr.asArrayType(), add(arrayExpr.asArray(), scalarExpr.value()).toArray());
}
abstract <T> Stream<T> add(T[] array, @Nullable T val);
@ -564,21 +552,13 @@ public interface Function
return lhsExpr;
}
switch (lhsExpr.elementType().getType()) {
case STRING:
return ExprEval.ofStringArray(
merge(lhsExpr.asStringArray(), rhsExpr.asStringArray()).toArray(String[]::new)
);
case LONG:
return ExprEval.ofLongArray(
merge(lhsExpr.asLongArray(), rhsExpr.asLongArray()).toArray(Long[]::new)
);
case DOUBLE:
return ExprEval.ofDoubleArray(
merge(lhsExpr.asDoubleArray(), rhsExpr.asDoubleArray()).toArray(Double[]::new)
);
if (!lhsExpr.asArrayType().equals(rhsExpr.asArrayType())) {
// try to cast if they types don't match
ExprEval coerced = rhsExpr.castTo(lhsExpr.asArrayType());
ExprEval.ofArray(lhsExpr.asArrayType(), merge(lhsExpr.asArray(), coerced.asArray()).toArray());
}
throw new RE("Unable to concatenate to unknown type %s", lhsExpr.type());
return ExprEval.ofArray(lhsExpr.asArrayType(), merge(lhsExpr.asArray(), rhsExpr.asArray()).toArray());
}
abstract <T> Stream<T> merge(T[] array1, T[] array2);
@ -2925,71 +2905,17 @@ public interface Function
// this is copied from 'BaseMapFunction.applyMap', need to find a better way to consolidate, or construct arrays,
// or.. something...
final int length = args.size();
String[] stringsOut = null;
Long[] longsOut = null;
Double[] doublesOut = null;
Object[] out = new Object[length];
ExpressionType elementType = null;
ExpressionType arrayType = null;
for (int i = 0; i < length; i++) {
ExprEval<?> evaluated = args.get(i).eval(bindings);
if (elementType == null) {
elementType = evaluated.type();
switch (elementType.getType()) {
case STRING:
stringsOut = new String[length];
break;
case LONG:
longsOut = new Long[length];
break;
case DOUBLE:
doublesOut = new Double[length];
break;
default:
throw new RE("Unhandled array constructor element type [%s]", elementType);
}
}
setArrayOutputElement(stringsOut, longsOut, doublesOut, elementType, i, evaluated);
arrayType = setArrayOutput(arrayType, out, i, evaluated);
}
// There should be always at least one argument and thus elementType is never null.
// See validateArguments().
//noinspection ConstantConditions
switch (elementType.getType()) {
case STRING:
return ExprEval.ofStringArray(stringsOut);
case LONG:
return ExprEval.ofLongArray(longsOut);
case DOUBLE:
return ExprEval.ofDoubleArray(doublesOut);
default:
throw new RE("Unhandled array constructor element type [%s]", elementType);
}
return ExprEval.ofArray(arrayType, out);
}
static void setArrayOutputElement(
String[] stringsOut,
Long[] longsOut,
Double[] doublesOut,
ExpressionType elementType,
int i,
ExprEval evaluated
)
{
switch (elementType.getType()) {
case STRING:
stringsOut[i] = evaluated.asString();
break;
case LONG:
longsOut[i] = evaluated.isNumericNull() ? null : evaluated.asLong();
break;
case DOUBLE:
doublesOut[i] = evaluated.isNumericNull() ? null : evaluated.asDouble();
break;
}
}
@Override
public Set<Expr> getScalarInputs(List<Expr> args)
{
@ -3026,6 +2952,29 @@ public interface Function
}
return ExpressionType.asArrayType(type);
}
/**
* Set an array element to the output array, checking for null if the array is numeric. If the type of the evaluated
* array element does not match the array element type, this method will attempt to call {@link ExprEval#castTo}
* to the array element type, else will set the element as is. If the type of the array is unknown, it will be
* detected and defined from the first element. Returns the type of the array, which will be identical to the input
* type, unless the input type was null.
*/
static ExpressionType setArrayOutput(@Nullable ExpressionType arrayType, Object[] out, int i, ExprEval evaluated)
{
if (arrayType == null) {
arrayType = ExpressionTypeFactory.getInstance().ofArray(evaluated.type());
}
ExpressionType.checkNestedArrayAllowed(arrayType);
if (arrayType.getElementType().isNumeric() && evaluated.isNumericNull()) {
out[i] = null;
} else if (!evaluated.asArrayType().equals(arrayType)) {
out[i] = evaluated.castTo((ExpressionType) arrayType.getElementType()).value();
} else {
out[i] = evaluated.value();
}
return arrayType;
}
}
class ArrayLengthFunction implements Function
@ -3186,7 +3135,7 @@ public interface Function
final int position = scalarExpr.asInt();
if (array.length > position) {
return ExprEval.bestEffortOf(array[position]);
return ExprEval.ofType(arrayExpr.elementType(), array[position]);
}
return ExprEval.of(null);
}
@ -3214,7 +3163,7 @@ public interface Function
final int position = scalarExpr.asInt() - 1;
if (array.length > position) {
return ExprEval.bestEffortOf(array[position]);
return ExprEval.ofType(arrayExpr.elementType(), array[position]);
}
return ExprEval.of(null);
}
@ -3521,15 +3470,7 @@ public interface Function
return ExprEval.of(null);
}
switch (expr.elementType().getType()) {
case STRING:
return ExprEval.ofStringArray(Arrays.copyOfRange(expr.asStringArray(), start, end));
case LONG:
return ExprEval.ofLongArray(Arrays.copyOfRange(expr.asLongArray(), start, end));
case DOUBLE:
return ExprEval.ofDoubleArray(Arrays.copyOfRange(expr.asDoubleArray(), start, end));
}
throw new RE("Unable to slice to unknown type %s", expr.type());
return ExprEval.ofArray(expr.asArrayType(), Arrays.copyOfRange(expr.asArray(), start, end));
}
}
@ -3631,4 +3572,78 @@ public interface Function
return HumanReadableBytes.UnitSystem.DECIMAL;
}
}
class ComplexDecodeBase64Function implements Function
{
@Override
public String name()
{
return "complex_decode_base64";
}
@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
ExprEval arg0 = args.get(0).eval(bindings);
if (!arg0.type().is(ExprType.STRING)) {
throw new IAE(
"Function[%s] first argument must be constant 'STRING' expression containing a valid complex type name",
name()
);
}
ExpressionType complexType = ExpressionTypeFactory.getInstance().ofComplex((String) args.get(0).getLiteralValue());
ObjectByteStrategy strategy = Types.getStrategy(complexType.getComplexTypeName());
if (strategy == null) {
throw new IAE(
"Function[%s] first argument must be a valid complex type name, unknown complex type [%s]",
name(),
complexType.asTypeString()
);
}
ExprEval base64String = args.get(1).eval(bindings);
if (!base64String.type().is(ExprType.STRING)) {
throw new IAE(
"Function[%s] second argument must be a base64 encoded 'STRING' value",
name()
);
}
if (base64String.value() == null) {
return ExprEval.ofComplex(complexType, null);
}
final byte[] base64 = StringUtils.decodeBase64String(base64String.asString());
return ExprEval.ofComplex(complexType, strategy.fromByteBuffer(ByteBuffer.wrap(base64), base64.length));
}
@Override
public void validateArguments(List<Expr> args)
{
if (args.size() != 2) {
throw new IAE("Function[%s] needs 2 arguments", name());
}
if (!args.get(0).isLiteral() || args.get(0).isNullLiteral()) {
throw new IAE(
"Function[%s] first argument must be constant 'STRING' expression containing a valid complex type name",
name()
);
}
}
@Nullable
@Override
public ExpressionType getOutputType(
Expr.InputBindingInspector inspector,
List<Expr> args
)
{
ExpressionType arg0Type = args.get(0).getOutputType(inspector);
if (arg0Type == null || !arg0Type.is(ExprType.STRING)) {
throw new IAE(
"Function[%s] first argument must be constant 'STRING' expression containing a valid complex type name",
name()
);
}
return ExpressionTypeFactory.getInstance().ofComplex((String) args.get(0).getLiteralValue());
}
}
}

View File

@ -122,7 +122,7 @@ class IdentifierExpr implements Expr
@Override
public ExprEval eval(ObjectBinding bindings)
{
return ExprEval.bestEffortOf(bindings.get(binding));
return ExprEval.ofType(bindings.getType(binding), bindings.get(binding));
}
@Override

View File

@ -20,12 +20,36 @@
package org.apache.druid.math.expr;
import com.google.common.base.Supplier;
import org.apache.druid.java.util.common.Pair;
import javax.annotation.Nullable;
import java.util.Map;
import java.util.function.Function;
public class InputBindings
{
private static final Expr.ObjectBinding NIL_BINDINGS = new Expr.ObjectBinding()
{
@Nullable
@Override
public Object get(String name)
{
return null;
}
@Nullable
@Override
public ExpressionType getType(String name)
{
return null;
}
};
public static Expr.ObjectBinding nilBindings()
{
return NIL_BINDINGS;
}
/**
* Create an {@link Expr.InputBindingInspector} backed by a map of binding identifiers to their {@link ExprType}
*/
@ -42,23 +66,95 @@ public class InputBindings
};
}
public static Expr.ObjectBinding singleProvider(ExpressionType type, final Function<String, ?> valueFn)
{
return new Expr.ObjectBinding()
{
@Nullable
@Override
public Object get(String name)
{
return valueFn.apply(name);
}
@Nullable
@Override
public ExpressionType getType(String name)
{
return type;
}
};
}
public static Expr.ObjectBinding forFunction(final Function<String, ?> valueFn)
{
return new Expr.ObjectBinding()
{
@Nullable
@Override
public Object get(String name)
{
return valueFn.apply(name);
}
@Nullable
@Override
public ExpressionType getType(String name)
{
return ExprEval.bestEffortOf(valueFn.apply(name)).type();
}
};
}
/**
* Create {@link Expr.ObjectBinding} backed by {@link Map} to provide values for identifiers to evaluate {@link Expr}
*/
public static Expr.ObjectBinding withMap(final Map<String, ?> bindings)
{
return bindings::get;
return new Expr.ObjectBinding()
{
@Nullable
@Override
public Object get(String name)
{
return bindings.get(name);
}
@Nullable
@Override
public ExpressionType getType(String name)
{
return ExprEval.bestEffortOf(bindings.get(name)).type();
}
};
}
/**
* Create {@link Expr.ObjectBinding} backed by map of {@link Supplier} to provide values for identifiers to evaluate
* {@link Expr}
*/
public static Expr.ObjectBinding withSuppliers(final Map<String, Supplier<Object>> bindings)
public static Expr.ObjectBinding withTypedSuppliers(final Map<String, Pair<ExpressionType, Supplier<Object>>> bindings)
{
return (String name) -> {
Supplier<Object> supplier = bindings.get(name);
return supplier == null ? null : supplier.get();
return new Expr.ObjectBinding()
{
@Nullable
@Override
public Object get(String name)
{
Pair<ExpressionType, Supplier<Object>> binding = bindings.get(name);
return binding == null || binding.rhs == null ? null : binding.rhs.get();
}
@Nullable
@Override
public ExpressionType getType(String name)
{
Pair<ExpressionType, Supplier<Object>> binding = bindings.get(name);
if (binding == null) {
return null;
}
return binding.lhs;
}
};
}
}

View File

@ -153,28 +153,33 @@ public class Parser
if (childExpr instanceof BinaryOpExprBase) {
BinaryOpExprBase binary = (BinaryOpExprBase) childExpr;
if (Evals.isAllConstants(binary.left, binary.right)) {
return childExpr.eval(null).toExpr();
return childExpr.eval(InputBindings.nilBindings()).toExpr();
}
} else if (childExpr instanceof UnaryExpr) {
UnaryExpr unary = (UnaryExpr) childExpr;
if (unary.expr instanceof ConstantExpr) {
return childExpr.eval(null).toExpr();
return childExpr.eval(InputBindings.nilBindings()).toExpr();
}
} else if (childExpr instanceof FunctionExpr) {
FunctionExpr functionExpr = (FunctionExpr) childExpr;
List<Expr> args = functionExpr.args;
if (Evals.isAllConstants(args)) {
return childExpr.eval(null).toExpr();
return childExpr.eval(InputBindings.nilBindings()).toExpr();
}
} else if (childExpr instanceof ApplyFunctionExpr) {
ApplyFunctionExpr applyFunctionExpr = (ApplyFunctionExpr) childExpr;
List<Expr> args = applyFunctionExpr.argsExpr;
if (Evals.isAllConstants(args)) {
if (applyFunctionExpr.analyzeInputs().getFreeVariables().size() == 0) {
return childExpr.eval(null).toExpr();
return childExpr.eval(InputBindings.nilBindings()).toExpr();
}
}
} else if (childExpr instanceof ExprMacroTable.ExprMacroFunctionExpr) {
ExprMacroTable.ExprMacroFunctionExpr macroFn = (ExprMacroTable.ExprMacroFunctionExpr) childExpr;
if (Evals.isAllConstants(macroFn.getArgs())) {
return childExpr.eval(InputBindings.nilBindings()).toExpr();
}
}
return childExpr;
});

View File

@ -32,6 +32,7 @@ import java.util.Map;
public class SettableObjectBinding implements Expr.ObjectBinding
{
private final Map<String, Object> bindings;
private Expr.InputBindingInspector inspector = InputBindings.nilBindings();
public SettableObjectBinding()
{
@ -50,12 +51,25 @@ public class SettableObjectBinding implements Expr.ObjectBinding
return bindings.get(name);
}
@Nullable
@Override
public ExpressionType getType(String name)
{
return inspector.getType(name);
}
public SettableObjectBinding withBinding(String name, @Nullable Object value)
{
bindings.put(name, value);
return this;
}
public SettableObjectBinding withInspector(Expr.InputBindingInspector inspector)
{
this.inspector = inspector;
return this;
}
@VisibleForTesting
public Map<String, Object> asMap()
{

View File

@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.segment.column;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.util.Comparator;
/**
* Naming is hard. This is the core interface extracted from another interface called ObjectStrategy that lives in
* 'druid-processing'. It provides basic methods for handling converting some type of object to a binary form, reading
* the binary form back into an object from a {@link ByteBuffer}, and mechanism to perform comparisons between objects.
*
* Complex types register one of these in {@link Types#registerStrategy}, which can be retrieved by the complex
* type name to convert values to and from binary format, and compare them.
*
* This could be recombined with 'ObjectStrategy' should these two modules be combined.
*/
public interface ObjectByteStrategy<T> extends Comparator<T>
{
Class<? extends T> getClazz();
/**
* Convert values from their underlying byte representation.
*
* Implementations of this method <i>may</i> change the given buffer's mark, or limit, and position.
*
* Implementations of this method <i>may not</i> store the given buffer in a field of the "deserialized" object,
* need to use {@link ByteBuffer#slice()}, {@link ByteBuffer#asReadOnlyBuffer()} or {@link ByteBuffer#duplicate()} in
* this case.
*
* @param buffer buffer to read value from
* @param numBytes number of bytes used to store the value, starting at buffer.position()
* @return an object created from the given byte buffer representation
*/
@Nullable
T fromByteBuffer(ByteBuffer buffer, int numBytes);
@Nullable
byte[] toBytes(@Nullable T val);
}

View File

@ -20,14 +20,23 @@
package org.apache.druid.segment.column;
import com.google.common.base.Preconditions;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.util.concurrent.ConcurrentHashMap;
public class Types
{
private static final String ARRAY_PREFIX = "ARRAY<";
private static final String COMPLEX_PREFIX = "COMPLEX<";
private static final int VALUE_OFFSET = Byte.BYTES;
private static final int NULLABLE_LONG_SIZE = Byte.BYTES + Long.BYTES;
private static final int NULLABLE_DOUBLE_SIZE = Byte.BYTES + Double.BYTES;
private static final int NULLABLE_FLOAT_SIZE = Byte.BYTES + Float.BYTES;
private static final ConcurrentHashMap<String, ObjectByteStrategy<?>> STRATEGIES = new ConcurrentHashMap<>();
/**
* Create a {@link TypeSignature} given the value of {@link TypeSignature#asTypeString()} and a {@link TypeFactory}
@ -112,4 +121,610 @@ public class Types
return (typeSignature1 != null && typeSignature1.is(typeDescriptor)) ||
(typeSignature2 != null && typeSignature2.is(typeDescriptor));
}
/**
* Get an {@link ObjectByteStrategy} registered to some {@link TypeSignature#getComplexTypeName()}.
*/
@Nullable
public static ObjectByteStrategy<?> getStrategy(String type)
{
return STRATEGIES.get(type);
}
/**
* hmm... this might look familiar... (see ComplexMetrics)
*
* Register a complex type name -> {@link ObjectByteStrategy} mapping.
*
* If the specified type name or type id are already used and the supplied {@link ObjectByteStrategy} is not of the
* same type as the existing value in the map for said key, an {@link ISE} is thrown.
*
* @param strategy The {@link ObjectByteStrategy} object to be associated with the 'type' in the map.
*/
public static void registerStrategy(String typeName, ObjectByteStrategy<?> strategy)
{
Preconditions.checkNotNull(typeName);
STRATEGIES.compute(typeName, (key, value) -> {
if (value == null) {
return strategy;
} else {
if (!value.getClass().getName().equals(strategy.getClass().getName())) {
throw new ISE(
"Incompatible strategy for type[%s] already exists. Expected [%s], found [%s].",
key,
strategy.getClass().getName(),
value.getClass().getName()
);
} else {
return value;
}
}
});
}
/**
* Clear and set the 'null' byte of a nullable value to {@link NullHandling#IS_NULL_BYTE} to a {@link ByteBuffer} at
* the supplied position. This method does not change the buffer position, limit, or mark, because it does not expect
* to own the buffer given to it (i.e. buffer aggs)
*
* Nullable types are stored with a leading byte to indicate if the value is null, followed by the value bytes
* (if not null)
*
* layout: | null (byte) | value |
*
* @return number of bytes written (always 1)
*/
public static int writeNull(ByteBuffer buffer, int offset)
{
buffer.put(offset, NullHandling.IS_NULL_BYTE);
return 1;
}
/**
* Checks if a 'nullable' value's null byte is set to {@link NullHandling#IS_NULL_BYTE}. This method will mask the
* value of the null byte to only check if the null bit is set, meaning that the higher bits can be utilized for
* flags as necessary (e.g. using high bits to indicate if the value has been set or not for aggregators).
*
* Note that writing nullable values with the methods of {@link Types} will always clear and set the null byte to
* either {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE}, losing any flag bits.
*
* layout: | null (byte) | value |
*/
public static boolean isNullableNull(ByteBuffer buffer, int offset)
{
// use & so that callers can use the high bits of the null byte to pack additional information if necessary
return (buffer.get(offset) & NullHandling.IS_NULL_BYTE) == NullHandling.IS_NULL_BYTE;
}
/**
* Write a non-null long value to a {@link ByteBuffer} at the supplied offset. The first byte is always cleared and
* set to {@link NullHandling#IS_NOT_NULL_BYTE}, the long value is written in the next 8 bytes.
*
* layout: | null (byte) | long |
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*
* @return number of bytes written (always 9)
*/
public static int writeNullableLong(ByteBuffer buffer, int offset, long value)
{
buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
buffer.putLong(offset, value);
return NULLABLE_LONG_SIZE;
}
/**
* Reads a non-null long value from a {@link ByteBuffer} at the supplied offset. This method should only be called
* if and only if {@link #isNullableNull} for the same offset returns false.
*
* layout: | null (byte) | long |
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*/
public static long readNullableLong(ByteBuffer buffer, int offset)
{
assert !isNullableNull(buffer, offset);
return buffer.getLong(offset + VALUE_OFFSET);
}
/**
* Write a non-null double value to a {@link ByteBuffer} at the supplied offset. The first byte is always cleared and
* set to {@link NullHandling#IS_NOT_NULL_BYTE}, the double value is written in the next 8 bytes.
*
* layout: | null (byte) | double |
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*
* @return number of bytes written (always 9)
*/
public static int writeNullableDouble(ByteBuffer buffer, int offset, double value)
{
buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
buffer.putDouble(offset, value);
return NULLABLE_DOUBLE_SIZE;
}
/**
* Reads a non-null double value from a {@link ByteBuffer} at the supplied offset. This method should only be called
* if and only if {@link #isNullableNull} for the same offset returns false.
*
* layout: | null (byte) | double |
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*/
public static double readNullableDouble(ByteBuffer buffer, int offset)
{
assert !isNullableNull(buffer, offset);
return buffer.getDouble(offset + VALUE_OFFSET);
}
/**
* Write a non-null float value to a {@link ByteBuffer} at the supplied offset. The first byte is always cleared and
* set to {@link NullHandling#IS_NOT_NULL_BYTE}, the float value is written in the next 4 bytes.
*
* layout: | null (byte) | float |
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*
* @return number of bytes written (always 5)
*/
public static int writeNullableFloat(ByteBuffer buffer, int offset, float value)
{
buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
buffer.putFloat(offset, value);
return NULLABLE_FLOAT_SIZE;
}
/**
* Reads a non-null float value from a {@link ByteBuffer} at the supplied offset. This method should only be called
* if and only if {@link #isNullableNull} for the same offset returns false.
*
* layout: | null (byte) | float |
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*/
public static float readNullableFloat(ByteBuffer buffer, int offset)
{
assert !isNullableNull(buffer, offset);
return buffer.getFloat(offset + VALUE_OFFSET);
}
/**
* Write a variably lengthed byte[] value to a {@link ByteBuffer} at the supplied offset. The first byte is set to
* {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE} as appropriate, and if the byte[] value
* is not null, the size in bytes is written as an integer in the next 4 bytes, followed by the byte[] value itself.
*
* layout: | null (byte) | size (int) | byte[] |
*
* This method checks that no more than the specified maximum number of bytes can be written to the buffer, and the
* proper function of this method requires that the buffer contains at least that many bytes free from the starting
* offset. See {@link #writeNullableVariableBlob(ByteBuffer, int, byte[])} if you do not need to check the length
* of the byte array, or wish to perform the check externally.
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*
* @return number of bytes written (1 if null, or 5 + size of byte[] if not)
*/
public static int writeNullableVariableBlob(
ByteBuffer buffer,
int offset,
@Nullable byte[] value,
TypeSignature<?> type,
int maxSizeBytes
)
{
if (value == null) {
return writeNull(buffer, offset);
}
// | null (byte) | length (int) | bytes |
checkMaxBytes(
type,
1 + Integer.BYTES + value.length,
maxSizeBytes
);
return writeNullableVariableBlob(buffer, offset, value);
}
/**
* Write a variably lengthed byte[] value to a {@link ByteBuffer} at the supplied offset. The first byte is set to
* {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE} as appropriate, and if the byte[] value
* is not null, the size in bytes is written as an integer in the next 4 bytes, followed by the byte[] value itself.
*
* layout: | null (byte) | size (int) | byte[] |
*
* This method does not constrain the number of bytes written to the buffer, so either use
* {@link #writeNullableVariableBlob(ByteBuffer, int, byte[], TypeSignature, int)} or first check that the size
* of the byte array plus 5 bytes is available in the buffer before using this method.
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*
* @return number of bytes written (1 if null, or 5 + size of byte[] if not)
*/
public static int writeNullableVariableBlob(ByteBuffer buffer, int offset, @Nullable byte[] value)
{
// | null (byte) | length (int) | bytes |
final int size;
if (value == null) {
return writeNull(buffer, offset);
}
final int oldPosition = buffer.position();
buffer.position(offset);
buffer.put(NullHandling.IS_NOT_NULL_BYTE);
buffer.putInt(value.length);
buffer.put(value, 0, value.length);
size = buffer.position() - offset;
buffer.position(oldPosition);
return size;
}
/**
* Reads a nullable variably lengthed byte[] value from a {@link ByteBuffer} at the supplied offset. If the null byte
* is set to {@link NullHandling#IS_NULL_BYTE}, this method will return null, else it will read the next 4 bytes to
* get the byte[] size followed by that many bytes to extract the value.
*
* layout: | null (byte) | size (int) | byte[] |
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*/
@Nullable
public static byte[] readNullableVariableBlob(ByteBuffer buffer, int offset)
{
// | null (byte) | length (int) | bytes |
final int length = buffer.getInt(offset + VALUE_OFFSET);
final byte[] blob = new byte[length];
final int oldPosition = buffer.position();
buffer.position(offset + VALUE_OFFSET + Integer.BYTES);
buffer.get(blob, 0, length);
buffer.position(oldPosition);
return blob;
}
/**
* Write a variably lengthed Long[] value to a {@link ByteBuffer} at the supplied offset. The first byte is set to
* {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE} as appropriate, and if the Long[] value
* is not null, the size in bytes is written as an integer in the next 4 bytes. Elements of the array are each written
* out with {@link #writeNull} if null, or {@link #writeNullableLong} if not, taking either 1 or 9 bytes each. If the
* total byte size of serializing the array is larger than the max size parameter, this method will explode via a call
* to {@link #checkMaxBytes}.
*
* layout: | null (byte) | size (int) | {| null (byte) | long |, | null (byte) |, ... |null (byte) | long |} |
*
* This method checks that no more than the specified maximum number of bytes can be written to the buffer, and the
* proper function of this method requires that the buffer contains at least that many bytes free from the starting
* offset.
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*
* @return number of bytes written (1 if null, or 5 + size of Long[] if not)
*/
public static int writeNullableLongArray(ByteBuffer buffer, int offset, @Nullable Long[] array, int maxSizeBytes)
{
// | null (byte) | array length (int) | array bytes |
if (array == null) {
return writeNull(buffer, offset);
}
int sizeBytes = 1 + Integer.BYTES;
buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
buffer.putInt(offset + 1, array.length);
for (Long element : array) {
if (element != null) {
checkMaxBytes(
ColumnType.LONG_ARRAY,
sizeBytes + 1 + Long.BYTES,
maxSizeBytes
);
sizeBytes += writeNullableLong(buffer, offset + sizeBytes, element);
} else {
checkMaxBytes(
ColumnType.LONG_ARRAY,
sizeBytes + 1,
maxSizeBytes
);
sizeBytes += writeNull(buffer, offset + sizeBytes);
}
}
return sizeBytes;
}
/**
* Reads a nullable variably lengthed Long[] value from a {@link ByteBuffer} at the supplied offset. If the null byte
* is set to {@link NullHandling#IS_NULL_BYTE}, this method will return null, else it will read the size of the array
* from the next 4 bytes and then read that many elements with {@link #isNullableNull} and {@link #readNullableLong}.
*
* layout: | null (byte) | size (int) | {| null (byte) | long |, | null (byte) |, ... |null (byte) | long |} |
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*/
@Nullable
public static Long[] readNullableLongArray(ByteBuffer buffer, int offset)
{
// | null (byte) | array length (int) | array bytes |
if (isNullableNull(buffer, offset++)) {
return null;
}
final int longArrayLength = buffer.getInt(offset);
offset += Integer.BYTES;
final Long[] longs = new Long[longArrayLength];
for (int i = 0; i < longArrayLength; i++) {
if (isNullableNull(buffer, offset)) {
longs[i] = null;
} else {
longs[i] = readNullableLong(buffer, offset);
offset += Long.BYTES;
}
offset++;
}
return longs;
}
/**
* Write a variably lengthed Double[] value to a {@link ByteBuffer} at the supplied offset. The first byte is set to
* {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE} as appropriate, and if the Long[] value
* is not null, the size in bytes is written as an integer in the next 4 bytes. Elements of the array are each written
* out with {@link #writeNull} if null, or {@link #writeNullableDouble} if not, taking either 1 or 9 bytes each. If
* the total byte size of serializing the array is larger than the max size parameter, this method will explode via a
* call to {@link #checkMaxBytes}.
*
* layout: | null (byte) | size (int) | {| null (byte) | double |, | null (byte) |, ... |null (byte) | double |} |
*
* This method checks that no more than the specified maximum number of bytes can be written to the buffer, and the
* proper function of this method requires that the buffer contains at least that many bytes free from the starting
* offset.
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*
* @return number of bytes written (1 if null, or 5 + size of Double[] if not)
*/
public static int writeNullableDoubleArray(ByteBuffer buffer, int offset, @Nullable Double[] array, int maxSizeBytes)
{
// | null (byte) | array length (int) | array bytes |
if (array == null) {
return writeNull(buffer, offset);
}
int sizeBytes = 1 + Integer.BYTES;
buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
buffer.putInt(offset + 1, array.length);
for (Double element : array) {
if (element != null) {
checkMaxBytes(
ColumnType.DOUBLE_ARRAY,
sizeBytes + 1 + Double.BYTES,
maxSizeBytes
);
sizeBytes += writeNullableDouble(buffer, offset + sizeBytes, element);
} else {
checkMaxBytes(
ColumnType.DOUBLE_ARRAY,
sizeBytes + 1,
maxSizeBytes
);
sizeBytes += writeNull(buffer, offset + sizeBytes);
}
}
return sizeBytes;
}
/**
* Reads a nullable variably lengthed Double[] value from a {@link ByteBuffer} at the supplied offset. If the null
* byte is set to {@link NullHandling#IS_NULL_BYTE}, this method will return null, else it will read the size of the
* array from the next 4 bytes and then read that many elements with {@link #isNullableNull} and
* {@link #readNullableDouble}.
*
* layout: | null (byte) | size (int) | {| null (byte) | double |, | null (byte) |, ... |null (byte) | double |} |
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*/
@Nullable
public static Double[] readNullableDoubleArray(ByteBuffer buffer, int offset)
{
// | null (byte) | array length (int) | array bytes |
if (isNullableNull(buffer, offset++)) {
return null;
}
final int doubleArrayLength = buffer.getInt(offset);
offset += Integer.BYTES;
final Double[] doubles = new Double[doubleArrayLength];
for (int i = 0; i < doubleArrayLength; i++) {
if (isNullableNull(buffer, offset)) {
doubles[i] = null;
} else {
doubles[i] = readNullableDouble(buffer, offset);
offset += Double.BYTES;
}
offset++;
}
return doubles;
}
/**
* Write a variably lengthed String[] value to a {@link ByteBuffer} at the supplied offset. The first byte is set to
* {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE} as appropriate, and if the String[]
* value is not null, the size in bytes is written as an integer in the next 4 bytes. The Strings themselves are
* encoded with {@link StringUtils#toUtf8} Elements of the array are each written out with {@link #writeNull} if null,
* or {@link #writeNullableVariableBlob} if not, taking either 1 or 5 + the size of the utf8 byte array each. If the
* total byte size of serializing the array is larger than the max size parameter, this method will explode via a
* call to {@link #checkMaxBytes}.
*
* layout: | null (byte) | size (int) | {| null (byte) | size (int) | byte[] |, | null (byte) |, ... } |
*
* This method checks that no more than the specified maximum number of bytes can be written to the buffer, and the
* proper function of this method requires that the buffer contains at least that many bytes free from the starting
* offset.
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*
* @return number of bytes written (1 if null, or 5 + size of String[] if not)
*/
public static int writeNullableStringArray(ByteBuffer buffer, int offset, @Nullable String[] array, int maxSizeBytes)
{
// | null (byte) | array length (int) | array bytes |
if (array == null) {
return writeNull(buffer, offset);
}
int sizeBytes = 1 + Integer.BYTES;
buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
buffer.putInt(offset + 1, array.length);
for (String element : array) {
if (element != null) {
final byte[] stringElementBytes = StringUtils.toUtf8(element);
checkMaxBytes(
ColumnType.STRING_ARRAY,
sizeBytes + 1 + Integer.BYTES + stringElementBytes.length,
maxSizeBytes
);
sizeBytes += writeNullableVariableBlob(buffer, offset + sizeBytes, stringElementBytes);
} else {
checkMaxBytes(
ColumnType.STRING_ARRAY,
sizeBytes + 1,
maxSizeBytes
);
sizeBytes += writeNull(buffer, offset + sizeBytes);
}
}
return sizeBytes;
}
/**
* Reads a nullable variably lengthed String[] value from a {@link ByteBuffer} at the supplied offset. If the null
* byte is set to {@link NullHandling#IS_NULL_BYTE}, this method will return null, else it will read the size of the
* array from the next 4 bytes and then read that many elements with {@link #readNullableVariableBlob} and decode them
* with {@link StringUtils#fromUtf8} to convert to string values.
*
* layout: | null (byte) | size (int) | {| null (byte) | size (int) | byte[] |, | null (byte) |, ... } |
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*/
@Nullable
public static String[] readNullableStringArray(ByteBuffer buffer, int offset)
{
// | null (byte) | array length (int) | array bytes |
if (isNullableNull(buffer, offset++)) {
return null;
}
final int stringArrayLength = buffer.getInt(offset);
offset += Integer.BYTES;
final String[] stringArray = new String[stringArrayLength];
for (int i = 0; i < stringArrayLength; i++) {
if (isNullableNull(buffer, offset)) {
stringArray[i] = null;
} else {
final byte[] stringElementBytes = readNullableVariableBlob(buffer, offset);
stringArray[i] = StringUtils.fromUtf8(stringElementBytes);
offset += Integer.BYTES + stringElementBytes.length;
}
offset++;
}
return stringArray;
}
/**
* Write a variably lengthed byte[] value derived from some {@link ObjectByteStrategy} for a complex
* {@link TypeSignature} to a {@link ByteBuffer} at the supplied offset. The first byte is set to
* {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE} as appropriate, and if the value
* is not null, the size in bytes is written as an integer in the next 4 bytes, followed by the byte[] value itself
* from {@link ObjectByteStrategy#toBytes}.
*
* layout: | null (byte) | size (int) | byte[] |
*
* Note that the {@link TypeSignature#getComplexTypeName()} MUST have registered an {@link ObjectByteStrategy} with
* {@link #registerStrategy} for this method to work, else a null pointer exception will be thrown.
*
* This method checks that no more than the specified maximum number of bytes can be written to the buffer, and the
* proper function of this method requires that the buffer contains at least that many bytes free from the starting
* offset.
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*
* @return number of bytes written (1 if null, or 5 + size of byte[] if not)
*/
public static <T> int writeNullableComplexType(
ByteBuffer buffer,
int offset,
TypeSignature<?> type,
@Nullable T value,
int maxSizeBytes
)
{
final ObjectByteStrategy strategy = Preconditions.checkNotNull(
getStrategy(type.getComplexTypeName()),
StringUtils.format(
"Type %s has not registered an ObjectByteStrategy and cannot be written",
type.asTypeString()
)
);
if (value == null) {
return writeNull(buffer, offset);
}
final byte[] complexBytes = strategy.toBytes(value);
return writeNullableVariableBlob(buffer, offset, complexBytes, type, maxSizeBytes);
}
/**
* Read a possibly null, variably lengthed byte[] value derived from some {@link ObjectByteStrategy} for a complex
* {@link TypeSignature} from a {@link ByteBuffer} at the supplied offset. If the first byte is set to
* {@link NullHandling#IS_NULL_BYTE}, this method will return null, and if the value is not null, the size in bytes
* is read as an integer from the next 4 bytes, followed by the byte[] value itself from
* {@link ObjectByteStrategy#fromByteBuffer}.
*
* layout: | null (byte) | size (int) | byte[] |
*
* Note that the {@link TypeSignature#getComplexTypeName()} MUST have registered an {@link ObjectByteStrategy} with
* {@link #registerStrategy} for this method to work, else a null pointer exception will be thrown.
*
* This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
* given to it (i.e. buffer aggs)
*/
@Nullable
public static Object readNullableComplexType(ByteBuffer buffer, int offset, TypeSignature<?> type)
{
if (isNullableNull(buffer, offset++)) {
return null;
}
final ObjectByteStrategy strategy = Preconditions.checkNotNull(
getStrategy(type.getComplexTypeName()),
StringUtils.format(
"Type %s has not registered an ObjectByteStrategy and cannot be read",
type.asTypeString()
)
);
final int complexLength = buffer.getInt(offset);
offset += Integer.BYTES;
ByteBuffer dupe = buffer.duplicate();
dupe.position(offset);
dupe.limit(offset + complexLength);
return strategy.fromByteBuffer(dupe, complexLength);
}
/**
* Throw an {@link ISE} for consistent error messaging if the size to be written is greater than the max size
*/
public static void checkMaxBytes(TypeSignature<?> type, int sizeBytes, int maxSizeBytes)
{
if (sizeBytes > maxSizeBytes) {
throw new ISE(
"Unable to serialize [%s], size [%s] is larger than max [%s]",
type.asTypeString(),
sizeBytes,
maxSizeBytes
);
}
}
}

View File

@ -205,25 +205,25 @@ public class ApplyFunctionTest extends InitializedNullHandlingTest
private void assertExpr(final String expression, final Double[] expectedResult)
{
final Expr expr = Parser.parse(expression, ExprMacroTable.nil());
Double[] result = (Double[]) expr.eval(bindings).value();
Object[] result = expr.eval(bindings).asArray();
Assert.assertEquals(expectedResult.length, result.length);
for (int i = 0; i < result.length; i++) {
Assert.assertEquals(expression, expectedResult[i], result[i], 0.00001); // something is lame somewhere..
Assert.assertEquals(expression, expectedResult[i], (Double) result[i], 0.00001); // something is lame somewhere..
}
final Expr exprNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false);
final Expr roundTrip = Parser.parse(exprNoFlatten.stringify(), ExprMacroTable.nil());
Double[] resultRoundTrip = (Double[]) roundTrip.eval(bindings).value();
Object[] resultRoundTrip = (Object[]) roundTrip.eval(bindings).value();
Assert.assertEquals(expectedResult.length, resultRoundTrip.length);
for (int i = 0; i < resultRoundTrip.length; i++) {
Assert.assertEquals(expression, expectedResult[i], resultRoundTrip[i], 0.00001);
Assert.assertEquals(expression, expectedResult[i], (Double) resultRoundTrip[i], 0.00001);
}
final Expr roundTripFlatten = Parser.parse(expr.stringify(), ExprMacroTable.nil());
Double[] resultRoundTripFlatten = (Double[]) roundTripFlatten.eval(bindings).value();
Object[] resultRoundTripFlatten = (Object[]) roundTripFlatten.eval(bindings).value();
Assert.assertEquals(expectedResult.length, resultRoundTripFlatten.length);
for (int i = 0; i < resultRoundTripFlatten.length; i++) {
Assert.assertEquals(expression, expectedResult[i], resultRoundTripFlatten[i], 0.00001);
Assert.assertEquals(expression, expectedResult[i], (Double) resultRoundTripFlatten[i], 0.00001);
}
Assert.assertEquals(expr.stringify(), roundTrip.stringify());

View File

@ -20,11 +20,14 @@
package org.apache.druid.math.expr;
import com.google.common.collect.ImmutableList;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.column.TypesTest;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@ -42,6 +45,12 @@ public class ExprEvalTest extends InitializedNullHandlingTest
ByteBuffer buffer = ByteBuffer.allocate(1 << 16);
@BeforeClass
public static void setup()
{
Types.registerStrategy(TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new TypesTest.PairObjectByteStrategy());
}
@Test
public void testStringSerde()
{
@ -109,10 +118,10 @@ public class ExprEvalTest extends InitializedNullHandlingTest
expectedException.expectMessage(StringUtils.format(
"Unable to serialize [%s], size [%s] is larger than max [%s]",
ExpressionType.STRING_ARRAY,
28,
15,
10
));
assertEstimatedBytes(ExprEval.ofStringArray(new String[]{"hello", "hi", "hey"}), 10);
assertExpr(0, ExprEval.ofStringArray(new String[]{"hello", "hi", "hey"}), 10);
}
@Test
@ -130,7 +139,7 @@ public class ExprEvalTest extends InitializedNullHandlingTest
expectedException.expectMessage(StringUtils.format(
"Unable to serialize [%s], size [%s] is larger than max [%s]",
ExpressionType.LONG_ARRAY,
30,
14,
10
));
assertExpr(0, ExprEval.ofLongArray(new Long[]{1L, 2L, 3L}), 10);
@ -143,10 +152,10 @@ public class ExprEvalTest extends InitializedNullHandlingTest
expectedException.expectMessage(StringUtils.format(
"Unable to serialize [%s], size [%s] is larger than max [%s]",
ExpressionType.LONG_ARRAY,
NullHandling.sqlCompatible() ? 33 : 30,
14,
10
));
assertEstimatedBytes(ExprEval.ofLongArray(new Long[]{1L, 2L, 3L}), 10);
assertExpr(0, ExprEval.ofLongArray(new Long[]{1L, 2L, 3L}), 10);
}
@Test
@ -164,7 +173,7 @@ public class ExprEvalTest extends InitializedNullHandlingTest
expectedException.expectMessage(StringUtils.format(
"Unable to serialize [%s], size [%s] is larger than max [%s]",
ExpressionType.DOUBLE_ARRAY,
30,
14,
10
));
assertExpr(0, ExprEval.ofDoubleArray(new Double[]{1.1, 2.2, 3.3}), 10);
@ -177,86 +186,144 @@ public class ExprEvalTest extends InitializedNullHandlingTest
expectedException.expectMessage(StringUtils.format(
"Unable to serialize [%s], size [%s] is larger than max [%s]",
ExpressionType.DOUBLE_ARRAY,
NullHandling.sqlCompatible() ? 33 : 30,
14,
10
));
assertEstimatedBytes(ExprEval.ofDoubleArray(new Double[]{1.1, 2.2, 3.3}), 10);
assertExpr(0, ExprEval.ofDoubleArray(new Double[]{1.1, 2.2, 3.3}), 10);
}
@Test
public void testComplexEval()
{
final ExpressionType complexType = ExpressionType.fromColumnType(TypesTest.NULLABLE_TEST_PAIR_TYPE);
assertExpr(0, ExprEval.ofComplex(complexType, new TypesTest.NullableLongPair(1234L, 5678L)));
assertExpr(1024, ExprEval.ofComplex(complexType, new TypesTest.NullableLongPair(1234L, 5678L)));
}
@Test
public void testComplexEvalTooBig()
{
final ExpressionType complexType = ExpressionType.fromColumnType(TypesTest.NULLABLE_TEST_PAIR_TYPE);
expectedException.expect(ISE.class);
expectedException.expectMessage(StringUtils.format(
"Unable to serialize [%s], size [%s] is larger than max [%s]",
complexType.asTypeString(),
23,
10
));
assertExpr(0, ExprEval.ofComplex(complexType, new TypesTest.NullableLongPair(1234L, 5678L)), 10);
}
@Test
public void test_coerceListToArray()
{
Assert.assertNull(ExprEval.coerceListToArray(null, false));
Assert.assertArrayEquals(new Object[0], (Object[]) ExprEval.coerceListToArray(ImmutableList.of(), false));
Assert.assertArrayEquals(new String[]{null}, (String[]) ExprEval.coerceListToArray(null, true));
Assert.assertArrayEquals(new String[]{null}, (String[]) ExprEval.coerceListToArray(ImmutableList.of(), true));
NonnullPair<ExpressionType, Object[]> coerced = ExprEval.coerceListToArray(ImmutableList.of(), false);
Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
Assert.assertArrayEquals(new Object[0], coerced.rhs);
coerced = ExprEval.coerceListToArray(null, true);
Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
Assert.assertArrayEquals(new Object[]{null}, coerced.rhs);
coerced = ExprEval.coerceListToArray(ImmutableList.of(), true);
Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
Assert.assertArrayEquals(new Object[]{null}, coerced.rhs);
List<Long> longList = ImmutableList.of(1L, 2L, 3L);
Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExprEval.coerceListToArray(longList, false));
coerced = ExprEval.coerceListToArray(longList, false);
Assert.assertEquals(ExpressionType.LONG_ARRAY, coerced.lhs);
Assert.assertArrayEquals(new Object[]{1L, 2L, 3L}, coerced.rhs);
List<Integer> intList = ImmutableList.of(1, 2, 3);
Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExprEval.coerceListToArray(intList, false));
ExprEval.coerceListToArray(intList, false);
Assert.assertEquals(ExpressionType.LONG_ARRAY, coerced.lhs);
Assert.assertArrayEquals(new Object[]{1L, 2L, 3L}, coerced.rhs);
List<Float> floatList = ImmutableList.of(1.0f, 2.0f, 3.0f);
Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExprEval.coerceListToArray(floatList, false));
coerced = ExprEval.coerceListToArray(floatList, false);
Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, coerced.lhs);
Assert.assertArrayEquals(new Object[]{1.0, 2.0, 3.0}, coerced.rhs);
List<Double> doubleList = ImmutableList.of(1.0, 2.0, 3.0);
Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExprEval.coerceListToArray(doubleList, false));
coerced = ExprEval.coerceListToArray(doubleList, false);
Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, coerced.lhs);
Assert.assertArrayEquals(new Object[]{1.0, 2.0, 3.0}, coerced.rhs);
List<String> stringList = ImmutableList.of("a", "b", "c");
Assert.assertArrayEquals(new String[]{"a", "b", "c"}, (String[]) ExprEval.coerceListToArray(stringList, false));
coerced = ExprEval.coerceListToArray(stringList, false);
Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
Assert.assertArrayEquals(new Object[]{"a", "b", "c"}, coerced.rhs);
List<String> withNulls = new ArrayList<>();
withNulls.add("a");
withNulls.add(null);
withNulls.add("c");
Assert.assertArrayEquals(new String[]{"a", null, "c"}, (String[]) ExprEval.coerceListToArray(withNulls, false));
coerced = ExprEval.coerceListToArray(withNulls, false);
Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
Assert.assertArrayEquals(new Object[]{"a", null, "c"}, coerced.rhs);
List<Long> withNumberNulls = new ArrayList<>();
withNumberNulls.add(1L);
withNumberNulls.add(null);
withNumberNulls.add(3L);
Assert.assertArrayEquals(new Long[]{1L, null, 3L}, (Long[]) ExprEval.coerceListToArray(withNumberNulls, false));
coerced = ExprEval.coerceListToArray(withNumberNulls, false);
Assert.assertEquals(ExpressionType.LONG_ARRAY, coerced.lhs);
Assert.assertArrayEquals(new Object[]{1L, null, 3L}, coerced.rhs);
List<Object> withStringMix = ImmutableList.of(1L, "b", 3L);
coerced = ExprEval.coerceListToArray(withStringMix, false);
Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
Assert.assertArrayEquals(
new String[]{"1", "b", "3"},
(String[]) ExprEval.coerceListToArray(withStringMix, false)
new Object[]{"1", "b", "3"},
coerced.rhs
);
List<Number> withIntsAndLongs = ImmutableList.of(1, 2L, 3);
coerced = ExprEval.coerceListToArray(withIntsAndLongs, false);
Assert.assertEquals(ExpressionType.LONG_ARRAY, coerced.lhs);
Assert.assertArrayEquals(
new Long[]{1L, 2L, 3L},
(Long[]) ExprEval.coerceListToArray(withIntsAndLongs, false)
new Object[]{1L, 2L, 3L},
coerced.rhs
);
List<Number> withFloatsAndLongs = ImmutableList.of(1, 2L, 3.0f);
coerced = ExprEval.coerceListToArray(withFloatsAndLongs, false);
Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, coerced.lhs);
Assert.assertArrayEquals(
new Double[]{1.0, 2.0, 3.0},
(Double[]) ExprEval.coerceListToArray(withFloatsAndLongs, false)
new Object[]{1.0, 2.0, 3.0},
coerced.rhs
);
List<Number> withDoublesAndLongs = ImmutableList.of(1, 2L, 3.0);
coerced = ExprEval.coerceListToArray(withDoublesAndLongs, false);
Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, coerced.lhs);
Assert.assertArrayEquals(
new Double[]{1.0, 2.0, 3.0},
(Double[]) ExprEval.coerceListToArray(withDoublesAndLongs, false)
new Object[]{1.0, 2.0, 3.0},
coerced.rhs
);
List<Number> withFloatsAndDoubles = ImmutableList.of(1L, 2.0f, 3.0);
coerced = ExprEval.coerceListToArray(withFloatsAndDoubles, false);
Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, coerced.lhs);
Assert.assertArrayEquals(
new Double[]{1.0, 2.0, 3.0},
(Double[]) ExprEval.coerceListToArray(withFloatsAndDoubles, false)
new Object[]{1.0, 2.0, 3.0},
coerced.rhs
);
List<String> withAllNulls = new ArrayList<>();
withAllNulls.add(null);
withAllNulls.add(null);
withAllNulls.add(null);
coerced = ExprEval.coerceListToArray(withAllNulls, false);
Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
Assert.assertArrayEquals(
new String[]{null, null, null},
(String[]) ExprEval.coerceListToArray(withAllNulls, false)
new Object[]{null, null, null},
coerced.rhs
);
}
@Test
@ -299,27 +366,21 @@ public class ExprEvalTest extends InitializedNullHandlingTest
if (expected.type().isArray()) {
Assert.assertArrayEquals(
expected.asArray(),
ExprEval.deserialize(buffer, position + 1, ExprType.fromByte(buffer.get(position))).asArray()
ExprEval.deserialize(buffer, position, expected.type()).asArray()
);
Assert.assertArrayEquals(
expected.asArray(),
ExprEval.deserialize(buffer, position).asArray()
ExprEval.deserialize(buffer, position, expected.type()).asArray()
);
} else {
Assert.assertEquals(
expected.value(),
ExprEval.deserialize(buffer, position + 1, ExprType.fromByte(buffer.get(position))).value()
ExprEval.deserialize(buffer, position, expected.type()).value()
);
Assert.assertEquals(
expected.value(),
ExprEval.deserialize(buffer, position).value()
ExprEval.deserialize(buffer, position, expected.type()).value()
);
}
assertEstimatedBytes(expected, maxSizeBytes);
}
private void assertEstimatedBytes(ExprEval eval, int maxSizeBytes)
{
ExprEval.estimateAndCheckMaxBytes(eval, maxSizeBytes);
}
}

View File

@ -160,34 +160,27 @@ public class ExprTest
}
@Test
public void testEqualsContractForStringArrayExpr()
public void testEqualsContractForArrayExpr()
{
EqualsVerifier.forClass(StringArrayExpr.class)
.withIgnoredFields("outputType")
.withPrefabValues(Object.class, new String[]{"foo"}, new String[0])
.withPrefabValues(ExpressionType.class, ExpressionType.STRING_ARRAY, ExpressionType.LONG_ARRAY)
.usingGetClass()
.verify();
}
@Test
public void testEqualsContractForLongArrayExpr()
{
EqualsVerifier.forClass(LongArrayExpr.class)
.withIgnoredFields("outputType")
.withPrefabValues(Object.class, new Long[]{1L}, new Long[0])
EqualsVerifier.forClass(ArrayExpr.class)
.withPrefabValues(Object.class, new Object[]{1L}, new Object[0])
.withPrefabValues(ExpressionType.class, ExpressionType.LONG_ARRAY, ExpressionType.DOUBLE_ARRAY)
.withNonnullFields("outputType")
.usingGetClass()
.verify();
}
@Test
public void testEqualsContractForDoubleArrayExpr()
public void testEqualsContractForComplexExpr()
{
EqualsVerifier.forClass(DoubleArrayExpr.class)
.withIgnoredFields("outputType")
.withPrefabValues(Object.class, new Double[]{1.0}, new Double[0])
.withPrefabValues(ExpressionType.class, ExpressionType.DOUBLE_ARRAY, ExpressionType.STRING_ARRAY)
EqualsVerifier.forClass(ComplexExpr.class)
.withPrefabValues(Object.class, new Object[]{1L}, new Object[0])
.withPrefabValues(
ExpressionType.class,
ExpressionTypeFactory.getInstance().ofComplex("foo"),
ExpressionTypeFactory.getInstance().ofComplex("bar")
)
.withNonnullFields("outputType")
.usingGetClass()
.verify();
}

View File

@ -25,9 +25,13 @@ import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.column.ObjectByteStrategy;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.column.TypesTest;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@ -45,6 +49,12 @@ public class FunctionTest extends InitializedNullHandlingTest
private Expr.ObjectBinding bindings;
@BeforeClass
public static void setupClass()
{
Types.registerStrategy(TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new TypesTest.PairObjectByteStrategy());
}
@Before
public void setup()
{
@ -64,7 +74,8 @@ public class FunctionTest extends InitializedNullHandlingTest
.put("of", 0F)
.put("a", new String[] {"foo", "bar", "baz", "foobar"})
.put("b", new Long[] {1L, 2L, 3L, 4L, 5L})
.put("c", new Double[] {3.1, 4.2, 5.3});
.put("c", new Double[] {3.1, 4.2, 5.3})
.put("someComplex", new TypesTest.NullableLongPair(1L, 2L));
bindings = InputBindings.withMap(builder.build());
}
@ -281,7 +292,7 @@ public class FunctionTest extends InitializedNullHandlingTest
public void testArrayAppend()
{
assertArrayExpr("array_append([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
assertArrayExpr("array_append([1, 2, 3], 'bar')", new Long[]{1L, 2L, 3L, null});
assertArrayExpr("array_append([1, 2, 3], 'bar')", new Long[]{1L, 2L, 3L, NullHandling.defaultLongValue()});
assertArrayExpr("array_append([], 1)", new String[]{"1"});
assertArrayExpr("array_append(<LONG>[], 1)", new Long[]{1L});
}
@ -300,11 +311,11 @@ public class FunctionTest extends InitializedNullHandlingTest
public void testArraySetAdd()
{
assertArrayExpr("array_set_add([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
assertArrayExpr("array_set_add([1, 2, 3], 'bar')", new Long[]{null, 1L, 2L, 3L});
assertArrayExpr("array_set_add([1, 2, 3], 'bar')", new Long[]{NullHandling.defaultLongValue(), 1L, 2L, 3L});
assertArrayExpr("array_set_add([1, 2, 2], 1)", new Long[]{1L, 2L});
assertArrayExpr("array_set_add([], 1)", new String[]{"1"});
assertArrayExpr("array_set_add(<LONG>[], 1)", new Long[]{1L});
assertArrayExpr("array_set_add(<LONG>[], null)", new Long[]{null});
assertArrayExpr("array_set_add(<LONG>[], null)", new Long[]{NullHandling.defaultLongValue()});
}
@Test
@ -358,7 +369,7 @@ public class FunctionTest extends InitializedNullHandlingTest
public void testArrayPrepend()
{
assertArrayExpr("array_prepend(4, [1, 2, 3])", new Long[]{4L, 1L, 2L, 3L});
assertArrayExpr("array_prepend('bar', [1, 2, 3])", new Long[]{null, 1L, 2L, 3L});
assertArrayExpr("array_prepend('bar', [1, 2, 3])", new Long[]{NullHandling.defaultLongValue(), 1L, 2L, 3L});
assertArrayExpr("array_prepend(1, [])", new String[]{"1"});
assertArrayExpr("array_prepend(1, <LONG>[])", new Long[]{1L});
assertArrayExpr("array_prepend(1, <DOUBLE>[])", new Double[]{1.0});
@ -792,6 +803,66 @@ public class FunctionTest extends InitializedNullHandlingTest
assertExpr("repeat(nonexistent, 10)", null);
}
@Test
public void testComplexDecode()
{
TypesTest.NullableLongPair expected = new TypesTest.NullableLongPair(1L, 2L);
ObjectByteStrategy strategy = Types.getStrategy(TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName());
assertExpr(
StringUtils.format(
"complex_decode_base64('%s', '%s')",
TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(),
StringUtils.encodeBase64String(strategy.toBytes(expected))
),
expected
);
}
@Test
public void testComplexDecodeNull()
{
assertExpr(
StringUtils.format(
"complex_decode_base64('%s', null)",
TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName()
),
null
);
}
@Test
public void testComplexDecodeBaseWrongArgCount()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[complex_decode_base64] needs 2 arguments");
assertExpr(
"complex_decode_base64(string)",
null
);
}
@Test
public void testComplexDecodeBaseArg0BadType()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[complex_decode_base64] first argument must be constant 'STRING' expression containing a valid complex type name");
assertExpr(
"complex_decode_base64(1, string)",
null
);
}
@Test
public void testComplexDecodeBaseArg0Unknown()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[complex_decode_base64] first argument must be a valid complex type name, unknown complex type [COMPLEX<unknown>]");
assertExpr(
"complex_decode_base64('unknown', string)",
null
);
}
private void assertExpr(final String expression, @Nullable final Object expectedResult)
{
final Expr expr = Parser.parse(expression, ExprMacroTable.nil());

View File

@ -22,9 +22,15 @@ package org.apache.druid.math.expr;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.column.ObjectByteStrategy;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.column.TypesTest;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@ -43,6 +49,12 @@ public class ParserTest extends InitializedNullHandlingTest
VectorExprSanityTest.SettableVectorInputBinding emptyBinding = new VectorExprSanityTest.SettableVectorInputBinding(8);
@BeforeClass
public static void setup()
{
Types.registerStrategy(TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new TypesTest.PairObjectByteStrategy());
}
@Test
public void testSimple()
{
@ -222,76 +234,154 @@ public class ParserTest extends InitializedNullHandlingTest
@Test
public void testLiteralArraysHomogeneousElements()
{
validateConstantExpression("[1.0, 2.345]", new Double[]{1.0, 2.345});
validateConstantExpression("[1, 3]", new Long[]{1L, 3L});
validateConstantExpression("['hello', 'world']", new String[]{"hello", "world"});
validateConstantExpression("[1.0, 2.345]", new Object[]{1.0, 2.345});
validateConstantExpression("[1, 3]", new Object[]{1L, 3L});
validateConstantExpression("['hello', 'world']", new Object[]{"hello", "world"});
}
@Test
public void testLiteralArraysHomogeneousOrNullElements()
{
validateConstantExpression("[1.0, null, 2.345]", new Double[]{1.0, null, 2.345});
validateConstantExpression("[null, 1, 3]", new Long[]{null, 1L, 3L});
validateConstantExpression("['hello', 'world', null]", new String[]{"hello", "world", null});
validateConstantExpression("[1.0, null, 2.345]", new Object[]{1.0, null, 2.345});
validateConstantExpression("[null, 1, 3]", new Object[]{null, 1L, 3L});
validateConstantExpression("['hello', 'world', null]", new Object[]{"hello", "world", null});
}
@Test
public void testLiteralArraysEmptyAndAllNullImplicitAreString()
{
validateConstantExpression("[]", new String[0]);
validateConstantExpression("[null, null, null]", new String[]{null, null, null});
validateConstantExpression("[]", new Object[0]);
validateConstantExpression("[null, null, null]", new Object[]{null, null, null});
}
@Test
public void testLiteralArraysImplicitTypedNumericMixed()
{
// implicit typed numeric arrays with mixed elements are doubles
validateConstantExpression("[1, null, 2000.0]", new Double[]{1.0, null, 2000.0});
validateConstantExpression("[1.0, null, 2000]", new Double[]{1.0, null, 2000.0});
validateConstantExpression("[1, null, 2000.0]", new Object[]{1.0, null, 2000.0});
validateConstantExpression("[1.0, null, 2000]", new Object[]{1.0, null, 2000.0});
}
@Test
public void testLiteralArraysExplicitTypedEmpties()
{
validateConstantExpression("<STRING>[]", new String[0]);
validateConstantExpression("<DOUBLE>[]", new Double[0]);
validateConstantExpression("<LONG>[]", new Long[0]);
// legacy explicit array format
validateConstantExpression("<STRING>[]", new Object[0]);
validateConstantExpression("<DOUBLE>[]", new Object[0]);
validateConstantExpression("<LONG>[]", new Object[0]);
}
@Test
public void testLiteralArraysExplicitAllNull()
{
validateConstantExpression("<DOUBLE>[null, null, null]", new Double[]{null, null, null});
validateConstantExpression("<LONG>[null, null, null]", new Long[]{null, null, null});
validateConstantExpression("<STRING>[null, null, null]", new String[]{null, null, null});
// legacy explicit array format
validateConstantExpression("<DOUBLE>[null, null, null]", new Object[]{null, null, null});
validateConstantExpression("<LONG>[null, null, null]", new Object[]{null, null, null});
validateConstantExpression("<STRING>[null, null, null]", new Object[]{null, null, null});
}
@Test
public void testLiteralArraysExplicitTypes()
{
validateConstantExpression("<DOUBLE>[1.0, null, 2000.0]", new Double[]{1.0, null, 2000.0});
validateConstantExpression("<LONG>[3, null, 4]", new Long[]{3L, null, 4L});
validateConstantExpression("<STRING>['foo', 'bar', 'baz']", new String[]{"foo", "bar", "baz"});
// legacy explicit array format
validateConstantExpression("<DOUBLE>[1.0, null, 2000.0]", new Object[]{1.0, null, 2000.0});
validateConstantExpression("<LONG>[3, null, 4]", new Object[]{3L, null, 4L});
validateConstantExpression("<STRING>['foo', 'bar', 'baz']", new Object[]{"foo", "bar", "baz"});
}
@Test
public void testLiteralArraysExplicitTypesMixedElements()
{
// legacy explicit array format
// explicit typed numeric arrays mixed numeric types should coerce to the correct explicit type
validateConstantExpression("<DOUBLE>[3, null, 4, 2.345]", new Double[]{3.0, null, 4.0, 2.345});
validateConstantExpression("<LONG>[1.0, null, 2000.0]", new Long[]{1L, null, 2000L});
validateConstantExpression("<DOUBLE>[3, null, 4, 2.345]", new Object[]{3.0, null, 4.0, 2.345});
validateConstantExpression("<LONG>[1.0, null, 2000.0]", new Object[]{1L, null, 2000L});
// explicit typed string arrays should accept any literal and convert to string
validateConstantExpression("<STRING>['1', null, 2000, 1.1]", new String[]{"1", null, "2000", "1.1"});
validateConstantExpression("<STRING>['1', null, 2000, 1.1]", new Object[]{"1", null, "2000", "1.1"});
}
@Test
public void testLiteralExplicitTypedArrays()
{
ExpressionProcessing.initializeForTests(true);
validateConstantExpression("ARRAY<DOUBLE>[1.0, 2.0, null, 3.0]", new Object[]{1.0, 2.0, null, 3.0});
validateConstantExpression("ARRAY<LONG>[1, 2, null, 3]", new Object[]{1L, 2L, null, 3L});
validateConstantExpression("ARRAY<STRING>['1', '2', null, '3.0']", new Object[]{"1", "2", null, "3.0"});
// mixed type tests
validateConstantExpression("ARRAY<DOUBLE>[3, null, 4, 2.345]", new Object[]{3.0, null, 4.0, 2.345});
validateConstantExpression("ARRAY<LONG>[1.0, null, 2000.0]", new Object[]{1L, null, 2000L});
// explicit typed string arrays should accept any literal and convert
validateConstantExpression("ARRAY<STRING>['1', null, 2000, 1.1]", new Object[]{"1", null, "2000", "1.1"});
validateConstantExpression("ARRAY<LONG>['1', null, 2000, 1.1]", new Object[]{1L, null, 2000L, 1L});
validateConstantExpression("ARRAY<DOUBLE>['1', null, 2000, 1.1]", new Object[]{1.0, null, 2000.0, 1.1});
// the gramar isn't cool enough yet to parse populated nested-arrays or complex arrays..., but empty ones can
// be defined...
validateConstantExpression("ARRAY<COMPLEX<nullableLongPair>>[]", new Object[]{});
validateConstantExpression("ARRAY<ARRAY<LONG>>[]", new Object[]{});
ExpressionProcessing.initializeForTests(null);
}
@Test
public void testConstantComplexAndNestedArrays()
{
ExpressionProcessing.initializeForTests(true);
// they can be built with array builder functions though...
validateConstantExpression(
"array(['foo', 'bar', 'baz'], ['baz','foo','bar'])",
new Object[]{new Object[]{"foo", "bar", "baz"}, new Object[]{"baz", "foo", "bar"}}
);
// nested arrays cannot be mixed types, the first element choo-choo-chooses for you
validateConstantExpression(
"array(['foo', 'bar', 'baz'], ARRAY<LONG>[1,2,3])",
new Object[]{new Object[]{"foo", "bar", "baz"}, new Object[]{"1", "2", "3"}}
);
// complex types too
TypesTest.NullableLongPair l1 = new TypesTest.NullableLongPair(1L, 2L);
TypesTest.NullableLongPair l2 = new TypesTest.NullableLongPair(2L, 3L);
ObjectByteStrategy byteStrategy = Types.getStrategy(TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName());
String l1String = StringUtils.format(
"complex_decode_base64('%s', '%s')",
TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(),
StringUtils.encodeBase64String(byteStrategy.toBytes(l1))
);
String l2String = StringUtils.format(
"complex_decode_base64('%s', '%s')",
TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(),
StringUtils.encodeBase64String(byteStrategy.toBytes(l2))
);
validateConstantExpression(
l1String,
l1
);
validateConstantExpression(
StringUtils.format("array(%s,%s)", l1String, l2String),
new Object[]{l1, l2}
);
ExpressionProcessing.initializeForTests(null);
}
@Test
public void nestedArraysExplodeIfNotEnabled()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Cannot create a nested array type [ARRAY<ARRAY<LONG>>], 'druid.expressions.allowNestedArrays' must be set to true");
validateConstantExpression("ARRAY<ARRAY<LONG>>[]", new Object[]{});
}
@Test
public void testLiteralArrayImplicitStringParseException()
{
// implicit typed string array cannot handle literals thate are not null or string
expectedException.expect(RE.class);
expectedException.expectMessage("Failed to parse array: element 2000 is not a string");
validateConstantExpression("['1', null, 2000, 1.1]", new String[]{"1", null, "2000", "1.1"});
validateConstantExpression("['1', null, 2000, 1.1]", new Object[]{"1", null, "2000", "1.1"});
}
@Test
@ -300,7 +390,7 @@ public class ParserTest extends InitializedNullHandlingTest
// explicit typed long arrays only handle numeric types
expectedException.expect(RE.class);
expectedException.expectMessage("Failed to parse array element '2000' as a long");
validateConstantExpression("<LONG>[1, null, '2000']", new Long[]{1L, null, 2000L});
validateConstantExpression("<LONG>[1, null, '2000']", new Object[]{1L, null, 2000L});
}
@Test
@ -309,7 +399,7 @@ public class ParserTest extends InitializedNullHandlingTest
// explicit typed double arrays only handle numeric types
expectedException.expect(RE.class);
expectedException.expectMessage("Failed to parse array element '2000.0' as a double");
validateConstantExpression("<DOUBLE>[1.0, null, '2000.0']", new Double[]{1.0, null, 2000.0});
validateConstantExpression("<DOUBLE>[1.0, null, '2000.0']", new Object[]{1.0, null, 2000.0});
}
@Test

View File

@ -0,0 +1,443 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.segment.column;
import com.google.common.primitives.Longs;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.guava.Comparators;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.util.Arrays;
public class TypesTest
{
ByteBuffer buffer = ByteBuffer.allocate(1 << 16);
public static ColumnType NULLABLE_TEST_PAIR_TYPE = ColumnType.ofComplex("nullableLongPair");
@Rule
public ExpectedException expectedException = ExpectedException.none();
@BeforeClass
public static void setup()
{
Types.registerStrategy(NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new PairObjectByteStrategy());
}
@Test
public void testIs()
{
Assert.assertTrue(Types.is(ColumnType.LONG, ValueType.LONG));
Assert.assertTrue(Types.is(ColumnType.DOUBLE, ValueType.DOUBLE));
Assert.assertTrue(Types.is(ColumnType.FLOAT, ValueType.FLOAT));
Assert.assertTrue(Types.is(ColumnType.STRING, ValueType.STRING));
Assert.assertTrue(Types.is(ColumnType.LONG_ARRAY, ValueType.ARRAY));
Assert.assertTrue(Types.is(ColumnType.LONG_ARRAY.getElementType(), ValueType.LONG));
Assert.assertTrue(Types.is(ColumnType.DOUBLE_ARRAY, ValueType.ARRAY));
Assert.assertTrue(Types.is(ColumnType.DOUBLE_ARRAY.getElementType(), ValueType.DOUBLE));
Assert.assertTrue(Types.is(ColumnType.STRING_ARRAY, ValueType.ARRAY));
Assert.assertTrue(Types.is(ColumnType.STRING_ARRAY.getElementType(), ValueType.STRING));
Assert.assertTrue(Types.is(NULLABLE_TEST_PAIR_TYPE, ValueType.COMPLEX));
Assert.assertFalse(Types.is(ColumnType.LONG, ValueType.DOUBLE));
Assert.assertFalse(Types.is(ColumnType.DOUBLE, ValueType.FLOAT));
Assert.assertFalse(Types.is(null, ValueType.STRING));
Assert.assertTrue(Types.isNullOr(null, ValueType.STRING));
}
@Test
public void testNullOrAnyOf()
{
Assert.assertTrue(Types.isNullOrAnyOf(ColumnType.LONG, ValueType.STRING, ValueType.LONG, ValueType.DOUBLE));
Assert.assertFalse(Types.isNullOrAnyOf(ColumnType.DOUBLE, ValueType.STRING, ValueType.LONG, ValueType.FLOAT));
Assert.assertTrue(Types.isNullOrAnyOf(null, ValueType.STRING, ValueType.LONG, ValueType.FLOAT));
}
@Test
public void testEither()
{
Assert.assertTrue(Types.either(ColumnType.LONG, ColumnType.DOUBLE, ValueType.DOUBLE));
Assert.assertFalse(Types.either(ColumnType.LONG, ColumnType.STRING, ValueType.DOUBLE));
}
@Test
public void testRegister()
{
ObjectByteStrategy<?> strategy = Types.getStrategy(NULLABLE_TEST_PAIR_TYPE.getComplexTypeName());
Assert.assertNotNull(strategy);
Assert.assertTrue(strategy instanceof PairObjectByteStrategy);
}
@Test
public void testRegisterDuplicate()
{
Types.registerStrategy(NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new PairObjectByteStrategy());
ObjectByteStrategy<?> strategy = Types.getStrategy(NULLABLE_TEST_PAIR_TYPE.getComplexTypeName());
Assert.assertNotNull(strategy);
Assert.assertTrue(strategy instanceof PairObjectByteStrategy);
}
@Test
public void testConflicting()
{
expectedException.expect(IllegalStateException.class);
expectedException.expectMessage(
"Incompatible strategy for type[nullableLongPair] already exists."
+ " Expected [org.apache.druid.segment.column.TypesTest$1],"
+ " found [org.apache.druid.segment.column.TypesTest$PairObjectByteStrategy]."
);
Types.registerStrategy(NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new ObjectByteStrategy<String>()
{
@Override
public int compare(String o1, String o2)
{
return 0;
}
@Override
public Class<? extends String> getClazz()
{
return null;
}
@Nullable
@Override
public String fromByteBuffer(ByteBuffer buffer, int numBytes)
{
return null;
}
@Nullable
@Override
public byte[] toBytes(@Nullable String val)
{
return new byte[0];
}
});
}
@Test
public void testNulls()
{
int offset = 0;
Types.writeNull(buffer, offset);
Assert.assertTrue(Types.isNullableNull(buffer, offset));
// test non-zero offset
offset = 128;
Types.writeNull(buffer, offset);
Assert.assertTrue(Types.isNullableNull(buffer, offset));
}
@Test
public void testNonNullNullableLongBinary()
{
final long someLong = 12345567L;
int offset = 0;
int bytesWritten = Types.writeNullableLong(buffer, offset, someLong);
Assert.assertEquals(1 + Long.BYTES, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertEquals(someLong, Types.readNullableLong(buffer, offset));
// test non-zero offset
offset = 1024;
bytesWritten = Types.writeNullableLong(buffer, offset, someLong);
Assert.assertEquals(1 + Long.BYTES, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertEquals(someLong, Types.readNullableLong(buffer, offset));
}
@Test
public void testNonNullNullableDoubleBinary()
{
final double someDouble = 1.234567;
int offset = 0;
int bytesWritten = Types.writeNullableDouble(buffer, offset, someDouble);
Assert.assertEquals(1 + Double.BYTES, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertEquals(someDouble, Types.readNullableDouble(buffer, offset), 0);
// test non-zero offset
offset = 1024;
bytesWritten = Types.writeNullableDouble(buffer, offset, someDouble);
Assert.assertEquals(1 + Double.BYTES, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertEquals(someDouble, Types.readNullableDouble(buffer, offset), 0);
}
@Test
public void testNonNullNullableFloatBinary()
{
final float someFloat = 12345567L;
int offset = 0;
int bytesWritten = Types.writeNullableFloat(buffer, offset, someFloat);
Assert.assertEquals(1 + Float.BYTES, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertEquals(someFloat, Types.readNullableFloat(buffer, offset), 0);
// test non-zero offset
offset = 1024;
bytesWritten = Types.writeNullableFloat(buffer, offset, someFloat);
Assert.assertEquals(1 + Float.BYTES, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertEquals(someFloat, Types.readNullableFloat(buffer, offset), 0);
}
@Test
public void testNullableVariableBlob()
{
String someString = "hello";
byte[] stringBytes = StringUtils.toUtf8(someString);
int offset = 0;
int bytesWritten = Types.writeNullableVariableBlob(buffer, offset, stringBytes);
Assert.assertEquals(1 + Integer.BYTES + stringBytes.length, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertArrayEquals(stringBytes, Types.readNullableVariableBlob(buffer, offset));
// test non-zero offset
offset = 1024;
bytesWritten = Types.writeNullableVariableBlob(buffer, offset, stringBytes);
Assert.assertEquals(1 + Integer.BYTES + stringBytes.length, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertArrayEquals(stringBytes, Types.readNullableVariableBlob(buffer, offset));
// test null
bytesWritten = Types.writeNullableVariableBlob(buffer, offset, null);
Assert.assertEquals(1, bytesWritten);
Assert.assertTrue(Types.isNullableNull(buffer, offset));
}
@Test
public void testNullableVariableBlobTooBig()
{
expectedException.expect(ISE.class);
expectedException.expectMessage("Unable to serialize [STRING], size [10] is larger than max [5]");
String someString = "hello";
byte[] stringBytes = StringUtils.toUtf8(someString);
int offset = 0;
Types.writeNullableVariableBlob(buffer, offset, stringBytes, ColumnType.STRING, stringBytes.length);
}
@Test
public void testArrays()
{
final Long[] longArray = new Long[]{1L, 1234567L, null, 10L};
final Double[] doubleArray = new Double[]{1.23, 4.567, null, 8.9};
final String[] stringArray = new String[]{"hello", "world", null, ""};
int bytesWritten;
int offset = 0;
bytesWritten = Types.writeNullableLongArray(buffer, offset, longArray, buffer.limit());
Assert.assertEquals(33, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertArrayEquals(longArray, Types.readNullableLongArray(buffer, offset));
bytesWritten = Types.writeNullableDoubleArray(buffer, offset, doubleArray, buffer.limit());
Assert.assertEquals(33, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertArrayEquals(doubleArray, Types.readNullableDoubleArray(buffer, offset));
bytesWritten = Types.writeNullableStringArray(buffer, offset, stringArray, buffer.limit());
Assert.assertEquals(31, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertArrayEquals(stringArray, Types.readNullableStringArray(buffer, offset));
offset = 1024;
bytesWritten = Types.writeNullableLongArray(buffer, offset, longArray, buffer.limit());
Assert.assertEquals(33, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertArrayEquals(longArray, Types.readNullableLongArray(buffer, offset));
bytesWritten = Types.writeNullableDoubleArray(buffer, offset, doubleArray, buffer.limit());
Assert.assertEquals(33, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertArrayEquals(doubleArray, Types.readNullableDoubleArray(buffer, offset));
bytesWritten = Types.writeNullableStringArray(buffer, offset, stringArray, buffer.limit());
Assert.assertEquals(31, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertArrayEquals(stringArray, Types.readNullableStringArray(buffer, offset));
// test nulls
bytesWritten = Types.writeNullableLongArray(buffer, offset, null, buffer.limit());
Assert.assertEquals(1, bytesWritten);
Assert.assertTrue(Types.isNullableNull(buffer, offset));
bytesWritten = Types.writeNullableDoubleArray(buffer, offset, null, buffer.limit());
Assert.assertEquals(1, bytesWritten);
Assert.assertTrue(Types.isNullableNull(buffer, offset));
bytesWritten = Types.writeNullableStringArray(buffer, offset, null, buffer.limit());
Assert.assertEquals(1, bytesWritten);
Assert.assertTrue(Types.isNullableNull(buffer, offset));
}
@Test
public void testLongArrayToBig()
{
expectedException.expect(ISE.class);
expectedException.expectMessage("Unable to serialize [ARRAY<LONG>], size [14] is larger than max [10]");
final Long[] longArray = new Long[]{1L, 1234567L, null, 10L};
Types.writeNullableLongArray(buffer, 0, longArray, 10);
}
@Test
public void testDoubleArrayToBig()
{
expectedException.expect(ISE.class);
expectedException.expectMessage("Unable to serialize [ARRAY<DOUBLE>], size [14] is larger than max [10]");
final Double[] doubleArray = new Double[]{1.23, 4.567, null, 8.9};
Types.writeNullableDoubleArray(buffer, 0, doubleArray, 10);
}
@Test
public void testStringArrayToBig()
{
expectedException.expect(ISE.class);
expectedException.expectMessage("Unable to serialize [ARRAY<STRING>], size [15] is larger than max [10]");
final String[] stringArray = new String[]{"hello", "world", null, ""};
Types.writeNullableStringArray(buffer, 0, stringArray, 10);
}
@Test
public void testComplex()
{
NullableLongPair lp1 = new NullableLongPair(null, 1L);
NullableLongPair lp2 = new NullableLongPair(1234L, 5678L);
NullableLongPair lp3 = new NullableLongPair(1234L, null);
int bytesWritten;
int offset = 0;
bytesWritten = Types.writeNullableComplexType(buffer, offset, NULLABLE_TEST_PAIR_TYPE, lp1, buffer.limit());
// 1 (not null) + 4 (length) + 1 (null) + 0 (lhs) + 1 (not null) + 8 (rhs)
Assert.assertEquals(15, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertEquals(lp1, Types.readNullableComplexType(buffer, offset, NULLABLE_TEST_PAIR_TYPE));
// 1 (not null) + 4 (length) + 1 (not null) + 8 (lhs) + 1 (not null) + 8 (rhs)
bytesWritten = Types.writeNullableComplexType(buffer, offset, NULLABLE_TEST_PAIR_TYPE, lp2, buffer.limit());
Assert.assertEquals(23, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertEquals(lp2, Types.readNullableComplexType(buffer, offset, NULLABLE_TEST_PAIR_TYPE));
// 1 (not null) + 4 (length) + 1 (not null) + 8 (lhs) + 1 (null) + 0 (rhs)
bytesWritten = Types.writeNullableComplexType(buffer, offset, NULLABLE_TEST_PAIR_TYPE, lp3, buffer.limit());
Assert.assertEquals(15, bytesWritten);
Assert.assertFalse(Types.isNullableNull(buffer, offset));
Assert.assertEquals(lp3, Types.readNullableComplexType(buffer, offset, NULLABLE_TEST_PAIR_TYPE));
}
@Test
public void testComplexTooBig()
{
expectedException.expect(ISE.class);
expectedException.expectMessage("Unable to serialize [COMPLEX<nullableLongPair>], size [23] is larger than max [10]");
Types.writeNullableComplexType(
buffer,
0,
NULLABLE_TEST_PAIR_TYPE,
new NullableLongPair(1234L, 5678L),
10
);
}
public static class PairObjectByteStrategy implements ObjectByteStrategy<NullableLongPair>
{
@Override
public Class<? extends NullableLongPair> getClazz()
{
return NullableLongPair.class;
}
@Nullable
@Override
public NullableLongPair fromByteBuffer(ByteBuffer buffer, int numBytes)
{
int position = buffer.position();
Long lhs = null;
Long rhs = null;
if (!Types.isNullableNull(buffer, position)) {
lhs = Types.readNullableLong(buffer, position);
position += 1 + Long.BYTES;
} else {
position++;
}
if (!Types.isNullableNull(buffer, position)) {
rhs = Types.readNullableLong(buffer, position);
}
return new NullableLongPair(lhs, rhs);
}
@Nullable
@Override
public byte[] toBytes(@Nullable NullableLongPair val)
{
byte[] bytes = new byte[1 + Long.BYTES + 1 + Long.BYTES];
ByteBuffer buffer = ByteBuffer.wrap(bytes);
int position = 0;
if (val != null) {
if (val.lhs != null) {
position += Types.writeNullableLong(buffer, position, val.lhs);
} else {
position += Types.writeNull(buffer, position);
}
if (val.rhs != null) {
position += Types.writeNullableLong(buffer, position, val.rhs);
} else {
position += Types.writeNull(buffer, position);
}
return Arrays.copyOfRange(bytes, 0, position);
} else {
return null;
}
}
@Override
public int compare(NullableLongPair o1, NullableLongPair o2)
{
return Comparators.<NullableLongPair>naturalNullsFirst().compare(o1, o2);
}
}
public static class NullableLongPair extends Pair<Long, Long> implements Comparable<NullableLongPair>
{
public NullableLongPair(@Nullable Long lhs, @Nullable Long rhs)
{
super(lhs, rhs);
}
@Override
public int compareTo(NullableLongPair o)
{
return Comparators.<Long>naturalNullsFirst().thenComparing(Longs::compare).compare(this.lhs, o.lhs);
}
}
}

View File

@ -20,10 +20,12 @@
package org.apache.druid.testing;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.math.expr.ExpressionProcessing;
public class InitializedNullHandlingTest
{
static {
NullHandling.initializeForTests();
ExpressionProcessing.initializeForTests(null);
}
}

View File

@ -23,7 +23,7 @@ import com.fasterxml.jackson.databind.Module;
import com.google.inject.Binder;
import org.apache.druid.initialization.DruidModule;
import org.apache.druid.query.aggregation.bloom.sql.BloomFilterSqlAggregator;
import org.apache.druid.query.expressions.BloomFilterExprMacro;
import org.apache.druid.query.expressions.BloomFilterExpressions;
import org.apache.druid.query.filter.sql.BloomFilterOperatorConversion;
import org.apache.druid.sql.guice.SqlBindings;
@ -44,6 +44,8 @@ public class BloomFilterExtensionModule implements DruidModule
{
SqlBindings.addOperatorConversion(binder, BloomFilterOperatorConversion.class);
SqlBindings.addAggregator(binder, BloomFilterSqlAggregator.class);
ExpressionModule.addExprMacro(binder, BloomFilterExprMacro.class);
ExpressionModule.addExprMacro(binder, BloomFilterExpressions.CreateExprMacro.class);
ExpressionModule.addExprMacro(binder, BloomFilterExpressions.AddExprMacro.class);
ExpressionModule.addExprMacro(binder, BloomFilterExpressions.TestExprMacro.class);
}
}

View File

@ -52,7 +52,7 @@ public class BloomFilterAggregatorFactory extends AggregatorFactory
public static final ColumnType TYPE = ColumnType.ofComplex(BloomFilterSerializersModule.BLOOM_FILTER_TYPE_NAME);
private static final int DEFAULT_NUM_ENTRIES = 1500;
private static final Comparator COMPARATOR = Comparator.nullsFirst((o1, o2) -> {
public static final Comparator COMPARATOR = Comparator.nullsFirst((o1, o2) -> {
if (o1 instanceof ByteBuffer && o2 instanceof ByteBuffer) {
ByteBuffer buf1 = (ByteBuffer) o1;
ByteBuffer buf2 = (ByteBuffer) o2;
@ -60,6 +60,13 @@ public class BloomFilterAggregatorFactory extends AggregatorFactory
BloomKFilter.getNumSetBits(buf1, buf1.position()),
BloomKFilter.getNumSetBits(buf2, buf2.position())
);
} else if (o1 instanceof BloomKFilter && o2 instanceof BloomKFilter) {
BloomKFilter f1 = (BloomKFilter) o1;
BloomKFilter f2 = (BloomKFilter) o2;
return Integer.compare(
f1.getNumSetBits(),
f2.getNumSetBits()
);
} else {
throw new RE("Unable to compare unexpected types [%s]", o1.getClass().getName());
}

View File

@ -28,6 +28,8 @@ import org.apache.druid.segment.serde.ComplexMetricExtractor;
import org.apache.druid.segment.serde.ComplexMetricSerde;
import org.apache.druid.segment.writeout.SegmentWriteOutMedium;
import javax.annotation.Nullable;
import java.io.IOException;
import java.nio.ByteBuffer;
/**
@ -37,6 +39,8 @@ import java.nio.ByteBuffer;
*/
public class BloomFilterSerde extends ComplexMetricSerde
{
private static final BloomFilterObjectStrategy STRATEGY = new BloomFilterObjectStrategy();
@Override
public String getTypeName()
{
@ -64,6 +68,45 @@ public class BloomFilterSerde extends ComplexMetricSerde
@Override
public ObjectStrategy<BloomKFilter> getObjectStrategy()
{
throw new UnsupportedOperationException("Bloom filter aggregators are query-time only");
return STRATEGY;
}
private static class BloomFilterObjectStrategy implements ObjectStrategy<BloomKFilter>
{
@Override
public Class<? extends BloomKFilter> getClazz()
{
return BloomKFilter.class;
}
@Nullable
@Override
public BloomKFilter fromByteBuffer(ByteBuffer buffer, int numBytes)
{
try {
return BloomKFilter.deserialize(buffer, buffer.position());
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
@Nullable
@Override
public byte[] toBytes(@Nullable BloomKFilter val)
{
try {
return BloomFilterSerializersModule.bloomKFilterToBytes(val);
}
catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public int compare(BloomKFilter o1, BloomKFilter o2)
{
return BloomFilterAggregatorFactory.COMPARATOR.compare(o1, o2);
}
}
}

View File

@ -1,138 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.expressions;
import org.apache.druid.guice.BloomFilterSerializersModule;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.filter.BloomKFilter;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.util.List;
public class BloomFilterExprMacro implements ExprMacroTable.ExprMacro
{
public static final String FN_NAME = "bloom_filter_test";
@Override
public String name()
{
return FN_NAME;
}
@Override
public Expr apply(List<Expr> args)
{
if (args.size() != 2) {
throw new IAE("Function[%s] must have 2 arguments", name());
}
final Expr arg = args.get(0);
final Expr filterExpr = args.get(1);
if (!filterExpr.isLiteral() || filterExpr.getLiteralValue() == null) {
throw new IAE("Function[%s] second argument must be a base64 serialized bloom filter", name());
}
final String serializedFilter = filterExpr.getLiteralValue().toString();
final byte[] decoded = StringUtils.decodeBase64String(serializedFilter);
BloomKFilter filter;
try {
filter = BloomFilterSerializersModule.bloomKFilterFromBytes(decoded);
}
catch (IOException ioe) {
throw new RuntimeException("Failed to deserialize bloom filter", ioe);
}
class BloomExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
{
private BloomExpr(Expr arg)
{
super(FN_NAME, arg);
}
@Nonnull
@Override
public ExprEval eval(final ObjectBinding bindings)
{
ExprEval evaluated = arg.eval(bindings);
boolean matches = false;
switch (evaluated.type().getType()) {
case STRING:
String stringVal = (String) evaluated.value();
if (stringVal == null) {
matches = nullMatch();
} else {
matches = filter.testString(stringVal);
}
break;
case DOUBLE:
Double doubleVal = (Double) evaluated.value();
if (doubleVal == null) {
matches = nullMatch();
} else {
matches = filter.testDouble(doubleVal);
}
break;
case LONG:
Long longVal = (Long) evaluated.value();
if (longVal == null) {
matches = nullMatch();
} else {
matches = filter.testLong(longVal);
}
break;
}
return ExprEval.ofLongBoolean(matches);
}
private boolean nullMatch()
{
return filter.testBytes(null, 0, 0);
}
@Override
public Expr visit(Shuttle shuttle)
{
Expr newArg = arg.visit(shuttle);
return shuttle.visit(new BloomExpr(newArg));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return ExpressionType.LONG;
}
}
return new BloomExpr(arg);
}
}

View File

@ -0,0 +1,366 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.expressions;
import org.apache.druid.guice.BloomFilterSerializersModule;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.query.aggregation.bloom.BloomFilterAggregatorFactory;
import org.apache.druid.query.filter.BloomKFilter;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;
public class BloomFilterExpressions
{
public static final ExpressionType BLOOM_FILTER_TYPE = ExpressionType.fromColumnTypeStrict(
BloomFilterAggregatorFactory.TYPE
);
public static class CreateExprMacro implements ExprMacroTable.ExprMacro
{
public static final String FN_NAME = "bloom_filter";
@Override
public String name()
{
return FN_NAME;
}
@Override
public Expr apply(List<Expr> args)
{
if (args.size() != 1) {
throw new IAE("Function[%s] must have 1 argument", name());
}
final Expr expectedSizeArg = args.get(0);
if (!expectedSizeArg.isLiteral() || expectedSizeArg.getLiteralValue() == null) {
throw new IAE("Function[%s] argument must be an LONG constant", name());
}
class BloomExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
{
final int expectedSize;
public BloomExpr(Expr arg)
{
super(FN_NAME, arg);
this.expectedSize = arg.eval(InputBindings.nilBindings()).asInt();
}
@Override
public ExprEval eval(ObjectBinding bindings)
{
return ExprEval.ofComplex(
BLOOM_FILTER_TYPE,
new BloomKFilter(expectedSize)
);
}
@Override
public Expr visit(Shuttle shuttle)
{
return shuttle.visit(this);
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return BLOOM_FILTER_TYPE;
}
}
return new BloomExpr(expectedSizeArg);
}
}
public static class AddExprMacro implements ExprMacroTable.ExprMacro
{
public static final String FN_NAME = "bloom_filter_add";
@Override
public String name()
{
return FN_NAME;
}
@Override
public Expr apply(List<Expr> args)
{
if (args.size() != 2) {
throw new IAE("Function[%s] must have 2 arguments", name());
}
class BloomExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr
{
private BloomExpr(List<Expr> args)
{
super(FN_NAME, args);
}
@Override
public ExprEval eval(final ObjectBinding bindings)
{
ExprEval bloomy = args.get(1).eval(bindings);
// be permissive for now, we can count more on this later when we are better at retaining complete complex
// type information everywhere
if (!bloomy.type().equals(BLOOM_FILTER_TYPE) ||
!bloomy.type().is(ExprType.COMPLEX) && bloomy.value() instanceof BloomKFilter) {
throw new IAE("Function[%s] must take a bloom filter as the second argument", FN_NAME);
}
BloomKFilter filter = (BloomKFilter) bloomy.value();
assert filter != null;
ExprEval input = args.get(0).eval(bindings);
if (input.value() == null) {
filter.addBytes(null, 0, 0);
} else {
switch (input.type().getType()) {
case STRING:
filter.addString(input.asString());
break;
case DOUBLE:
filter.addDouble(input.asDouble());
break;
case LONG:
filter.addLong(input.asLong());
break;
case COMPLEX:
if (BLOOM_FILTER_TYPE.equals(input.type()) || (bloomy.type().is(ExprType.COMPLEX) && bloomy.value() instanceof BloomKFilter)) {
filter.merge((BloomKFilter) input.value());
break;
}
default:
throw new IAE("Function[%s] cannot add [%s] to a bloom filter", FN_NAME, input.type());
}
}
return ExprEval.ofComplex(BLOOM_FILTER_TYPE, filter);
}
@Override
public Expr visit(Shuttle shuttle)
{
List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList());
return shuttle.visit(new BloomExpr(newArgs));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return BLOOM_FILTER_TYPE;
}
}
return new BloomExpr(args);
}
}
public static class TestExprMacro implements ExprMacroTable.ExprMacro
{
public static final String FN_NAME = "bloom_filter_test";
@Override
public String name()
{
return FN_NAME;
}
@Override
public Expr apply(List<Expr> args)
{
if (args.size() != 2) {
throw new IAE("Function[%s] must have 2 arguments", name());
}
class BloomExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
{
private final BloomKFilter filter;
private BloomExpr(BloomKFilter filter, Expr arg)
{
super(FN_NAME, arg);
this.filter = filter;
}
@Nonnull
@Override
public ExprEval eval(final ObjectBinding bindings)
{
ExprEval evaluated = arg.eval(bindings);
boolean matches = false;
switch (evaluated.type().getType()) {
case STRING:
String stringVal = (String) evaluated.value();
if (stringVal == null) {
matches = nullMatch();
} else {
matches = filter.testString(stringVal);
}
break;
case DOUBLE:
Double doubleVal = (Double) evaluated.value();
if (doubleVal == null) {
matches = nullMatch();
} else {
matches = filter.testDouble(doubleVal);
}
break;
case LONG:
Long longVal = (Long) evaluated.value();
if (longVal == null) {
matches = nullMatch();
} else {
matches = filter.testLong(longVal);
}
break;
}
return ExprEval.ofLongBoolean(matches);
}
private boolean nullMatch()
{
return filter.testBytes(null, 0, 0);
}
@Override
public Expr visit(Shuttle shuttle)
{
Expr newArg = arg.visit(shuttle);
return shuttle.visit(new BloomExpr(filter, newArg));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return ExpressionType.LONG;
}
}
class DynamicBloomExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr
{
public DynamicBloomExpr(List<Expr> args)
{
super(FN_NAME, args);
}
@Override
public ExprEval eval(final ObjectBinding bindings)
{
ExprEval bloomy = args.get(1).eval(bindings);
// be permissive for now, we can count more on this later when we are better at retaining complete complex
// type information everywhere
if (!bloomy.type().equals(BLOOM_FILTER_TYPE) ||
!bloomy.type().is(ExprType.COMPLEX) && bloomy.value() instanceof BloomKFilter) {
throw new IAE("Function[%s] must take a bloom filter as the second argument", FN_NAME);
}
BloomKFilter filter = (BloomKFilter) bloomy.value();
assert filter != null;
ExprEval input = args.get(0).eval(bindings);
boolean matches = false;
switch (input.type().getType()) {
case STRING:
String stringVal = (String) input.value();
if (stringVal == null) {
matches = nullMatch(filter);
} else {
matches = filter.testString(stringVal);
}
break;
case DOUBLE:
Double doubleVal = (Double) input.value();
if (doubleVal == null) {
matches = nullMatch(filter);
} else {
matches = filter.testDouble(doubleVal);
}
break;
case LONG:
Long longVal = (Long) input.value();
if (longVal == null) {
matches = nullMatch(filter);
} else {
matches = filter.testLong(longVal);
}
break;
}
return ExprEval.ofLongBoolean(matches);
}
private boolean nullMatch(BloomKFilter filter)
{
return filter.testBytes(null, 0, 0);
}
@Override
public Expr visit(Shuttle shuttle)
{
List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList());
return shuttle.visit(new DynamicBloomExpr(newArgs));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return ExpressionType.LONG;
}
}
final Expr arg = args.get(0);
final Expr filterExpr = args.get(1);
if (filterExpr.isLiteral() && filterExpr.getLiteralValue() instanceof String) {
final String serializedFilter = (String) filterExpr.getLiteralValue();
final byte[] decoded = StringUtils.decodeBase64String(serializedFilter);
BloomKFilter filter;
try {
filter = BloomFilterSerializersModule.bloomKFilterFromBytes(decoded);
}
catch (IOException ioe) {
throw new RuntimeException("Failed to deserialize bloom filter", ioe);
}
return new BloomExpr(filter, arg);
} else {
return new DynamicBloomExpr(args);
}
}
}
}

View File

@ -28,7 +28,7 @@ import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.guice.BloomFilterSerializersModule;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.expressions.BloomFilterExprMacro;
import org.apache.druid.query.expressions.BloomFilterExpressions;
import org.apache.druid.query.filter.BloomDimFilter;
import org.apache.druid.query.filter.BloomKFilter;
import org.apache.druid.query.filter.BloomKFilterHolder;
@ -49,14 +49,14 @@ import java.util.List;
public class BloomFilterOperatorConversion extends DirectOperatorConversion
{
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder(StringUtils.toUpperCase(BloomFilterExprMacro.FN_NAME))
.operatorBuilder(StringUtils.toUpperCase(BloomFilterExpressions.TestExprMacro.FN_NAME))
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER)
.returnTypeInference(ReturnTypes.BOOLEAN_NULLABLE)
.build();
public BloomFilterOperatorConversion()
{
super(SQL_FUNCTION, BloomFilterExprMacro.FN_NAME);
super(SQL_FUNCTION, BloomFilterExpressions.TestExprMacro.FN_NAME);
}
@Override

View File

@ -0,0 +1,226 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.expressions;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.query.filter.BloomKFilter;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
public class BloomFilterExpressionsTest extends InitializedNullHandlingTest
{
private static final String SOME_STRING = "foo";
private static final long SOME_LONG = 1234L;
private static final double SOME_DOUBLE = 1.234;
private static final String[] SOME_STRING_ARRAY = new String[]{"hello", "world"};
private static final Long[] SOME_LONG_ARRAY = new Long[]{1L, 2L, 3L, 4L};
private static final Double[] SOME_DOUBLE_ARRAY = new Double[]{1.2, 3.4};
BloomFilterExpressions.CreateExprMacro createMacro = new BloomFilterExpressions.CreateExprMacro();
BloomFilterExpressions.AddExprMacro addMacro = new BloomFilterExpressions.AddExprMacro();
BloomFilterExpressions.TestExprMacro testMacro = new BloomFilterExpressions.TestExprMacro();
ExprMacroTable macroTable = new ExprMacroTable(ImmutableList.of(createMacro, addMacro, testMacro));
Expr.ObjectBinding inputBindings = InputBindings.withTypedSuppliers(
new ImmutableMap.Builder<String, Pair<ExpressionType, Supplier<Object>>>()
.put("bloomy", new Pair<>(BloomFilterExpressions.BLOOM_FILTER_TYPE, () -> new BloomKFilter(100)))
.put("string", new Pair<>(ExpressionType.STRING, () -> SOME_STRING))
.put("long", new Pair<>(ExpressionType.LONG, () -> SOME_LONG))
.put("double", new Pair<>(ExpressionType.DOUBLE, () -> SOME_DOUBLE))
.put("string_array", new Pair<>(ExpressionType.STRING_ARRAY, () -> SOME_STRING_ARRAY))
.put("long_array", new Pair<>(ExpressionType.LONG_ARRAY, () -> SOME_LONG_ARRAY))
.put("double_array", new Pair<>(ExpressionType.DOUBLE_ARRAY, () -> SOME_DOUBLE_ARRAY))
.build()
);
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Test
public void testCreate()
{
Expr expr = Parser.parse("bloom_filter(100)", macroTable);
ExprEval eval = expr.eval(inputBindings);
Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof BloomKFilter);
Assert.assertEquals(1024, ((BloomKFilter) eval.value()).getBitSize());
}
@Test
public void testAddString()
{
Expr expr = Parser.parse("bloom_filter_add('foo', bloomy)", macroTable);
ExprEval eval = expr.eval(inputBindings);
Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof BloomKFilter);
Assert.assertTrue(((BloomKFilter) eval.value()).testString(SOME_STRING));
expr = Parser.parse("bloom_filter_add(string, bloomy)", macroTable);
eval = expr.eval(inputBindings);
Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof BloomKFilter);
Assert.assertTrue(((BloomKFilter) eval.value()).testString(SOME_STRING));
}
@Test
public void testAddLong()
{
Expr expr = Parser.parse("bloom_filter_add(1234, bloomy)", macroTable);
ExprEval eval = expr.eval(inputBindings);
Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof BloomKFilter);
Assert.assertTrue(((BloomKFilter) eval.value()).testLong(SOME_LONG));
expr = Parser.parse("bloom_filter_add(long, bloomy)", macroTable);
eval = expr.eval(inputBindings);
Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof BloomKFilter);
Assert.assertTrue(((BloomKFilter) eval.value()).testLong(SOME_LONG));
}
@Test
public void testAddDouble()
{
Expr expr = Parser.parse("bloom_filter_add(1.234, bloomy)", macroTable);
ExprEval eval = expr.eval(inputBindings);
Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof BloomKFilter);
Assert.assertTrue(((BloomKFilter) eval.value()).testDouble(SOME_DOUBLE));
expr = Parser.parse("bloom_filter_add(double, bloomy)", macroTable);
eval = expr.eval(inputBindings);
Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof BloomKFilter);
Assert.assertTrue(((BloomKFilter) eval.value()).testDouble(SOME_DOUBLE));
}
@Test
public void testFilter()
{
Expr expr = Parser.parse("bloom_filter_test(1.234, bloom_filter_add(1.234, bloomy))", macroTable);
ExprEval eval = expr.eval(inputBindings);
Assert.assertEquals(ExpressionType.LONG, eval.type());
Assert.assertTrue(eval.asBoolean());
expr = Parser.parse("bloom_filter_test(1234, bloom_filter_add(1234, bloomy))", macroTable);
eval = expr.eval(inputBindings);
Assert.assertTrue(eval.asBoolean());
expr = Parser.parse("bloom_filter_test('foo', bloom_filter_add('foo', bloomy))", macroTable);
eval = expr.eval(inputBindings);
Assert.assertTrue(eval.asBoolean());
expr = Parser.parse("bloom_filter_test('bar', bloom_filter_add('foo', bloomy))", macroTable);
eval = expr.eval(inputBindings);
Assert.assertFalse(eval.asBoolean());
expr = Parser.parse("bloom_filter_test(1234, bloom_filter_add('foo', bloomy))", macroTable);
eval = expr.eval(inputBindings);
Assert.assertFalse(eval.asBoolean());
expr = Parser.parse("bloom_filter_test(1.23, bloom_filter_add('foo', bloomy))", macroTable);
eval = expr.eval(inputBindings);
Assert.assertFalse(eval.asBoolean());
expr = Parser.parse("bloom_filter_test(1234, bloom_filter_add(1234, bloom_filter(100)))", macroTable);
eval = expr.eval(inputBindings);
Assert.assertTrue(eval.asBoolean());
expr = Parser.parse("bloom_filter_test(4321, bloom_filter_add(1234, bloom_filter(100)))", macroTable);
eval = expr.eval(inputBindings);
Assert.assertFalse(eval.asBoolean());
expr = Parser.parse("bloom_filter_test(4321, bloom_filter_add(bloom_filter_add(1234, bloom_filter(100)), bloom_filter_add(4321, bloom_filter(100))))", macroTable);
eval = expr.eval(inputBindings);
Assert.assertTrue(eval.asBoolean());
}
@Test
public void testCreateWrongArgsCount()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[bloom_filter] must have 1 argument");
Parser.parse("bloom_filter()", macroTable);
}
@Test
public void testAddWrongArgsCount()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[bloom_filter_add] must have 2 arguments");
Parser.parse("bloom_filter_add(1)", macroTable);
}
@Test
public void testAddWrongArgType()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[bloom_filter_add] must take a bloom filter as the second argument");
Parser.parse("bloom_filter_add(1, 2)", macroTable);
}
@Test
public void testAddWrongArgType2()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[bloom_filter_add] cannot add [ARRAY<LONG>] to a bloom filter");
Expr expr = Parser.parse("bloom_filter_add(ARRAY<LONG>[], bloomy)", macroTable);
expr.eval(inputBindings);
}
@Test
public void testTestWrongArgsCount()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[bloom_filter_test] must have 2 arguments");
Parser.parse("bloom_filter_test(1)", macroTable);
}
@Test
public void testTestWrongArgsType()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[bloom_filter_test] must take a bloom filter as the second argument");
Parser.parse("bloom_filter_test(1, 2)", macroTable);
}
}

View File

@ -31,7 +31,7 @@ import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.Druids;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.expression.LookupExprMacro;
import org.apache.druid.query.expressions.BloomFilterExprMacro;
import org.apache.druid.query.expressions.BloomFilterExpressions;
import org.apache.druid.query.filter.BloomDimFilter;
import org.apache.druid.query.filter.BloomKFilter;
import org.apache.druid.query.filter.BloomKFilterHolder;
@ -74,7 +74,7 @@ public class BloomDimFilterSqlTest extends BaseCalciteQueryTest
exprMacros.add(CalciteTests.INJECTOR.getInstance(clazz));
}
exprMacros.add(CalciteTests.INJECTOR.getInstance(LookupExprMacro.class));
exprMacros.add(new BloomFilterExprMacro());
exprMacros.add(new BloomFilterExpressions.TestExprMacro());
return new ExprMacroTable(exprMacros);
}

View File

@ -21,28 +21,17 @@ package org.apache.druid.query.expressions;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.Expr.ObjectBinding;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
import javax.annotation.Nullable;
import java.util.Collections;
public class SleepExprTest extends InitializedNullHandlingTest
{
private final ObjectBinding bindings = new ObjectBinding()
{
@Nullable
@Override
public Object get(String name)
{
return null;
}
};
private final ExprMacroTable exprMacroTable = new ExprMacroTable(Collections.singletonList(new SleepExprMacro()));
@Test
@ -66,7 +55,7 @@ public class SleepExprTest extends InitializedNullHandlingTest
final long detla = 50;
final long before = System.currentTimeMillis();
final Expr expr = Parser.parse(expression, exprMacroTable);
expr.eval(bindings).value();
expr.eval(InputBindings.nilBindings()).value();
final long after = System.currentTimeMillis();
final long elapsed = after - before;
Assert.assertTrue(
@ -79,14 +68,14 @@ public class SleepExprTest extends InitializedNullHandlingTest
private void assertExpr(final String expression)
{
final Expr expr = Parser.parse(expression, exprMacroTable);
Assert.assertNull(expression, expr.eval(bindings).value());
Assert.assertNull(expression, expr.eval(InputBindings.nilBindings()).value());
final Expr exprNoFlatten = Parser.parse(expression, exprMacroTable, false);
final Expr roundTrip = Parser.parse(exprNoFlatten.stringify(), exprMacroTable);
Assert.assertNull(expr.stringify(), roundTrip.eval(bindings).value());
Assert.assertNull(expr.stringify(), roundTrip.eval(InputBindings.nilBindings()).value());
final Expr roundTripFlatten = Parser.parse(expr.stringify(), exprMacroTable);
Assert.assertNull(expr.stringify(), roundTripFlatten.eval(bindings).value());
Assert.assertNull(expr.stringify(), roundTripFlatten.eval(InputBindings.nilBindings()).value());
Assert.assertEquals(expr.stringify(), roundTrip.stringify());
Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify());

View File

@ -39,6 +39,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.segment.DimensionHandlerUtils;
import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.incremental.IncrementalIndex;
import org.apache.druid.segment.serde.ComplexMetricSerde;
@ -330,7 +331,7 @@ public class InputRowSerde
writeString(k, out);
try (Aggregator agg = aggFactory.factorize(
IncrementalIndex.makeColumnSelectorFactory(VirtualColumns.EMPTY, aggFactory, supplier, true)
IncrementalIndex.makeColumnSelectorFactory(RowSignature::empty, VirtualColumns.EMPTY, aggFactory, supplier, true)
)) {
try {
agg.aggregate();

View File

@ -770,11 +770,6 @@
<groupId>it.unimi.dsi</groupId>
<artifactId>fastutil-core</artifactId>
<version>${fastutil.version}</version>
</dependency>
<dependency>
<groupId>it.unimi.dsi</groupId>
<artifactId>fastutil-extra</artifactId>
<version>${fastutil.version}</version>
</dependency>
<dependency>
<groupId>com.opencsv</groupId>

View File

@ -24,6 +24,7 @@ import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.Module;
import org.apache.druid.jackson.JacksonModule;
import org.apache.druid.math.expr.ExpressionProcessingModule;
import java.util.ArrayList;
import java.util.Arrays;
@ -43,6 +44,7 @@ public class GuiceInjectors
new RuntimeInfoModule(),
new ConfigModule(),
new NullHandlingModule(),
new ExpressionProcessingModule(),
binder -> {
binder.bind(DruidSecondaryModule.class);
JsonConfigProvider.bind(binder, "druid.extensions", ExtensionsConfig.class);

View File

@ -19,10 +19,17 @@
package org.apache.druid.query.aggregation;
import com.google.common.annotations.VisibleForTesting;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.segment.column.ObjectByteStrategy;
import org.apache.druid.segment.column.Types;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.Objects;
public class ExpressionLambdaAggregator implements Aggregator
{
@ -48,7 +55,7 @@ public class ExpressionLambdaAggregator implements Aggregator
public void aggregate()
{
final ExprEval<?> eval = lambda.eval(bindings);
ExprEval.estimateAndCheckMaxBytes(eval, maxSizeBytes);
estimateAndCheckMaxBytes(eval, maxSizeBytes);
bindings.accumulate(eval);
hasValue = true;
}
@ -89,4 +96,88 @@ public class ExpressionLambdaAggregator implements Aggregator
{
// nothing to close
}
/**
* Tries to mimic the byte serialization of {@link Types} binary methods use to write expression values for the
* {@link ExpressionLambdaBufferAggregator} in an attempt to provide consistent size limits when using the heap
* based algorithm.
*/
@VisibleForTesting
public static void estimateAndCheckMaxBytes(ExprEval eval, int maxSizeBytes)
{
final int estimated;
switch (eval.type().getType()) {
case STRING:
String stringValue = eval.asString();
estimated = Integer.BYTES + (stringValue == null ? 0 : StringUtils.estimatedBinaryLengthAsUTF8(stringValue));
break;
case LONG:
case DOUBLE:
estimated = Long.BYTES;
break;
case ARRAY:
switch (eval.type().getElementType().getType()) {
case STRING:
String[] stringArray = eval.asStringArray();
if (stringArray == null) {
estimated = Integer.BYTES;
} else {
final int elementsSize = Arrays.stream(stringArray)
.filter(Objects::nonNull)
.mapToInt(StringUtils::estimatedBinaryLengthAsUTF8)
.sum();
// since each value is variably sized, there is a null byte, and an integer length per element
estimated = Integer.BYTES + (Integer.BYTES * stringArray.length) + elementsSize;
}
break;
case LONG:
Long[] longArray = eval.asLongArray();
if (longArray == null) {
estimated = Integer.BYTES;
} else {
final int elementsSize = Arrays.stream(longArray)
.filter(Objects::nonNull)
.mapToInt(x -> Long.BYTES)
.sum();
// null byte + length int + byte per element + size per element
estimated = Integer.BYTES + longArray.length + elementsSize;
}
break;
case DOUBLE:
Double[] doubleArray = eval.asDoubleArray();
if (doubleArray == null) {
estimated = Integer.BYTES;
} else {
final int elementsSize = Arrays.stream(doubleArray)
.filter(Objects::nonNull)
.mapToInt(x -> Long.BYTES)
.sum();
// null byte + length int + byte per element + size per element
estimated = Integer.BYTES + doubleArray.length + elementsSize;
}
break;
default:
throw new ISE("Unsupported array type: %s", eval.type());
}
break;
case COMPLEX:
final ObjectByteStrategy strategy = Types.getStrategy(eval.type().getComplexTypeName());
if (strategy != null) {
if (eval.value() != null) {
// | null (byte) | length (int) | complex type bytes |
final byte[] complexBytes = strategy.toBytes(eval.value());
estimated = Integer.BYTES + complexBytes.length;
} else {
estimated = Integer.BYTES;
}
} else {
throw new ISE("Unsupported type: %s", eval.type());
}
break;
default:
throw new ISE("Unsupported type: %s", eval.type());
}
// +1 for the null byte
Types.checkMaxBytes(eval.type(), 1 + estimated, maxSizeBytes);
}
}

View File

@ -39,12 +39,12 @@ import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.math.expr.SettableObjectBinding;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.expression.ExprUtils;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.virtual.ExpressionPlan;
import org.apache.druid.segment.virtual.ExpressionPlanner;
import org.apache.druid.segment.virtual.ExpressionSelectors;
@ -92,12 +92,9 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
private final Supplier<Expr> finalizeExpression;
private final HumanReadableBytes maxSizeBytes;
private final Supplier<SettableObjectBinding> compareBindings =
Suppliers.memoize(() -> new SettableObjectBinding(2));
private final Supplier<SettableObjectBinding> combineBindings =
Suppliers.memoize(() -> new SettableObjectBinding(2));
private final Supplier<SettableObjectBinding> finalizeBindings =
Suppliers.memoize(() -> new SettableObjectBinding(1));
private final Supplier<SettableObjectBinding> compareBindings;
private final Supplier<SettableObjectBinding> combineBindings;
private final Supplier<SettableObjectBinding> finalizeBindings;
private final Supplier<Expr.InputBindingInspector> finalizeInspector;
@JsonCreator
@ -145,12 +142,12 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
this.initialValue = Suppliers.memoize(() -> {
Expr parsed = Parser.parse(initialValue, macroTable);
Preconditions.checkArgument(parsed.isLiteral(), "initial value must be constant");
return parsed.eval(ExprUtils.nilBindings());
return parsed.eval(InputBindings.nilBindings());
});
this.initialCombineValue = Suppliers.memoize(() -> {
Expr parsed = Parser.parse(this.initialCombineValueExpressionString, macroTable);
Preconditions.checkArgument(parsed.isLiteral(), "initial combining value must be constant");
return parsed.eval(ExprUtils.nilBindings());
return parsed.eval(InputBindings.nilBindings());
});
this.foldExpression = Parser.lazyParse(foldExpressionString, macroTable);
this.combineExpression = Parser.lazyParse(combineExpressionString, macroTable);
@ -160,6 +157,29 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
ImmutableMap.of(FINALIZE_IDENTIFIER, this.initialCombineValue.get().type())
)
);
this.compareBindings = Suppliers.memoize(
() -> new SettableObjectBinding(2).withInspector(
InputBindings.inspectorFromTypeMap(
ImmutableMap.of(
COMPARE_O1, this.initialCombineValue.get().type(),
COMPARE_O2, this.initialCombineValue.get().type()
)
)
)
);
this.combineBindings = Suppliers.memoize(
() -> new SettableObjectBinding(2).withInspector(
InputBindings.inspectorFromTypeMap(
ImmutableMap.of(
accumulatorId, this.initialCombineValue.get().type(),
name, this.initialCombineValue.get().type()
)
)
)
);
this.finalizeBindings = Suppliers.memoize(
() -> new SettableObjectBinding(1).withInspector(finalizeInspector.get())
);
this.finalizeExpression = Parser.lazyParse(finalizeExpressionString, macroTable);
this.maxSizeBytes = maxSizeBytes != null ? maxSizeBytes : DEFAULT_MAX_SIZE_BYTES;
Preconditions.checkArgument(this.maxSizeBytes.getBytesInInt() >= MIN_SIZE_BYTES);
@ -285,11 +305,13 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
return (o1, o2) ->
compareExpr.eval(compareBindings.get().withBinding(COMPARE_O1, o1).withBinding(COMPARE_O2, o2)).asInt();
}
switch (initialValue.get().type().getType()) {
switch (initialCombineValue.get().type().getType()) {
case LONG:
return LongSumAggregator.COMPARATOR;
case DOUBLE:
return DoubleSumAggregator.COMPARATOR;
case COMPLEX:
return Types.getStrategy(initialCombineValue.get().type().getComplexTypeName());
default:
return Comparators.naturalNullsFirst();
}

View File

@ -21,6 +21,7 @@ package org.apache.druid.query.aggregation;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExpressionType;
import javax.annotation.Nullable;
@ -57,6 +58,16 @@ public class ExpressionLambdaAggregatorInputBindings implements Expr.ObjectBindi
return inputBindings.get(name);
}
@Nullable
@Override
public ExpressionType getType(String name)
{
if (accumlatorIdentifier.equals(name)) {
return accumulator.type();
}
return inputBindings.getType(name);
}
public void accumulate(ExprEval<?> eval)
{
accumulator = eval;

View File

@ -21,7 +21,7 @@ package org.apache.druid.query.aggregation;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.ExpressionType;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@ -35,6 +35,7 @@ public class ExpressionLambdaBufferAggregator implements BufferAggregator
private final ExpressionLambdaAggregatorInputBindings bindings;
private final int maxSizeBytes;
private final boolean isNullUnlessAggregated;
private final ExpressionType outputType;
public ExpressionLambdaBufferAggregator(
Expr lambda,
@ -46,6 +47,7 @@ public class ExpressionLambdaBufferAggregator implements BufferAggregator
{
this.lambda = lambda;
this.initialValue = initialValue;
this.outputType = initialValue.type();
this.bindings = bindings;
this.isNullUnlessAggregated = isNullUnlessAggregated;
this.maxSizeBytes = maxSizeBytes;
@ -64,7 +66,7 @@ public class ExpressionLambdaBufferAggregator implements BufferAggregator
@Override
public void aggregate(ByteBuffer buf, int position)
{
ExprEval<?> acc = ExprEval.deserialize(buf, position + 1, getType(buf, position));
ExprEval<?> acc = ExprEval.deserialize(buf, position, outputType);
bindings.setAccumulator(acc);
ExprEval<?> newAcc = lambda.eval(bindings);
ExprEval.serialize(buf, position, newAcc, maxSizeBytes);
@ -79,25 +81,25 @@ public class ExpressionLambdaBufferAggregator implements BufferAggregator
if (isNullUnlessAggregated && (buf.get(position) & NOT_AGGREGATED_BIT) != 0) {
return null;
}
return ExprEval.deserialize(buf, position + 1, getType(buf, position)).value();
return ExprEval.deserialize(buf, position, outputType).value();
}
@Override
public float getFloat(ByteBuffer buf, int position)
{
return (float) ExprEval.deserialize(buf, position + 1, getType(buf, position)).asDouble();
return (float) ExprEval.deserialize(buf, position, outputType).asDouble();
}
@Override
public double getDouble(ByteBuffer buf, int position)
{
return ExprEval.deserialize(buf, position + 1, getType(buf, position)).asDouble();
return ExprEval.deserialize(buf, position, outputType).asDouble();
}
@Override
public long getLong(ByteBuffer buf, int position)
{
return ExprEval.deserialize(buf, position + 1, getType(buf, position)).asLong();
return ExprEval.deserialize(buf, position, outputType).asLong();
}
@Override
@ -105,9 +107,4 @@ public class ExpressionLambdaBufferAggregator implements BufferAggregator
{
// nothing to close
}
private static ExprType getType(ByteBuffer buf, int position)
{
return ExprType.fromByte((byte) (buf.get(position) & IS_AGGREGATED_MASK));
}
}

View File

@ -35,13 +35,6 @@ import javax.annotation.Nullable;
public class ExprUtils
{
private static final Expr.ObjectBinding NIL_BINDINGS = name -> null;
public static Expr.ObjectBinding nilBindings()
{
return NIL_BINDINGS;
}
static DateTimeZone toTimeZone(final Expr timeZoneArg)
{
if (!timeZoneArg.isLiteral()) {
@ -110,19 +103,6 @@ public class ExprUtils
static boolean isStringLiteral(final Expr expr)
{
return (expr.isLiteral() && expr.getLiteralValue() instanceof String)
|| (NullHandling.replaceWithDefault() && isNullLiteral(expr));
}
/**
* True if Expr is a null literal.
*
* In non-SQL-compliant null handling mode, this method will return true for either a null literal or an empty string
* literal (they are treated equivalently and we cannot tell the difference).
*
* In SQL-compliant null handling mode, this method will only return true for an actual null literal.
*/
static boolean isNullLiteral(final Expr expr)
{
return expr.isLiteral() && expr.getLiteralValue() == null;
|| (NullHandling.replaceWithDefault() && expr.isNullLiteral());
}
}

View File

@ -0,0 +1,333 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.expression;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.hll.HyperLogLogCollector;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.aggregation.cardinality.CardinalityAggregator;
import org.apache.druid.query.aggregation.cardinality.types.StringCardinalityAggregatorColumnSelectorStrategy;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
public class HyperUniqueExpressions
{
public static final ExpressionType TYPE = ExpressionType.fromColumnType(HyperUniquesAggregatorFactory.TYPE);
public static class HllCreateExprMacro implements ExprMacroTable.ExprMacro
{
private static final String NAME = "hyper_unique";
@Override
public String name()
{
return NAME;
}
@Override
public Expr apply(List<Expr> args)
{
if (args.size() > 0) {
throw new IAE("Function[%s] must have no arguments", name());
}
final HyperLogLogCollector collector = HyperLogLogCollector.makeLatestCollector();
class HllExpression implements ExprMacroTable.ExprMacroFunctionExpr
{
@Override
public ExprEval eval(ObjectBinding bindings)
{
return ExprEval.ofComplex(TYPE, collector);
}
@Override
public String stringify()
{
return StringUtils.format("%s()", NAME);
}
@Override
public Expr visit(Shuttle shuttle)
{
return shuttle.visit(this);
}
@Override
public BindingAnalysis analyzeInputs()
{
return BindingAnalysis.EMTPY;
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return TYPE;
}
@Override
public List<Expr> getArgs()
{
return Collections.emptyList();
}
@Override
public int hashCode()
{
return Objects.hashCode(NAME);
}
@Override
public boolean equals(Object obj)
{
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
return true;
}
@Override
public String toString()
{
return StringUtils.format("(%s)", NAME);
}
}
return new HllExpression();
}
}
public static class HllAddExprMacro implements ExprMacroTable.ExprMacro
{
private static final String NAME = "hyper_unique_add";
@Override
public String name()
{
return NAME;
}
@Override
public Expr apply(List<Expr> args)
{
if (args.size() != 2) {
throw new IAE("Function[%s] must have 2 arguments", name());
}
class HllExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr
{
public HllExpr(List<Expr> args)
{
super(NAME, args);
}
@Override
public ExprEval eval(ObjectBinding bindings)
{
ExprEval hllCollector = args.get(1).eval(bindings);
ExpressionType hllType = hllCollector.type();
// be permissive for now, we can count more on this later when we are better at retaining complete complex
// type information everywhere
if (!TYPE.equals(hllType) ||
!(hllType.is(ExprType.COMPLEX) && hllCollector.value() instanceof HyperLogLogCollector)
) {
throw new IAE("Function[%s] must take a hyper-log-log collector as the second argument", NAME);
}
HyperLogLogCollector collector = (HyperLogLogCollector) hllCollector.value();
assert collector != null;
ExprEval input = args.get(0).eval(bindings);
switch (input.type().getType()) {
case STRING:
if (input.value() == null) {
if (NullHandling.replaceWithDefault()) {
collector.add(
CardinalityAggregator.HASH_FUNCTION.hashUnencodedChars(
StringCardinalityAggregatorColumnSelectorStrategy.CARDINALITY_AGG_NULL_STRING
).asBytes()
);
}
} else {
collector.add(CardinalityAggregator.HASH_FUNCTION.hashUnencodedChars(input.asString()).asBytes());
}
break;
case DOUBLE:
if (NullHandling.replaceWithDefault() || !input.isNumericNull()) {
collector.add(CardinalityAggregator.HASH_FUNCTION.hashLong(Double.doubleToLongBits(input.asDouble())).asBytes());
}
break;
case LONG:
if (NullHandling.replaceWithDefault() || !input.isNumericNull()) {
collector.add(CardinalityAggregator.HASH_FUNCTION.hashLong(input.asLong()).asBytes());
}
break;
case COMPLEX:
if (TYPE.equals(input.type()) || hllType.is(ExprType.COMPLEX) && hllCollector.value() instanceof HyperLogLogCollector) {
collector.fold((HyperLogLogCollector) input.value());
break;
}
default:
throw new IAE("Function[%s] cannot add [%s] to hyper-log-log collector", NAME, input.type());
}
return ExprEval.ofComplex(TYPE, collector);
}
@Override
public Expr visit(Shuttle shuttle)
{
List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList());
return shuttle.visit(new HllExpr(newArgs));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return TYPE;
}
}
return new HllExpr(args);
}
}
public static class HllEstimateExprMacro implements ExprMacroTable.ExprMacro
{
public static final String NAME = "hyper_unique_estimate";
@Override
public String name()
{
return NAME;
}
@Override
public Expr apply(List<Expr> args)
{
if (args.size() != 1) {
throw new IAE("Function[%s] must have 1 argument", name());
}
class HllExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
{
public HllExpr(Expr arg)
{
super(NAME, arg);
}
@Override
public ExprEval eval(ObjectBinding bindings)
{
ExprEval hllCollector = args.get(0).eval(bindings);
// be permissive for now, we can count more on this later when we are better at retaining complete complex
// type information everywhere
if (!TYPE.equals(hllCollector.type()) ||
!(hllCollector.type().is(ExprType.COMPLEX) && hllCollector.value() instanceof HyperLogLogCollector)
) {
throw new IAE("Function[%s] must take a hyper-log-log collector as input", NAME);
}
HyperLogLogCollector collector = (HyperLogLogCollector) hllCollector.value();
assert collector != null;
return ExprEval.ofDouble(collector.estimateCardinality());
}
@Override
public Expr visit(Shuttle shuttle)
{
Expr newArg = arg.visit(shuttle);
return shuttle.visit(new HllExpr(newArg));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return ExpressionType.DOUBLE;
}
}
return new HllExpr(args.get(0));
}
}
public static class HllRoundEstimateExprMacro implements ExprMacroTable.ExprMacro
{
public static final String NAME = "hyper_unique_round_estimate";
@Override
public String name()
{
return NAME;
}
@Override
public Expr apply(List<Expr> args)
{
if (args.size() != 1) {
throw new IAE("Function[%s] must have 1 argument", name());
}
class HllExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
{
public HllExpr(Expr arg)
{
super(NAME, arg);
}
@Override
public ExprEval eval(ObjectBinding bindings)
{
ExprEval hllCollector = args.get(0).eval(bindings);
if (!hllCollector.type().equals(TYPE)) {
throw new IAE("Function[%s] must take a hyper-log-log collector as input", NAME);
}
HyperLogLogCollector collector = (HyperLogLogCollector) hllCollector.value();
assert collector != null;
return ExprEval.ofLong(collector.estimateCardinalityRound());
}
@Override
public Expr visit(Shuttle shuttle)
{
Expr newArg = arg.visit(shuttle);
return shuttle.visit(new HllExpr(newArg));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return ExpressionType.LONG;
}
}
return new HllExpr(args.get(0));
}
}
}

View File

@ -27,6 +27,7 @@ import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@ -66,7 +67,7 @@ public class TimestampCeilExprMacro implements ExprMacroTable.ExprMacro
TimestampCeilExpr(final List<Expr> args)
{
super(FN_NAME, args);
this.granularity = getGranularity(args, ExprUtils.nilBindings());
this.granularity = getGranularity(args, InputBindings.nilBindings());
}
@Nonnull

View File

@ -25,6 +25,7 @@ import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.vector.CastToTypeVectorProcessor;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;
import org.apache.druid.math.expr.vector.LongOutLongInFunctionVectorProcessor;
@ -76,7 +77,7 @@ public class TimestampFloorExprMacro implements ExprMacroTable.ExprMacro
TimestampFloorExpr(final List<Expr> args)
{
super(FN_NAME, args);
this.granularity = computeGranularity(args, ExprUtils.nilBindings());
this.granularity = computeGranularity(args, InputBindings.nilBindings());
}
/**

View File

@ -25,6 +25,7 @@ import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.joda.time.Chronology;
import org.joda.time.Period;
import org.joda.time.chrono.ISOChronology;
@ -90,9 +91,9 @@ public class TimestampShiftExprMacro implements ExprMacroTable.ExprMacro
TimestampShiftExpr(final List<Expr> args)
{
super(FN_NAME, args);
period = getPeriod(args, ExprUtils.nilBindings());
chronology = getTimeZone(args, ExprUtils.nilBindings());
step = getStep(args, ExprUtils.nilBindings());
period = getPeriod(args, InputBindings.nilBindings());
chronology = getTimeZone(args, InputBindings.nilBindings());
step = getStep(args, InputBindings.nilBindings());
}
@Nonnull

View File

@ -27,6 +27,7 @@ import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@ -97,7 +98,7 @@ public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro
} else {
final Expr charsArg = args.get(1);
if (charsArg.isLiteral()) {
final String charsString = charsArg.eval(ExprUtils.nilBindings()).asString();
final String charsString = charsArg.eval(InputBindings.nilBindings()).asString();
final char[] chars = charsString == null ? EMPTY_CHARS : charsString.toCharArray();
return new TrimStaticCharsExpr(mode, args.get(0), chars, charsArg);
} else {

View File

@ -378,7 +378,7 @@ public class RowBasedGrouperHelper
return RowBasedColumnSelectorFactory.create(
adapter,
supplier::get,
query.getResultRowSignature(),
() -> query.getResultRowSignature(),
false
);
}

View File

@ -230,7 +230,7 @@ public class TimeseriesQueryQueryToolChest extends QueryToolChest<Result<Timeser
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
() -> new MapBasedRow(null, null),
RowSignature.empty(),
() -> RowSignature.builder().addAggregators(aggregatorSpecs).build(),
false
)
);

View File

@ -49,19 +49,19 @@ public class RowBasedColumnSelectorFactory<T> implements ColumnSelectorFactory
{
private final Supplier<T> supplier;
private final RowAdapter<T> adapter;
private final RowSignature rowSignature;
private final Supplier<RowSignature> rowSignatureSupplier;
private final boolean throwParseExceptions;
private RowBasedColumnSelectorFactory(
final Supplier<T> supplier,
final RowAdapter<T> adapter,
final RowSignature rowSignature,
final Supplier<RowSignature> rowSignatureSupplier,
final boolean throwParseExceptions
)
{
this.supplier = supplier;
this.adapter = adapter;
this.rowSignature = Preconditions.checkNotNull(rowSignature, "rowSignature must be nonnull");
this.rowSignatureSupplier = Preconditions.checkNotNull(rowSignatureSupplier, "rowSignature must be nonnull");
this.throwParseExceptions = throwParseExceptions;
}
@ -70,7 +70,7 @@ public class RowBasedColumnSelectorFactory<T> implements ColumnSelectorFactory
*
* @param adapter adapter for these row objects
* @param supplier supplier of row objects
* @param signature will be used for reporting available columns and their capabilities. Note that the this
* @param signatureSupplier will be used for reporting available columns and their capabilities. Note that the this
* factory will still allow creation of selectors on any named field in the rows, even if
* it doesn't appear in "rowSignature". (It only needs to be accessible via
* {@link RowAdapter#columnFunction}.) As a result, you can achieve an untyped mode by
@ -81,11 +81,11 @@ public class RowBasedColumnSelectorFactory<T> implements ColumnSelectorFactory
public static <RowType> RowBasedColumnSelectorFactory<RowType> create(
final RowAdapter<RowType> adapter,
final Supplier<RowType> supplier,
final RowSignature signature,
final Supplier<RowSignature> signatureSupplier,
final boolean throwParseExceptions
)
{
return new RowBasedColumnSelectorFactory<>(supplier, adapter, signature, throwParseExceptions);
return new RowBasedColumnSelectorFactory<>(supplier, adapter, signatureSupplier, throwParseExceptions);
}
@Nullable
@ -452,6 +452,6 @@ public class RowBasedColumnSelectorFactory<T> implements ColumnSelectorFactory
@Override
public ColumnCapabilities getColumnCapabilities(String columnName)
{
return getColumnCapabilities(rowSignature, columnName);
return getColumnCapabilities(rowSignatureSupplier.get(), columnName);
}
}

View File

@ -66,7 +66,7 @@ public class RowBasedCursor<RowType> implements Cursor
RowBasedColumnSelectorFactory.create(
rowAdapter,
rowWalker::currentRow,
rowSignature,
() -> rowSignature,
false
)
);

View File

@ -20,37 +20,15 @@
package org.apache.druid.segment.data;
import org.apache.druid.guice.annotations.ExtensionPoint;
import org.apache.druid.segment.column.ObjectByteStrategy;
import org.apache.druid.segment.writeout.WriteOutBytes;
import javax.annotation.Nullable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Comparator;
@ExtensionPoint
public interface ObjectStrategy<T> extends Comparator<T>
public interface ObjectStrategy<T> extends ObjectByteStrategy<T>
{
Class<? extends T> getClazz();
/**
* Convert values from their underlying byte representation.
*
* Implementations of this method <i>may</i> change the given buffer's mark, or limit, and position.
*
* Implementations of this method <i>may not</i> store the given buffer in a field of the "deserialized" object,
* need to use {@link ByteBuffer#slice()}, {@link ByteBuffer#asReadOnlyBuffer()} or {@link ByteBuffer#duplicate()} in
* this case.
*
* @param buffer buffer to read value from
* @param numBytes number of bytes used to store the value, starting at buffer.position()
* @return an object created from the given byte buffer representation
*/
@Nullable
T fromByteBuffer(ByteBuffer buffer, int numBytes);
@Nullable
byte[] toBytes(@Nullable T val);
/**
* Reads 4-bytes numBytes from the given buffer, and then delegates to {@link #fromByteBuffer(ByteBuffer, int)}.
*/

View File

@ -29,8 +29,8 @@ import org.apache.druid.math.expr.Evals;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.query.BitmapResultFactory;
import org.apache.druid.query.expression.ExprUtils;
import org.apache.druid.query.filter.BitmapIndexSelector;
import org.apache.druid.query.filter.DruidDoublePredicate;
import org.apache.druid.query.filter.DruidFloatPredicate;
@ -143,7 +143,7 @@ public class ExpressionFilter implements Filter
// or not.
return BooleanVectorValueMatcher.of(
factory.getReadableVectorInspector(),
theExpr.eval(ExprUtils.nilBindings()).asBoolean()
theExpr.eval(InputBindings.nilBindings()).asBoolean()
);
}
@ -181,14 +181,14 @@ public class ExpressionFilter implements Filter
final ExprEval eval = selector.getObject();
if (eval.type().isArray()) {
switch (eval.type().getElementType().getType()) {
switch (eval.elementType().getType()) {
case LONG:
final Long[] lResult = eval.asLongArray();
if (lResult == null) {
return false;
}
return Arrays.stream(lResult).anyMatch(Evals::asBoolean);
return Arrays.stream(lResult).filter(Objects::nonNull).anyMatch(Evals::asBoolean);
case STRING:
final String[] sResult = eval.asStringArray();
if (sResult == null) {
@ -202,7 +202,7 @@ public class ExpressionFilter implements Filter
return false;
}
return Arrays.stream(dResult).anyMatch(Evals::asBoolean);
return Arrays.stream(dResult).filter(Objects::nonNull).anyMatch(Evals::asBoolean);
}
}
return eval.asBoolean();
@ -248,7 +248,7 @@ public class ExpressionFilter implements Filter
{
if (bindingDetails.get().getRequiredBindings().isEmpty()) {
// Constant expression.
if (expr.get().eval(ExprUtils.nilBindings()).asBoolean()) {
if (expr.get().eval(InputBindings.nilBindings()).asBoolean()) {
return bitmapResultFactory.wrapAllTrue(Filters.allTrue(selector));
} else {
return bitmapResultFactory.wrapAllFalse(Filters.allFalse(selector));
@ -263,12 +263,12 @@ public class ExpressionFilter implements Filter
column,
selector,
bitmapResultFactory,
value -> expr.get().eval(identifierName -> {
value -> expr.get().eval(InputBindings.forFunction(identifierName -> {
// There's only one binding, and it must be the single column, so it can safely be ignored in production.
assert column.equals(identifierName);
// convert null to Empty before passing to expressions if needed.
return NullHandling.nullToEmptyIfNeeded(value);
}).asBoolean()
})).asBoolean()
);
}
}

View File

@ -109,6 +109,7 @@ public abstract class IncrementalIndex extends AbstractIndex implements Iterable
* @return column selector factory
*/
public static ColumnSelectorFactory makeColumnSelectorFactory(
final Supplier<RowSignature> rowSignatureSupplier,
final VirtualColumns virtualColumns,
final AggregatorFactory agg,
final Supplier<InputRow> in,
@ -118,7 +119,7 @@ public abstract class IncrementalIndex extends AbstractIndex implements Iterable
final RowBasedColumnSelectorFactory<InputRow> baseSelectorFactory = RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
in::get,
RowSignature.empty(),
rowSignatureSupplier::get,
true
);
@ -264,6 +265,8 @@ public abstract class IncrementalIndex extends AbstractIndex implements Iterable
this.deserializeComplexMetrics = deserializeComplexMetrics;
this.timeAndMetricsColumnCapabilities = new HashMap<>();
this.metricDescs = Maps.newLinkedHashMap();
this.dimensionDescs = Maps.newLinkedHashMap();
this.metadata = new Metadata(
null,
getCombiningAggregators(metrics),
@ -274,7 +277,6 @@ public abstract class IncrementalIndex extends AbstractIndex implements Iterable
initAggs(metrics, rowSupplier, deserializeComplexMetrics, concurrentEventAdd);
this.metricDescs = Maps.newLinkedHashMap();
for (AggregatorFactory metric : metrics) {
MetricDesc metricDesc = new MetricDesc(metricDescs.size(), metric);
metricDescs.put(metricDesc.getName(), metricDesc);
@ -282,7 +284,6 @@ public abstract class IncrementalIndex extends AbstractIndex implements Iterable
}
DimensionsSpec dimensionsSpec = incrementalIndexSchema.getDimensionsSpec();
this.dimensionDescs = Maps.newLinkedHashMap();
this.dimensionDescsList = new ArrayList<>();
for (DimensionSchema dimSchema : dimensionsSpec.getDimensions()) {
@ -986,7 +987,15 @@ public abstract class IncrementalIndex extends AbstractIndex implements Iterable
final boolean deserializeComplexMetrics
)
{
return makeColumnSelectorFactory(virtualColumns, agg, in, deserializeComplexMetrics);
Supplier<RowSignature> signatureSupplier = () -> {
Map<String, ColumnCapabilities> capabilitiesMap = getColumnCapabilities();
RowSignature.Builder bob = RowSignature.builder();
for (Map.Entry<String, ColumnCapabilities> capabilitiesEntry : capabilitiesMap.entrySet()) {
bob.add(capabilitiesEntry.getKey(), capabilitiesEntry.getValue().toColumnType());
}
return bob.build();
};
return makeColumnSelectorFactory(signatureSupplier, virtualColumns, agg, in, deserializeComplexMetrics);
}
protected final Comparator<IncrementalIndexRow> dimsComparator()

View File

@ -24,8 +24,8 @@ import org.apache.druid.java.util.common.Pair;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.Exprs;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.query.expression.ExprUtils;
import java.util.ArrayList;
import java.util.Collections;
@ -74,12 +74,12 @@ public class JoinConditionAnalysis
this.nonEquiConditions = Collections.unmodifiableList(nonEquiConditions);
// if any nonEquiCondition is an expression and it evaluates to false
isAlwaysFalse = nonEquiConditions.stream()
.anyMatch(expr -> expr.isLiteral() && !expr.eval(ExprUtils.nilBindings())
.anyMatch(expr -> expr.isLiteral() && !expr.eval(InputBindings.nilBindings())
.asBoolean());
// if there are no equiConditions and all nonEquiConditions are literals and the evaluate to true
isAlwaysTrue = equiConditions.isEmpty() && nonEquiConditions.stream()
.allMatch(expr -> expr.isLiteral() && expr.eval(
ExprUtils.nilBindings()).asBoolean());
InputBindings.nilBindings()).asBoolean());
canHashJoin = nonEquiConditions.stream().allMatch(Expr::isLiteral);
rightKeyColumns = getEquiConditions().stream().map(Equality::getRightColumn).collect(Collectors.toSet());
requiredColumns = computeRequiredColumns(rightPrefix, equiConditions, nonEquiConditions);

View File

@ -20,6 +20,7 @@
package org.apache.druid.segment.serde;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.column.Types;
import javax.annotation.Nullable;
import java.util.concurrent.ConcurrentHashMap;
@ -62,6 +63,7 @@ public class ComplexMetrics
value.getClass().getName()
);
} else {
Types.registerStrategy(type, serde.getObjectStrategy());
return value;
}
}

View File

@ -25,9 +25,12 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import org.apache.druid.data.input.Row;
import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.virtual.ExpressionSelectors;
@ -96,7 +99,9 @@ public class ExpressionTransform implements Transform
@Override
public Object eval(final Row row)
{
return ExpressionSelectors.coerceEvalToSelectorObject(expr.eval(name -> getValueFromRow(row, name)));
return ExpressionSelectors.coerceEvalToSelectorObject(
expr.eval(InputBindings.forFunction(name -> getValueFromRow(row, name)))
);
}
}
@ -107,7 +112,11 @@ public class ExpressionTransform implements Transform
} else {
Object raw = row.getRaw(column);
if (raw instanceof List) {
return ExprEval.coerceListToArray((List) raw, true);
NonnullPair<ExpressionType, Object[]> coerced = ExprEval.coerceListToArray((List) raw, true);
if (coerced == null) {
return null;
}
return coerced.rhs;
}
return raw;
}

View File

@ -59,7 +59,7 @@ public class Transformer
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
rowSupplierForValueMatcher::get,
RowSignature.empty(),
RowSignature::empty, // sad
false
)
);

View File

@ -24,11 +24,13 @@ import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.collect.Iterables;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.expression.ExprUtils;
import org.apache.druid.query.extraction.ExtractionFn;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseObjectColumnValueSelector;
@ -159,7 +161,7 @@ public class ExpressionSelectors
final Expr.ObjectBinding bindings = createBindings(plan.getAnalysis(), columnSelectorFactory);
// Optimization for constant expressions
if (bindings.equals(ExprUtils.nilBindings())) {
if (bindings.equals(InputBindings.nilBindings())) {
return new ConstantExprEvalSelector(plan.getExpression().eval(bindings));
}
@ -261,11 +263,12 @@ public class ExpressionSelectors
List<String> columns
)
{
final Map<String, Supplier<Object>> suppliers = new HashMap<>();
final Map<String, Pair<ExpressionType, Supplier<Object>>> suppliers = new HashMap<>();
for (String columnName : columns) {
final ColumnCapabilities columnCapabilities = columnSelectorFactory.getColumnCapabilities(columnName);
final boolean multiVal = columnCapabilities != null && columnCapabilities.hasMultipleValues().isTrue();
final Supplier<Object> supplier;
final ExpressionType expressionType = ExpressionType.fromColumnType(columnCapabilities);
if (columnCapabilities == null || columnCapabilities.isArray()) {
// Unknown ValueType or array type. Try making an Object selector and see if that gives us anything useful.
@ -285,30 +288,48 @@ public class ExpressionSelectors
multiVal
);
} else {
// Unhandleable ValueType (COMPLEX).
supplier = null;
// complex type just pass straight through
ColumnValueSelector<?> selector = columnSelectorFactory.makeColumnValueSelector(columnName);
if (!(selector instanceof NilColumnValueSelector)) {
supplier = selector::getObject;
} else {
supplier = null;
}
}
if (supplier != null) {
suppliers.put(columnName, supplier);
suppliers.put(columnName, new Pair<>(expressionType, supplier));
}
}
if (suppliers.isEmpty()) {
return ExprUtils.nilBindings();
return InputBindings.nilBindings();
} else if (suppliers.size() == 1 && columns.size() == 1) {
// If there's only one column (and it has a supplier), we can skip the Map and just use that supplier when
// asked for something.
final String column = Iterables.getOnlyElement(suppliers.keySet());
final Supplier<Object> supplier = Iterables.getOnlyElement(suppliers.values());
final Pair<ExpressionType, Supplier<Object>> supplier = Iterables.getOnlyElement(suppliers.values());
return identifierName -> {
// There's only one binding, and it must be the single column, so it can safely be ignored in production.
assert column.equals(identifierName);
return supplier.get();
return new Expr.ObjectBinding()
{
@Nullable
@Override
public Object get(String name)
{
// There's only one binding, and it must be the single column, so it can safely be ignored in production.
assert column.equals(name);
return supplier.rhs.get();
}
@Nullable
@Override
public ExpressionType getType(String name)
{
return supplier.lhs;
}
};
} else {
return InputBindings.withSuppliers(suppliers);
return InputBindings.withTypedSuppliers(suppliers);
}
}
@ -350,9 +371,9 @@ public class ExpressionSelectors
} else {
// column selector factories hate you and use [] and [null] interchangeably for nullish data
if (row.size() == 0) {
return new String[]{null};
return new Object[]{null};
}
final String[] strings = new String[row.size()];
final Object[] strings = new Object[row.size()];
// noinspection SSBasedInspection
for (int i = 0; i < row.size(); i++) {
strings[i] = selector.lookupName(row.get(i));
@ -382,25 +403,31 @@ public class ExpressionSelectors
// Might be Numbers and Strings. Use a selector that double-checks.
return () -> {
final Object val = selector.getObject();
if (val instanceof Number || val instanceof String || (val != null && val.getClass().isArray())) {
return val;
} else if (val instanceof List) {
return ExprEval.coerceListToArray((List) val, true);
if (val instanceof List) {
NonnullPair<ExpressionType, Object[]> coerced = ExprEval.coerceListToArray((List) val, true);
if (coerced == null) {
return null;
}
return coerced.rhs;
} else {
return null;
return val;
}
};
} else if (clazz.isAssignableFrom(List.class)) {
return () -> {
final Object val = selector.getObject();
if (val != null) {
return ExprEval.coerceListToArray((List) val, true);
NonnullPair<ExpressionType, Object[]> coerced = ExprEval.coerceListToArray((List) val, true);
if (coerced == null) {
return null;
}
return coerced.rhs;
}
return null;
};
} else {
// No numbers or strings.
return null;
// No numbers or strings, just pass it through
return selector::getObject;
}
}
@ -412,16 +439,7 @@ public class ExpressionSelectors
public static Object coerceEvalToSelectorObject(ExprEval eval)
{
if (eval.type().isArray()) {
switch (eval.type().getElementType().getType()) {
case STRING:
return Arrays.stream(eval.asStringArray()).collect(Collectors.toList());
case DOUBLE:
return Arrays.stream(eval.asDoubleArray()).collect(Collectors.toList());
case LONG:
return Arrays.stream(eval.asLongArray()).collect(Collectors.toList());
default:
}
return Arrays.stream(eval.asArray()).collect(Collectors.toList());
}
return eval.value();
}

View File

@ -23,9 +23,9 @@ import com.google.common.base.Preconditions;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.expression.ExprUtils;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.vector.ConstantVectorSelectors;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
@ -52,7 +52,7 @@ public class ExpressionVectorSelectors
// only constant expressions are currently supported, nothing else should get here
if (plan.isConstant()) {
String constant = plan.getExpression().eval(ExprUtils.nilBindings()).asString();
String constant = plan.getExpression().eval(InputBindings.nilBindings()).asString();
return ConstantVectorSelectors.singleValueDimensionVectorSelector(factory.getReadableVectorInspector(), constant);
}
if (plan.is(ExpressionPlan.Trait.SINGLE_INPUT_SCALAR) && (plan.getOutputType() != null && plan.getOutputType().is(ExprType.STRING))) {
@ -75,7 +75,7 @@ public class ExpressionVectorSelectors
if (plan.isConstant()) {
return ConstantVectorSelectors.vectorValueSelector(
factory.getReadableVectorInspector(),
(Number) plan.getExpression().eval(ExprUtils.nilBindings()).value()
(Number) plan.getExpression().eval(InputBindings.nilBindings()).value()
);
}
final Expr.VectorInputBinding bindings = createVectorBindings(plan.getAnalysis(), factory);
@ -94,7 +94,7 @@ public class ExpressionVectorSelectors
if (plan.isConstant()) {
return ConstantVectorSelectors.vectorObjectSelector(
factory.getReadableVectorInspector(),
plan.getExpression().eval(ExprUtils.nilBindings()).value()
plan.getExpression().eval(InputBindings.nilBindings()).value()
);
}

View File

@ -34,7 +34,7 @@ import java.util.stream.Collectors;
* Expression column value selector that examines a set of 'unknown' type input bindings on a row by row basis,
* transforming the expression to handle multi-value list typed inputs as they are encountered.
*
* Currently, string dimensions are the only bindings which might appear as a {@link String} or a {@link String[]}, so
* Currently, string dimensions are the only bindings which might appear as a {@link String} or a {@link Object[]}, so
* numbers are eliminated from the set of 'unknown' bindings to check as they are encountered.
*/
public class RowBasedExpressionColumnValueSelector extends ExpressionColumnValueSelector
@ -94,7 +94,7 @@ public class RowBasedExpressionColumnValueSelector extends ExpressionColumnValue
{
Object binding = bindings.get(x);
if (binding != null) {
if (binding instanceof String[]) {
if (binding instanceof Object[] && ((Object[]) binding).length > 0) {
return true;
} else if (binding instanceof Number) {
ignoredColumns.add(x);

View File

@ -20,6 +20,7 @@
package org.apache.druid.segment.virtual;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExpressionType;
import javax.annotation.Nullable;
@ -28,6 +29,13 @@ public class SingleInputBindings implements Expr.ObjectBinding
@Nullable
private Object value;
private final ExpressionType type;
public SingleInputBindings(ExpressionType type)
{
this.type = type;
}
@Override
public Object get(final String name)
{
@ -38,4 +46,11 @@ public class SingleInputBindings implements Expr.ObjectBinding
{
this.value = value;
}
@Nullable
@Override
public ExpressionType getType(String name)
{
return type;
}
}

View File

@ -24,6 +24,7 @@ import it.unimi.dsi.fastutil.longs.Long2ObjectLinkedOpenHashMap;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.ColumnValueSelector;
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
@ -41,7 +42,7 @@ public class SingleLongInputCachingExpressionColumnValueSelector implements Colu
private final ColumnValueSelector selector;
private final Expr expression;
private final SingleInputBindings bindings = new SingleInputBindings();
private final SingleInputBindings bindings = new SingleInputBindings(ExpressionType.LONG);
@Nullable
private final LruEvalCache lruEvalCache;

View File

@ -25,6 +25,8 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectLinkedOpenHashMap;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.DimensionDictionarySelector;
@ -63,7 +65,7 @@ public class SingleStringInputCachingExpressionColumnValueSelector implements Co
this.expression = Preconditions.checkNotNull(expression, "expression");
final Supplier<Object> inputSupplier = ExpressionSelectors.supplierFromDimensionSelector(selector, false);
this.bindings = name -> inputSupplier.get();
this.bindings = InputBindings.singleProvider(ExpressionType.STRING, name -> inputSupplier.get());
if (selector.getValueCardinality() == DimensionDictionarySelector.CARDINALITY_UNKNOWN) {
throw new ISE("Selector must have a dictionary");

View File

@ -23,6 +23,7 @@ import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.filter.ValueMatcher;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.DimensionDictionarySelector;
@ -48,7 +49,7 @@ public class SingleStringInputDeferredEvaluationExpressionDimensionSelector impl
{
private final DimensionSelector selector;
private final Expr expression;
private final SingleInputBindings bindings = new SingleInputBindings();
private final SingleInputBindings bindings = new SingleInputBindings(ExpressionType.STRING);
public SingleStringInputDeferredEvaluationExpressionDimensionSelector(
final DimensionSelector selector,

View File

@ -26,6 +26,7 @@ import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.Druids;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.query.expression.TestExprMacroTable;
@ -189,7 +190,7 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli
ImmutableSet.of("x"),
null,
"0",
null,
"<LONG>[]",
true,
"array_set_add(__acc, x)",
"array_set_add_all(__acc, expr_agg_name)",
@ -410,6 +411,52 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli
Assert.assertEquals(ColumnType.STRING, agg.getFinalizedType());
}
@Test
public void testComplexType()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column"),
null,
"hyper_unique()",
null,
null,
"hyper_unique_add(some_column, __acc)",
"hyper_unique_add(__acc, expr_agg_name)",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(HyperUniquesAggregatorFactory.TYPE, agg.getType());
Assert.assertEquals(HyperUniquesAggregatorFactory.TYPE, agg.getCombiningFactory().getType());
Assert.assertEquals(HyperUniquesAggregatorFactory.TYPE, agg.getFinalizedType());
}
@Test
public void testComplexTypeFinalized()
{
ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
"expr_agg_name",
ImmutableSet.of("some_column"),
null,
"hyper_unique()",
null,
null,
"hyper_unique_add(some_column, __acc)",
"hyper_unique_add(__acc, expr_agg_name)",
null,
"hyper_unique_estimate(o)",
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
);
Assert.assertEquals(HyperUniquesAggregatorFactory.TYPE, agg.getType());
Assert.assertEquals(HyperUniquesAggregatorFactory.TYPE, agg.getCombiningFactory().getType());
Assert.assertEquals(ColumnType.DOUBLE, agg.getFinalizedType());
}
@Test
public void testResultArraySignature()
{
@ -544,6 +591,34 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli
"fold((x, acc) -> x + acc, o, 0)",
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"complex_expr",
ImmutableSet.of("some_column"),
null,
"hyper_unique()",
null,
null,
"hyper_unique_add(some_column, __acc)",
"hyper_unique_add(__acc, expr_agg_name)",
null,
null,
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
),
new ExpressionLambdaAggregatorFactory(
"complex_expr_finalized",
ImmutableSet.of("some_column"),
null,
"hyper_unique()",
null,
null,
"hyper_unique_add(some_column, __acc)",
"hyper_unique_add(__acc, expr_agg_name)",
null,
"hyper_unique_estimate(o)",
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
)
)
.postAggregators(
@ -552,7 +627,9 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli
new FieldAccessPostAggregator("double-array-expr-access", "double_array_expr_finalized"),
new FinalizingFieldAccessPostAggregator("double-array-expr-finalize", "double_array_expr_finalized"),
new FieldAccessPostAggregator("long-array-expr-access", "long_array_expr_finalized"),
new FinalizingFieldAccessPostAggregator("long-array-expr-finalize", "long_array_expr_finalized")
new FinalizingFieldAccessPostAggregator("long-array-expr-finalize", "long_array_expr_finalized"),
new FieldAccessPostAggregator("complex-expr-access", "complex_expr_finalized"),
new FinalizingFieldAccessPostAggregator("complex-expr-finalize", "complex_expr_finalized")
)
.build();
@ -576,6 +653,10 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli
.add("double_array_expr_finalized", null)
// long because fold type equals finalized type, even though merge type is array
.add("long_array_expr_finalized", ColumnType.LONG)
.add("complex_expr", HyperUniquesAggregatorFactory.TYPE)
// type does not equal finalized type. (combining factory type does equal finalized type,
// but this signature doesn't use combining factory)
.add("complex_expr_finalized", null)
// fold type is string
.add("string-array-expr-access", ColumnType.STRING)
// finalized type is string
@ -588,6 +669,8 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli
.add("long-array-expr-access", ColumnType.LONG)
// finalized type is long
.add("long-array-expr-finalize", ColumnType.LONG)
.add("complex-expr-access", HyperUniquesAggregatorFactory.TYPE)
.add("complex-expr-finalize", ColumnType.DOUBLE)
.build(),
new TimeseriesQueryQueryToolChest().resultArraySignature(query)
);

View File

@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
public class ExpressionLambdaAggregatorTest extends InitializedNullHandlingTest
{
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Test
public void testEstimateString()
{
ExpressionLambdaAggregator.estimateAndCheckMaxBytes(ExprEval.ofType(ExpressionType.STRING, "hello"), 10);
}
@Test
public void testEstimateStringTooBig()
{
expectedException.expect(ISE.class);
expectedException.expectMessage("Unable to serialize [STRING], size [12] is larger than max [5]");
ExpressionLambdaAggregator.estimateAndCheckMaxBytes(ExprEval.ofType(ExpressionType.STRING, "too big"), 5);
}
@Test
public void testEstimateStringArray()
{
ExpressionLambdaAggregator.estimateAndCheckMaxBytes(
ExprEval.ofType(ExpressionType.STRING_ARRAY, new Object[] {"a", "b", "c", "d"}),
30
);
}
@Test
public void testEstimateStringArrayTooBig()
{
expectedException.expect(ISE.class);
expectedException.expectMessage("Unable to serialize [ARRAY<STRING>], size [25] is larger than max [15]");
ExpressionLambdaAggregator.estimateAndCheckMaxBytes(
ExprEval.ofType(ExpressionType.STRING_ARRAY, new Object[] {"a", "b", "c", "d"}),
15
);
}
@Test
public void testEstimateLongArray()
{
ExpressionLambdaAggregator.estimateAndCheckMaxBytes(
ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1L, 2L, 3L, 4L}),
64
);
}
@Test
public void testEstimateLongArrayTooBig()
{
expectedException.expect(ISE.class);
expectedException.expectMessage("Unable to serialize [ARRAY<LONG>], size [41] is larger than max [24]");
ExpressionLambdaAggregator.estimateAndCheckMaxBytes(
ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1L, 2L, 3L, 4L}),
24
);
}
@Test
public void testEstimateDoubleArray()
{
ExpressionLambdaAggregator.estimateAndCheckMaxBytes(
ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1.0, 2.0, 3.0, 4.0}),
64
);
}
@Test
public void testEstimateDoubleArrayTooBig()
{
expectedException.expect(ISE.class);
expectedException.expectMessage("Unable to serialize [ARRAY<DOUBLE>], size [41] is larger than max [24]");
ExpressionLambdaAggregator.estimateAndCheckMaxBytes(
ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1.0, 2.0, 3.0, 4.0}),
24
);
}
}

View File

@ -135,7 +135,7 @@ public class CaseInsensitiveExprMacroTest extends MacroTestBase
final ExprEval<?> result = eval(
"icontains_string(a, null)",
InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))
InputBindings.nilBindings()
);
Assert.assertEquals(
ExprEval.ofBoolean(true, ExprType.LONG).value(),
@ -146,7 +146,7 @@ public class CaseInsensitiveExprMacroTest extends MacroTestBase
@Test
public void testEmptyStringSearchOnNull()
{
final ExprEval<?> result = eval("icontains_string(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
final ExprEval<?> result = eval("icontains_string(a, '')", InputBindings.nilBindings());
Assert.assertEquals(
ExprEval.ofBoolean(!NullHandling.sqlCompatible(), ExprType.LONG).value(),
result.value()

View File

@ -123,7 +123,7 @@ public class ContainsExprMacroTest extends MacroTestBase
expectException(IllegalArgumentException.class, "Function[contains_string] substring must be a string literal");
}
final ExprEval<?> result = eval("contains_string(a, null)", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
final ExprEval<?> result = eval("contains_string(a, null)", InputBindings.nilBindings());
Assert.assertEquals(
ExprEval.ofBoolean(true, ExprType.LONG).value(),
result.value()
@ -133,7 +133,7 @@ public class ContainsExprMacroTest extends MacroTestBase
@Test
public void testEmptyStringSearchOnNull()
{
final ExprEval<?> result = eval("contains_string(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
final ExprEval<?> result = eval("contains_string(a, '')", InputBindings.nilBindings());
Assert.assertEquals(
ExprEval.ofBoolean(!NullHandling.sqlCompatible(), ExprType.LONG).value(),
result.value()

View File

@ -0,0 +1,256 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.expression;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.hll.HyperLogLogCollector;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
public class HyperUniqueExpressionsTest extends InitializedNullHandlingTest
{
private static final ExprMacroTable MACRO_TABLE = new ExprMacroTable(
ImmutableList.of(
new HyperUniqueExpressions.HllCreateExprMacro(),
new HyperUniqueExpressions.HllAddExprMacro(),
new HyperUniqueExpressions.HllEstimateExprMacro(),
new HyperUniqueExpressions.HllRoundEstimateExprMacro()
)
);
private static final String SOME_STRING = "foo";
private static final long SOME_LONG = 1234L;
private static final double SOME_DOUBLE = 1.234;
Expr.ObjectBinding inputBindings = InputBindings.withTypedSuppliers(
new ImmutableMap.Builder<String, Pair<ExpressionType, Supplier<Object>>>()
.put("hll", new Pair<>(HyperUniqueExpressions.TYPE, HyperLogLogCollector::makeLatestCollector))
.put("string", new Pair<>(ExpressionType.STRING, () -> SOME_STRING))
.put("long", new Pair<>(ExpressionType.LONG, () -> SOME_LONG))
.put("double", new Pair<>(ExpressionType.DOUBLE, () -> SOME_DOUBLE))
.put("nullString", new Pair<>(ExpressionType.STRING, () -> null))
.put("nullLong", new Pair<>(ExpressionType.LONG, () -> null))
.put("nullDouble", new Pair<>(ExpressionType.DOUBLE, () -> null))
.build()
);
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Test
public void testCreate()
{
Expr expr = Parser.parse("hyper_unique()", MACRO_TABLE);
ExprEval eval = expr.eval(inputBindings);
Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
Assert.assertEquals(0.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0);
}
@Test
public void testString()
{
Expr expr = Parser.parse("hyper_unique_add('foo', hyper_unique())", MACRO_TABLE);
ExprEval eval = expr.eval(inputBindings);
Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
Assert.assertEquals(1.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
expr = Parser.parse("hyper_unique_add('bar', hyper_unique_add('foo', hyper_unique()))", MACRO_TABLE);
eval = expr.eval(inputBindings);
Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
Assert.assertEquals(2.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
expr = Parser.parse("hyper_unique_add(string, hyper_unique())", MACRO_TABLE);
eval = expr.eval(inputBindings);
Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
Assert.assertEquals(1.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
expr = Parser.parse("hyper_unique_add(nullString, hyper_unique())", MACRO_TABLE);
eval = expr.eval(inputBindings);
Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 1.0 : 0.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
}
@Test
public void testLong()
{
Expr expr = Parser.parse("hyper_unique_add(1234, hyper_unique())", MACRO_TABLE);
ExprEval eval = expr.eval(inputBindings);
Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
Assert.assertEquals(1.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
expr = Parser.parse("hyper_unique_add(1234, hyper_unique_add(5678, hyper_unique()))", MACRO_TABLE);
eval = expr.eval(inputBindings);
Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
Assert.assertEquals(2.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
expr = Parser.parse("hyper_unique_add(long, hyper_unique())", MACRO_TABLE);
eval = expr.eval(inputBindings);
Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
Assert.assertEquals(1.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
expr = Parser.parse("hyper_unique_add(nullLong, hyper_unique())", MACRO_TABLE);
eval = expr.eval(inputBindings);
Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 1.0 : 0.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
}
@Test
public void testDouble()
{
Expr expr = Parser.parse("hyper_unique_add(1.234, hyper_unique())", MACRO_TABLE);
ExprEval eval = expr.eval(inputBindings);
Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
Assert.assertEquals(1.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
expr = Parser.parse("hyper_unique_add(1.234, hyper_unique_add(5.678, hyper_unique()))", MACRO_TABLE);
eval = expr.eval(inputBindings);
Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
Assert.assertEquals(2.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
expr = Parser.parse("hyper_unique_add(double, hyper_unique())", MACRO_TABLE);
eval = expr.eval(inputBindings);
Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
Assert.assertEquals(1.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
expr = Parser.parse("hyper_unique_add(nullDouble, hyper_unique())", MACRO_TABLE);
eval = expr.eval(inputBindings);
Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
Assert.assertEquals(NullHandling.replaceWithDefault() ? 1.0 : 0.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
}
@Test
public void testEstimate()
{
Expr expr = Parser.parse("hyper_unique_estimate(hyper_unique_add(1.234, hyper_unique()))", MACRO_TABLE);
ExprEval eval = expr.eval(inputBindings);
Assert.assertEquals(ExpressionType.DOUBLE, eval.type());
Assert.assertEquals(1.0, eval.asDouble(), 0.01);
}
@Test
public void testEstimateRound()
{
Expr expr = Parser.parse("hyper_unique_round_estimate(hyper_unique_add(1.234, hyper_unique()))", MACRO_TABLE);
ExprEval eval = expr.eval(inputBindings);
Assert.assertEquals(ExpressionType.LONG, eval.type());
Assert.assertEquals(1L, eval.asLong(), 0.01);
}
@Test
public void testCreateWrongArgsCount()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[hyper_unique] must have no arguments");
Parser.parse("hyper_unique(100)", MACRO_TABLE);
}
@Test
public void testAddWrongArgsCount()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[hyper_unique_add] must have 2 arguments");
Parser.parse("hyper_unique_add(100, hyper_unique(), 100)", MACRO_TABLE);
}
@Test
public void testAddWrongArgType()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[hyper_unique_add] must take a hyper-log-log collector as the second argument");
Expr expr = Parser.parse("hyper_unique_add(long, string)", MACRO_TABLE);
expr.eval(inputBindings);
}
@Test
public void testEstimateWrongArgsCount()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[hyper_unique_estimate] must have 1 argument");
Parser.parse("hyper_unique_estimate(hyper_unique(), 100)", MACRO_TABLE);
}
@Test
public void testEstimateWrongArgTypes()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[hyper_unique_estimate] must take a hyper-log-log collector as input");
Expr expr = Parser.parse("hyper_unique_estimate(100)", MACRO_TABLE);
expr.eval(inputBindings);
}
@Test
public void testRoundEstimateWrongArgsCount()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[hyper_unique_round_estimate] must have 1 argument");
Parser.parse("hyper_unique_round_estimate(hyper_unique(), 100)", MACRO_TABLE);
}
@Test
public void testRoundEstimateWrongArgTypes()
{
expectedException.expect(IAE.class);
expectedException.expectMessage("Function[hyper_unique_round_estimate] must take a hyper-log-log collector as input");
Expr expr = Parser.parse("hyper_unique_round_estimate(string)", MACRO_TABLE);
expr.eval(inputBindings);
}
}

View File

@ -22,6 +22,7 @@ package org.apache.druid.query.expression;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.InputBindings;
import org.junit.Assert;
import org.junit.Test;
@ -179,7 +180,7 @@ public class IPv4AddressMatchExprMacroTest extends MacroTestBase
private boolean eval(Expr... args)
{
Expr expr = apply(Arrays.asList(args));
ExprEval eval = expr.eval(ExprUtils.nilBindings());
ExprEval eval = expr.eval(InputBindings.nilBindings());
return eval.asBoolean();
}

View File

@ -22,6 +22,7 @@ package org.apache.druid.query.expression;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.InputBindings;
import org.junit.Assert;
import org.junit.Test;
@ -151,7 +152,7 @@ public class IPv4AddressParseExprMacroTest extends MacroTestBase
private Object eval(Expr arg)
{
Expr expr = apply(Collections.singletonList(arg));
ExprEval eval = expr.eval(ExprUtils.nilBindings());
ExprEval eval = expr.eval(InputBindings.nilBindings());
return eval.value();
}
}

View File

@ -22,6 +22,7 @@ package org.apache.druid.query.expression;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.InputBindings;
import org.junit.Assert;
import org.junit.Test;
@ -147,7 +148,7 @@ public class IPv4AddressStringifyExprMacroTest extends MacroTestBase
private Object eval(Expr arg)
{
Expr expr = apply(Collections.singletonList(arg));
ExprEval eval = expr.eval(ExprUtils.nilBindings());
ExprEval eval = expr.eval(InputBindings.nilBindings());
return eval.value();
}
}

View File

@ -128,14 +128,14 @@ public class RegexpExtractExprMacroTest extends MacroTestBase
expectException(IllegalArgumentException.class, "Function[regexp_extract] pattern must be a string literal");
}
final ExprEval<?> result = eval("regexp_extract(a, null)", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
final ExprEval<?> result = eval("regexp_extract(a, null)", InputBindings.nilBindings());
Assert.assertNull(result.value());
}
@Test
public void testEmptyStringPatternOnNull()
{
final ExprEval<?> result = eval("regexp_extract(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
final ExprEval<?> result = eval("regexp_extract(a, '')", InputBindings.nilBindings());
Assert.assertNull(result.value());
}
}

View File

@ -122,7 +122,7 @@ public class RegexpLikeExprMacroTest extends MacroTestBase
expectException(IllegalArgumentException.class, "Function[regexp_like] pattern must be a string literal");
}
final ExprEval<?> result = eval("regexp_like(a, null)", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
final ExprEval<?> result = eval("regexp_like(a, null)", InputBindings.nilBindings());
Assert.assertEquals(
ExprEval.ofLongBoolean(true).value(),
result.value()
@ -132,7 +132,7 @@ public class RegexpLikeExprMacroTest extends MacroTestBase
@Test
public void testEmptyStringPatternOnNull()
{
final ExprEval<?> result = eval("regexp_like(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
final ExprEval<?> result = eval("regexp_like(a, '')", InputBindings.nilBindings());
Assert.assertEquals(
ExprEval.ofLongBoolean(NullHandling.replaceWithDefault()).value(),
result.value()

View File

@ -43,7 +43,11 @@ public class TestExprMacroTable extends ExprMacroTable
new TimestampShiftExprMacro(),
new TrimExprMacro.BothTrimExprMacro(),
new TrimExprMacro.LeftTrimExprMacro(),
new TrimExprMacro.RightTrimExprMacro()
new TrimExprMacro.RightTrimExprMacro(),
new HyperUniqueExpressions.HllCreateExprMacro(),
new HyperUniqueExpressions.HllAddExprMacro(),
new HyperUniqueExpressions.HllEstimateExprMacro(),
new HyperUniqueExpressions.HllRoundEstimateExprMacro()
)
);
}

View File

@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableList;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.InputBindings;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@ -53,7 +54,7 @@ public class TimestampExtractExprMacroTest
ExprEval.of("2001-02-16").toExpr(),
ExprEval.of(TimestampExtractExprMacro.Unit.DECADE.toString()).toExpr()
));
Assert.assertEquals(200, expression.eval(ExprUtils.nilBindings()).asInt());
Assert.assertEquals(200, expression.eval(InputBindings.nilBindings()).asInt());
}
@Test
@ -64,7 +65,7 @@ public class TimestampExtractExprMacroTest
ExprEval.of("2000-12-16").toExpr(),
ExprEval.of(TimestampExtractExprMacro.Unit.CENTURY.toString()).toExpr()
));
Assert.assertEquals(20, expression.eval(ExprUtils.nilBindings()).asInt());
Assert.assertEquals(20, expression.eval(InputBindings.nilBindings()).asInt());
}
@Test
@ -75,7 +76,7 @@ public class TimestampExtractExprMacroTest
ExprEval.of("2001-02-16").toExpr(),
ExprEval.of(TimestampExtractExprMacro.Unit.CENTURY.toString()).toExpr()
));
Assert.assertEquals(21, expression.eval(ExprUtils.nilBindings()).asInt());
Assert.assertEquals(21, expression.eval(InputBindings.nilBindings()).asInt());
}
@Test
@ -86,7 +87,7 @@ public class TimestampExtractExprMacroTest
ExprEval.of("2000-12-16").toExpr(),
ExprEval.of(TimestampExtractExprMacro.Unit.MILLENNIUM.toString()).toExpr()
));
Assert.assertEquals(2, expression.eval(ExprUtils.nilBindings()).asInt());
Assert.assertEquals(2, expression.eval(InputBindings.nilBindings()).asInt());
}
@Test
@ -97,6 +98,6 @@ public class TimestampExtractExprMacroTest
ExprEval.of("2001-02-16").toExpr(),
ExprEval.of(TimestampExtractExprMacro.Unit.MILLENNIUM.toString()).toExpr()
));
Assert.assertEquals(3, expression.eval(ExprUtils.nilBindings()).asInt());
Assert.assertEquals(3, expression.eval(InputBindings.nilBindings()).asInt());
}
}

View File

@ -26,6 +26,8 @@ import org.apache.druid.java.util.common.IAE;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.joda.time.DateTime;
import org.joda.time.Days;
import org.joda.time.Minutes;
@ -102,7 +104,7 @@ public class TimestampShiftMacroTest extends MacroTestBase
Assert.assertEquals(
timestamp.withPeriodAdded(Months.ONE, step).getMillis(),
expr.eval(ExprUtils.nilBindings()).asLong()
expr.eval(InputBindings.nilBindings()).asLong()
);
}
@ -119,7 +121,7 @@ public class TimestampShiftMacroTest extends MacroTestBase
Assert.assertEquals(
timestamp.withPeriodAdded(Months.ONE, step).getMillis(),
expr.eval(ExprUtils.nilBindings()).asLong()
expr.eval(InputBindings.nilBindings()).asLong()
);
}
@ -136,7 +138,7 @@ public class TimestampShiftMacroTest extends MacroTestBase
Assert.assertEquals(
timestamp.withPeriodAdded(Months.ONE, step).getMillis(),
expr.eval(ExprUtils.nilBindings()).asLong()
expr.eval(InputBindings.nilBindings()).asLong()
);
}
@ -152,7 +154,7 @@ public class TimestampShiftMacroTest extends MacroTestBase
Assert.assertEquals(
timestamp.withPeriodAdded(Minutes.ONE, 1).getMillis(),
expr.eval(ExprUtils.nilBindings()).asLong()
expr.eval(InputBindings.nilBindings()).asLong()
);
}
@ -168,7 +170,7 @@ public class TimestampShiftMacroTest extends MacroTestBase
Assert.assertEquals(
timestamp.withPeriodAdded(Days.ONE, 1).getMillis(),
expr.eval(ExprUtils.nilBindings()).asLong()
expr.eval(InputBindings.nilBindings()).asLong()
);
}
@ -185,7 +187,7 @@ public class TimestampShiftMacroTest extends MacroTestBase
Assert.assertEquals(
timestamp.toDateTime(DateTimes.inferTzFromString("America/Los_Angeles")).withPeriodAdded(Years.ONE, 1).getMillis(),
expr.eval(ExprUtils.nilBindings()).asLong()
expr.eval(InputBindings.nilBindings()).asLong()
);
}
@ -206,6 +208,13 @@ public class TimestampShiftMacroTest extends MacroTestBase
timestamp.toDateTime(DateTimes.inferTzFromString("America/Los_Angeles")).withPeriodAdded(Years.ONE, step).getMillis(),
expr.eval(new Expr.ObjectBinding()
{
@Nullable
@Override
public ExpressionType getType(String name)
{
return null;
}
@Nullable
@Override
public Object get(String name)
@ -232,9 +241,9 @@ public class TimestampShiftMacroTest extends MacroTestBase
);
if (NullHandling.replaceWithDefault()) {
Assert.assertEquals(2678400000L, expr.eval(ExprUtils.nilBindings()).value());
Assert.assertEquals(2678400000L, expr.eval(InputBindings.nilBindings()).value());
} else {
Assert.assertNull(expr.eval(ExprUtils.nilBindings()).value());
Assert.assertNull(expr.eval(InputBindings.nilBindings()).value());
}
}

View File

@ -217,7 +217,7 @@ public class InDimFilterTest extends InitializedNullHandlingTest
final RowBasedColumnSelectorFactory<MapBasedRow> columnSelectorFactory = RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
() -> new MapBasedRow(0, row),
RowSignature.builder().add("dim", ColumnType.STRING).build(),
() -> RowSignature.builder().add("dim", ColumnType.STRING).build(),
true
);

View File

@ -11442,6 +11442,91 @@ public class GroupByQueryRunnerTest extends InitializedNullHandlingTest
TestHelper.assertExpectedObjects(expectedResults, results, "groupBy");
}
@Test
public void testGroupByWithExpressionAggregatorWithComplex()
{
cannotVectorize();
final GroupByQuery query = makeQueryBuilder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
.setDimensions(Collections.emptyList())
.setAggregatorSpecs(
new CardinalityAggregatorFactory(
"car",
ImmutableList.of(new DefaultDimensionSpec("quality", "quality")),
false
),
new ExpressionLambdaAggregatorFactory(
"carExpr",
ImmutableSet.of("quality"),
null,
"hyper_unique()",
null,
null,
"hyper_unique_add(quality, __acc)",
"hyper_unique_add(carExpr, __acc)",
null,
"hyper_unique_estimate(o)",
null,
TestExprMacroTable.INSTANCE
)
)
.setGranularity(QueryRunnerTestHelper.ALL_GRAN)
.build();
List<ResultRow> expectedResults = Collections.singletonList(
makeRow(query, "1970-01-01", "car", QueryRunnerTestHelper.UNIQUES_9, "carExpr", QueryRunnerTestHelper.UNIQUES_9)
);
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
TestHelper.assertExpectedObjects(expectedResults, results, "subquery-cardinality");
}
@Test
public void testGroupByWithExpressionAggregatorWithComplexOnSubquery()
{
final GroupByQuery subquery = makeQueryBuilder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setQuerySegmentSpec(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
.setDimensions(new DefaultDimensionSpec("market", "market"), new DefaultDimensionSpec("quality", "quality"))
.setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("index", "index"))
.setGranularity(QueryRunnerTestHelper.ALL_GRAN)
.build();
final GroupByQuery query = makeQueryBuilder()
.setDataSource(subquery)
.setQuerySegmentSpec(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
.setDimensions(Collections.emptyList())
.setAggregatorSpecs(
new CardinalityAggregatorFactory(
"car",
ImmutableList.of(new DefaultDimensionSpec("quality", "quality")),
false
),
new ExpressionLambdaAggregatorFactory(
"carExpr",
ImmutableSet.of("quality"),
null,
"hyper_unique()",
null,
null,
"hyper_unique_add(quality, __acc)",
null,
null,
"hyper_unique_estimate(o)",
null,
TestExprMacroTable.INSTANCE
)
)
.setGranularity(QueryRunnerTestHelper.ALL_GRAN)
.build();
List<ResultRow> expectedResults = Collections.singletonList(
makeRow(query, "1970-01-01", "car", QueryRunnerTestHelper.UNIQUES_9, "carExpr", QueryRunnerTestHelper.UNIQUES_9)
);
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
TestHelper.assertExpectedObjects(expectedResults, results, "subquery-cardinality");
}
@Test
public void testGroupByWithExpressionAggregatorWithArrays()
{

View File

@ -6075,6 +6075,71 @@ public class TopNQueryRunnerTest extends InitializedNullHandlingTest
assertExpectedResults(expectedResults, query);
}
@Test
public void testExpressionAggregatorComplex()
{
// sorted by array hyperunique expression
TopNQuery query = new TopNQueryBuilder()
.dataSource(QueryRunnerTestHelper.DATA_SOURCE)
.granularity(QueryRunnerTestHelper.ALL_GRAN)
.dimension(QueryRunnerTestHelper.MARKET_DIMENSION)
.metric("carExpr")
.threshold(4)
.intervals(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
.aggregators(
ImmutableList.of(
new CardinalityAggregatorFactory(
"car",
ImmutableList.of(new DefaultDimensionSpec("quality", "quality")),
false
),
new ExpressionLambdaAggregatorFactory(
"carExpr",
ImmutableSet.of("quality"),
null,
"hyper_unique()",
null,
null,
"hyper_unique_add(quality, __acc)",
"hyper_unique_add(carExpr, __acc)",
null,
"hyper_unique_estimate(o)",
null,
TestExprMacroTable.INSTANCE
)
)
)
.build();
List<Result<TopNResultValue>> expectedResults = Collections.singletonList(
new Result<>(
DateTimes.of("2011-01-12T00:00:00.000Z"),
new TopNResultValue(
Arrays.<Map<String, Object>>asList(
ImmutableMap.<String, Object>builder()
.put(QueryRunnerTestHelper.MARKET_DIMENSION, "spot")
.put("car", 9.019833517963864)
.put("carExpr", 9.019833517963864)
.build(),
ImmutableMap.<String, Object>builder()
.put(QueryRunnerTestHelper.MARKET_DIMENSION, "total_market")
.put("car", 2.000977198748901)
.put("carExpr", 2.000977198748901)
.build(),
ImmutableMap.<String, Object>builder()
.put(QueryRunnerTestHelper.MARKET_DIMENSION, "upfront")
.put("car", 2.000977198748901)
.put("carExpr", 2.000977198748901)
.build()
)
)
)
);
assertExpectedResults(expectedResults, query);
}
private static Map<String, Object> makeRowWithNulls(
String dimName,
@Nullable Object dimValue,

View File

@ -392,7 +392,7 @@ public class TestHelper
Assert.assertEquals(
message,
(Object[]) expectedValue,
(Object[]) ExprEval.coerceListToArray((List) actualValue, true)
(Object[]) ExprEval.coerceListToArray((List) actualValue, true).rhs
);
} else {
Assert.assertArrayEquals(

View File

@ -754,7 +754,7 @@ public abstract class BaseFilterTest extends InitializedNullHandlingTest
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
rowSupplier::get,
rowSignatureBuilder.build(),
rowSignatureBuilder::build,
false
)
)

View File

@ -203,7 +203,7 @@ public class ExpressionVirtualColumnTest extends InitializedNullHandlingTest
private static final ColumnSelectorFactory COLUMN_SELECTOR_FACTORY = RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
CURRENT_ROW::get,
RowSignature.empty(),
RowSignature::empty,
false
);
@ -743,7 +743,7 @@ public class ExpressionVirtualColumnTest extends InitializedNullHandlingTest
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
CURRENT_ROW::get,
RowSignature.builder().add("x", ColumnType.LONG).build(),
RowSignature.builder().add("x", ColumnType.LONG)::build,
false
),
Parser.parse(SCALE_LONG.getExpression(), TestExprMacroTable.INSTANCE)
@ -766,7 +766,7 @@ public class ExpressionVirtualColumnTest extends InitializedNullHandlingTest
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
CURRENT_ROW::get,
RowSignature.builder().add("x", ColumnType.DOUBLE).build(),
RowSignature.builder().add("x", ColumnType.DOUBLE)::build,
false
),
Parser.parse(SCALE_FLOAT.getExpression(), TestExprMacroTable.INSTANCE)
@ -789,7 +789,7 @@ public class ExpressionVirtualColumnTest extends InitializedNullHandlingTest
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
CURRENT_ROW::get,
RowSignature.builder().add("x", ColumnType.FLOAT).build(),
RowSignature.builder().add("x", ColumnType.FLOAT)::build,
false
),
Parser.parse(SCALE_FLOAT.getExpression(), TestExprMacroTable.INSTANCE)

View File

@ -276,7 +276,7 @@ public class ListFilteredVirtualColumnSelectorTest extends InitializedNullHandli
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
() -> new MapBasedRow(0L, ImmutableMap.of(COLUMN_NAME, ImmutableList.of("a", "b", "c", "d"))),
rowSignature,
() -> rowSignature,
false
),
VirtualColumns.create(ImmutableList.of(virtualColumn))

View File

@ -28,6 +28,7 @@ import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.expression.CaseInsensitiveContainsExprMacro;
import org.apache.druid.query.expression.ContainsExprMacro;
import org.apache.druid.query.expression.GuiceExprMacroTable;
import org.apache.druid.query.expression.HyperUniqueExpressions;
import org.apache.druid.query.expression.IPv4AddressMatchExprMacro;
import org.apache.druid.query.expression.IPv4AddressParseExprMacro;
import org.apache.druid.query.expression.IPv4AddressStringifyExprMacro;
@ -65,6 +66,10 @@ public class ExpressionModule implements DruidModule
.add(TrimExprMacro.BothTrimExprMacro.class)
.add(TrimExprMacro.LeftTrimExprMacro.class)
.add(TrimExprMacro.RightTrimExprMacro.class)
.add(HyperUniqueExpressions.HllCreateExprMacro.class)
.add(HyperUniqueExpressions.HllAddExprMacro.class)
.add(HyperUniqueExpressions.HllEstimateExprMacro.class)
.add(HyperUniqueExpressions.HllRoundEstimateExprMacro.class)
.build();
@Override

View File

@ -27,6 +27,7 @@ import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.query.filter.AndDimFilter;
import org.apache.druid.query.filter.DimFilter;
@ -102,7 +103,7 @@ public class ArrayContainsOperatorConversion extends BaseExpressionDimFilterOper
if (expr.isLiteral()) {
// Evaluate the expression to get out the array elements.
// We can safely pass a noop ObjectBinding if the expression is literal.
ExprEval<?> exprEval = expr.eval(name -> null);
ExprEval<?> exprEval = expr.eval(InputBindings.nilBindings());
String[] arrayElements = exprEval.asStringArray();
if (arrayElements == null || arrayElements.length == 0) {
// If arrayElements is empty which means rightExpr is an empty array,

View File

@ -28,6 +28,7 @@ import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.InDimFilter;
@ -109,7 +110,7 @@ public class ArrayOverlapOperatorConversion extends BaseExpressionDimFilterOpera
if (expr.isLiteral()) {
// Evaluate the expression to take out the array elements.
// We can safely pass null if the expression is literal.
ExprEval<?> exprEval = expr.eval(name -> null);
ExprEval<?> exprEval = expr.eval(InputBindings.nilBindings());
String[] arrayElements = exprEval.asStringArray();
if (arrayElements == null || arrayElements.length == 0) {
// If arrayElements is empty which means complexExpr is an empty array,

View File

@ -30,8 +30,8 @@ import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.query.expression.ExprUtils;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.virtual.ListFilteredVirtualColumn;
import org.apache.druid.sql.calcite.expression.AliasedOperatorConversion;
@ -334,7 +334,7 @@ public class MultiValueStringOperatorConversions
if (!expr.isLiteral()) {
return null;
}
String[] lit = expr.eval(ExprUtils.nilBindings()).asStringArray();
String[] lit = expr.eval(InputBindings.nilBindings()).asStringArray();
if (lit == null || lit.length == 0) {
return null;
}

View File

@ -29,6 +29,7 @@ import org.apache.druid.java.util.common.IAE;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.DruidExpression;
@ -74,10 +75,12 @@ public class DruidRexExecutor implements RexExecutor
final Expr expr = Parser.parse(druidExpression.getExpression(), plannerContext.getExprMacroTable());
final ExprEval exprResult = expr.eval(
name -> {
// Sanity check. Bindings should not be used for a constant expression.
throw new UnsupportedOperationException();
}
InputBindings.forFunction(
name -> {
// Sanity check. Bindings should not be used for a constant expression.
throw new UnsupportedOperationException();
}
)
);
final RexNode literal;

View File

@ -340,6 +340,8 @@ public class QueryMaker
coercedValue = Arrays.asList((Long[]) value);
} else if (value instanceof Double[]) {
coercedValue = Arrays.asList((Double[]) value);
} else if (value instanceof Object[]) {
coercedValue = Arrays.asList((Object[]) value);
} else {
throw new ISE("Cannot coerce[%s] to %s", value.getClass().getName(), sqlType);
}

View File

@ -24,8 +24,11 @@ import com.google.common.collect.ImmutableSet;
import junitparams.JUnitParamsRunner;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.math.expr.ExpressionProcessing;
import org.apache.druid.query.Druids;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
@ -141,21 +144,49 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
}
@Test
public void testSelectNonConstantArrayExpressionFromTableFailForMultival() throws Exception
public void testSelectNonConstantArrayExpressionFromTableForMultival() throws Exception
{
// without expression output type inference to prevent this, the automatic translation will try to turn this into
final String sql = "SELECT ARRAY[CONCAT(dim3, 'word'),'up'] as arr, dim1 FROM foo LIMIT 5";
final Query<?> scanQuery = newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(expressionVirtualColumn("v0", "array(concat(\"dim3\",'word'),'up')", ColumnType.STRING))
.columns("dim1", "v0")
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.limit(5)
.context(QUERY_CONTEXT_DEFAULT)
.build();
ExpressionProcessing.initializeForTests(true);
// if nested arrays are allowed, dim3 is a multi-valued string column, so the automatic translation will turn this
// expression into
//
// `map((dim3) -> array(concat(dim3,'word'),'up'), dim3)`
//
// This error message will get better in the future. The error without translation would be:
//
// org.apache.druid.java.util.common.RE: Unhandled array constructor element type [ARRAY<STRING>]
// this works, but we still translate the output into a string since that is the current output type
// in some future this might not auto-convert to a string type (when we support grouping on arrays maybe?)
expectedException.expect(RuntimeException.class);
expectedException.expectMessage("Unhandled map function output type [ARRAY<STRING>]");
testQuery(
"SELECT ARRAY[CONCAT(dim3, 'word'),'up'] as arr, dim1 FROM foo LIMIT 5",
ImmutableList.of(),
sql,
ImmutableList.of(scanQuery),
ImmutableList.of(
new Object[]{"[[\"aword\",\"up\"],[\"bword\",\"up\"]]", ""},
new Object[]{"[[\"bword\",\"up\"],[\"cword\",\"up\"]]", "10.1"},
new Object[]{"[[\"dword\",\"up\"]]", "2"},
new Object[]{"[[\"word\",\"up\"]]", "1"},
useDefault ? new Object[]{"[[\"word\",\"up\"]]", "def"} : new Object[]{"[[null,\"up\"]]", "def"}
)
);
ExpressionProcessing.initializeForTests(null);
// if nested arrays are not enabled, this doesn't work
expectedException.expect(IAE.class);
expectedException.expectMessage("Cannot create a nested array type [ARRAY<ARRAY<STRING>>], 'druid.expressions.allowNestedArrays' must be set to true");
testQuery(
sql,
ImmutableList.of(scanQuery),
ImmutableList.of()
);
}

View File

@ -301,7 +301,7 @@ class ExpressionTestHelper
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
() -> new MapBasedRow(0L, bindings),
rowSignature,
() -> rowSignature,
false
),
VirtualColumns.create(virtualColumns)

View File

@ -21,6 +21,7 @@ package org.apache.druid.sql.calcite.util;
import com.google.common.collect.ImmutableList;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.math.expr.ExpressionProcessing;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.http.SqlParameter;
import org.junit.BeforeClass;
@ -36,5 +37,6 @@ public abstract class CalciteTestBase
{
Calcites.setSystemProperties();
NullHandling.initializeForTests();
ExpressionProcessing.initializeForTests(null);
}
}