Fix for schema mismatch to go down using the non vectorize path till we update the vectorized aggs properly (#14924)

* Fix for schema mismatch to go down using the non vectorize path till we update the vectorized aggs properly

* Fixing a failed test

* Updating numericNilAgg

* Moving to use default values in case of nil agg

* Adding the same for first agg

* Fixing a test

* fixing vectorized string agg for last/first with cast if numeric

* Updating tests to remove mockito and cover the case of string first/last on non string columns

* Updating a test to vectorize

* Addressing review comments: Name change to NilVectorAggregator and using static variables now

* fixing intellij inspections
This commit is contained in:
Soumyava 2023-09-13 13:15:14 -07:00 committed by GitHub
parent 7f757e33f0
commit bf99d2c7b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 497 additions and 83 deletions

View File

@ -123,7 +123,7 @@ public class DoubleAnyAggregatorFactory extends AggregatorFactory
if (capabilities == null || capabilities.isNumeric()) { if (capabilities == null || capabilities.isNumeric()) {
return new DoubleAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName)); return new DoubleAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName));
} else { } else {
return NumericNilVectorAggregator.doubleNilVectorAggregator(); return NilVectorAggregator.doubleNilVectorAggregator();
} }
} }

View File

@ -120,7 +120,7 @@ public class FloatAnyAggregatorFactory extends AggregatorFactory
if (capabilities == null || capabilities.isNumeric()) { if (capabilities == null || capabilities.isNumeric()) {
return new FloatAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName)); return new FloatAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName));
} else { } else {
return NumericNilVectorAggregator.floatNilVectorAggregator(); return NilVectorAggregator.floatNilVectorAggregator();
} }
} }

View File

@ -119,7 +119,7 @@ public class LongAnyAggregatorFactory extends AggregatorFactory
if (capabilities == null || capabilities.isNumeric()) { if (capabilities == null || capabilities.isNumeric()) {
return new LongAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName)); return new LongAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName));
} else { } else {
return NumericNilVectorAggregator.longNilVectorAggregator(); return NilVectorAggregator.longNilVectorAggregator();
} }
} }

View File

@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.any; package org.apache.druid.query.aggregation.any;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.common.config.NullHandling; import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.VectorAggregator;
@ -28,24 +29,28 @@ import java.nio.ByteBuffer;
/** /**
* A vector aggregator that returns the default numeric value. * A vector aggregator that returns the default numeric value.
*/ */
public class NumericNilVectorAggregator implements VectorAggregator public class NilVectorAggregator implements VectorAggregator
{ {
private static final NumericNilVectorAggregator DOUBLE_NIL_VECTOR_AGGREGATOR = new NumericNilVectorAggregator( private static final NilVectorAggregator DOUBLE_NIL_VECTOR_AGGREGATOR = new NilVectorAggregator(
NullHandling.defaultDoubleValue() NullHandling.defaultDoubleValue()
); );
private static final NumericNilVectorAggregator FLOAT_NIL_VECTOR_AGGREGATOR = new NumericNilVectorAggregator( private static final NilVectorAggregator FLOAT_NIL_VECTOR_AGGREGATOR = new NilVectorAggregator(
NullHandling.defaultFloatValue() NullHandling.defaultFloatValue()
); );
private static final NumericNilVectorAggregator LONG_NIL_VECTOR_AGGREGATOR = new NumericNilVectorAggregator( private static final NilVectorAggregator LONG_NIL_VECTOR_AGGREGATOR = new NilVectorAggregator(
NullHandling.defaultLongValue() NullHandling.defaultLongValue()
); );
public static final SerializablePair<Long, Double> DOUBLE_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultDoubleValue());
public static final SerializablePair<Long, Long> LONG_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultLongValue());
public static final SerializablePair<Long, Float> FLOAT_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultFloatValue());
/** /**
* @return A vectorized aggregator that returns the default double value. * @return A vectorized aggregator that returns the default double value.
*/ */
public static NumericNilVectorAggregator doubleNilVectorAggregator() public static NilVectorAggregator doubleNilVectorAggregator()
{ {
return DOUBLE_NIL_VECTOR_AGGREGATOR; return DOUBLE_NIL_VECTOR_AGGREGATOR;
} }
@ -53,7 +58,7 @@ public class NumericNilVectorAggregator implements VectorAggregator
/** /**
* @return A vectorized aggregator that returns the default float value. * @return A vectorized aggregator that returns the default float value.
*/ */
public static NumericNilVectorAggregator floatNilVectorAggregator() public static NilVectorAggregator floatNilVectorAggregator()
{ {
return FLOAT_NIL_VECTOR_AGGREGATOR; return FLOAT_NIL_VECTOR_AGGREGATOR;
} }
@ -61,7 +66,7 @@ public class NumericNilVectorAggregator implements VectorAggregator
/** /**
* @return A vectorized aggregator that returns the default long value. * @return A vectorized aggregator that returns the default long value.
*/ */
public static NumericNilVectorAggregator longNilVectorAggregator() public static NilVectorAggregator longNilVectorAggregator()
{ {
return LONG_NIL_VECTOR_AGGREGATOR; return LONG_NIL_VECTOR_AGGREGATOR;
} }
@ -69,7 +74,12 @@ public class NumericNilVectorAggregator implements VectorAggregator
@Nullable @Nullable
private final Object returnValue; private final Object returnValue;
private NumericNilVectorAggregator(@Nullable Object returnValue) public static NilVectorAggregator of(Object returnValue)
{
return new NilVectorAggregator(returnValue);
}
private NilVectorAggregator(@Nullable Object returnValue)
{ {
this.returnValue = returnValue; this.returnValue = returnValue;
} }

View File

@ -30,7 +30,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; import org.apache.druid.query.aggregation.any.NilVectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseDoubleColumnValueSelector; import org.apache.druid.segment.BaseDoubleColumnValueSelector;
@ -149,7 +149,7 @@ public class DoubleFirstAggregatorFactory extends AggregatorFactory
timeColumn); timeColumn);
return new DoubleFirstVectorAggregator(timeSelector, valueSelector); return new DoubleFirstVectorAggregator(timeSelector, valueSelector);
} }
return NumericNilVectorAggregator.doubleNilVectorAggregator(); return NilVectorAggregator.of(NilVectorAggregator.DOUBLE_NIL_PAIR);
} }
@Override @Override

View File

@ -30,7 +30,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; import org.apache.druid.query.aggregation.any.NilVectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseFloatColumnValueSelector; import org.apache.druid.segment.BaseFloatColumnValueSelector;
@ -138,7 +138,7 @@ public class FloatFirstAggregatorFactory extends AggregatorFactory
VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn); VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn);
return new FloatFirstVectorAggregator(timeSelector, valueSelector); return new FloatFirstVectorAggregator(timeSelector, valueSelector);
} }
return NumericNilVectorAggregator.floatNilVectorAggregator(); return NilVectorAggregator.of(NilVectorAggregator.FLOAT_NIL_PAIR);
} }
@Override @Override

View File

@ -30,7 +30,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; import org.apache.druid.query.aggregation.any.NilVectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseLongColumnValueSelector; import org.apache.druid.segment.BaseLongColumnValueSelector;
@ -138,7 +138,7 @@ public class LongFirstAggregatorFactory extends AggregatorFactory
timeColumn); timeColumn);
return new LongFirstVectorAggregator(timeSelector, valueSelector); return new LongFirstVectorAggregator(timeSelector, valueSelector);
} }
return NumericNilVectorAggregator.longNilVectorAggregator(); return NilVectorAggregator.of(NilVectorAggregator.LONG_NIL_PAIR);
} }
@Override @Override

View File

@ -42,11 +42,14 @@ import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorObjectSelector;
import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.segment.virtual.ExpressionVectorSelectors;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -188,6 +191,17 @@ public class StringFirstAggregatorFactory extends AggregatorFactory
{ {
final VectorValueSelector timeSelector = selectorFactory.makeValueSelector(timeColumn); final VectorValueSelector timeSelector = selectorFactory.makeValueSelector(timeColumn);
ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName); ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName);
if (Types.isNumeric(capabilities)) {
VectorValueSelector valueSelector = selectorFactory.makeValueSelector(fieldName);
VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject(
selectorFactory.getReadableVectorInspector(),
fieldName,
valueSelector,
capabilities.toColumnType(),
ColumnType.STRING
);
return new StringFirstVectorAggregator(timeSelector, objectSelector, maxStringBytes);
}
if (capabilities != null) { if (capabilities != null) {
if (capabilities.is(ValueType.STRING) && capabilities.isDictionaryEncoded().isTrue()) { if (capabilities.is(ValueType.STRING) && capabilities.isDictionaryEncoded().isTrue()) {
// Case 1: Single value string with dimension selector // Case 1: Single value string with dimension selector

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import org.apache.druid.collections.SerializablePair; import org.apache.druid.collections.SerializablePair;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.UOE; import org.apache.druid.java.util.common.UOE;
import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.AggregateCombiner;
import org.apache.druid.query.aggregation.Aggregator; import org.apache.druid.query.aggregation.Aggregator;
@ -30,7 +31,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; import org.apache.druid.query.aggregation.any.NilVectorAggregator;
import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory; import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
@ -42,6 +43,7 @@ import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.segment.vector.VectorValueSelector;
@ -125,14 +127,12 @@ public class DoubleLastAggregatorFactory extends AggregatorFactory
) )
{ {
ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); if (Types.isNumeric(capabilities)) {
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector( VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn);
timeColumn);
if (capabilities == null || capabilities.isNumeric()) {
return new DoubleLastVectorAggregator(timeSelector, valueSelector); return new DoubleLastVectorAggregator(timeSelector, valueSelector);
} else { } else {
return NumericNilVectorAggregator.doubleNilVectorAggregator(); return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultDoubleValue()));
} }
} }

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import org.apache.druid.collections.SerializablePair; import org.apache.druid.collections.SerializablePair;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.UOE; import org.apache.druid.java.util.common.UOE;
import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.AggregateCombiner;
import org.apache.druid.query.aggregation.Aggregator; import org.apache.druid.query.aggregation.Aggregator;
@ -30,7 +31,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; import org.apache.druid.query.aggregation.any.NilVectorAggregator;
import org.apache.druid.query.aggregation.first.FloatFirstAggregatorFactory; import org.apache.druid.query.aggregation.first.FloatFirstAggregatorFactory;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
@ -42,6 +43,7 @@ import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.segment.vector.VectorValueSelector;
@ -136,15 +138,13 @@ public class FloatLastAggregatorFactory extends AggregatorFactory
VectorColumnSelectorFactory columnSelectorFactory VectorColumnSelectorFactory columnSelectorFactory
) )
{ {
ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); final ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); if (Types.isNumeric(capabilities)) {
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector( VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn);
timeColumn);
if (capabilities == null || capabilities.isNumeric()) {
return new FloatLastVectorAggregator(timeSelector, valueSelector); return new FloatLastVectorAggregator(timeSelector, valueSelector);
} else { } else {
return NumericNilVectorAggregator.floatNilVectorAggregator(); return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultFloatValue()));
} }
} }

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import org.apache.druid.collections.SerializablePair; import org.apache.druid.collections.SerializablePair;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.UOE; import org.apache.druid.java.util.common.UOE;
import org.apache.druid.query.aggregation.AggregateCombiner; import org.apache.druid.query.aggregation.AggregateCombiner;
import org.apache.druid.query.aggregation.Aggregator; import org.apache.druid.query.aggregation.Aggregator;
@ -30,7 +31,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator; import org.apache.druid.query.aggregation.any.NilVectorAggregator;
import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory; import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
@ -42,6 +43,7 @@ import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.segment.vector.VectorValueSelector;
@ -136,14 +138,13 @@ public class LongLastAggregatorFactory extends AggregatorFactory
VectorColumnSelectorFactory columnSelectorFactory VectorColumnSelectorFactory columnSelectorFactory
) )
{ {
ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName); final ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName); if (Types.isNumeric(capabilities)) {
VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector( VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
timeColumn); VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn);
if (capabilities == null || capabilities.isNumeric()) {
return new LongLastVectorAggregator(timeSelector, valueSelector); return new LongLastVectorAggregator(timeSelector, valueSelector);
} else { } else {
return NumericNilVectorAggregator.longNilVectorAggregator(); return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultLongValue()));
} }
} }

View File

@ -42,9 +42,11 @@ import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorObjectSelector;
import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.segment.virtual.ExpressionVectorSelectors;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -156,16 +158,25 @@ public class StringLastAggregatorFactory extends AggregatorFactory
public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory) public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory)
{ {
ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName); final ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector timeSelector = selectorFactory.makeValueSelector(timeColumn);
if (Types.isNumeric(capabilities)) {
VectorValueSelector valueSelector = selectorFactory.makeValueSelector(fieldName);
VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject(
selectorFactory.getReadableVectorInspector(),
fieldName,
valueSelector,
capabilities.toColumnType(),
ColumnType.STRING
);
return new StringLastVectorAggregator(timeSelector, objectSelector, maxStringBytes);
}
VectorObjectSelector vSelector = selectorFactory.makeObjectSelector(fieldName); VectorObjectSelector vSelector = selectorFactory.makeObjectSelector(fieldName);
VectorValueSelector timeSelector = selectorFactory.makeValueSelector(
timeColumn);
if (capabilities != null) { if (capabilities != null) {
return new StringLastVectorAggregator(timeSelector, vSelector, maxStringBytes); return new StringLastVectorAggregator(timeSelector, vSelector, maxStringBytes);
} else { } else {
return new StringLastVectorAggregator(null, vSelector, maxStringBytes); return new StringLastVectorAggregator(null, vSelector, maxStringBytes);
} }
} }
@Override @Override

View File

@ -23,35 +23,45 @@ import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.query.aggregation.SerializablePairLongString; import org.apache.druid.query.aggregation.SerializablePairLongString;
import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.vector.BaseLongVectorValueSelector; import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector;
import org.apache.druid.segment.vector.NoFilterVectorOffset;
import org.apache.druid.segment.vector.ReadableVectorInspector;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorObjectSelector;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest; import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Answers;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest
{ {
private static final double EPSILON = 1e-5; private static final double EPSILON = 1e-5;
private static final String[] VALUES = new String[]{"a", "b", null, "c"}; private static final String[] VALUES = new String[]{"a", "b", null, "c"};
private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L};
private static final String[] STRING_VALUES = new String[]{"1", "2", "3", "4"};
private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f};
private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0};
private static final boolean[] NULLS = new boolean[]{false, false, true, false}; private static final boolean[] NULLS = new boolean[]{false, false, true, false};
private static final String NAME = "NAME"; private static final String NAME = "NAME";
private static final String FIELD_NAME = "FIELD_NAME"; private static final String FIELD_NAME = "FIELD_NAME";
private static final String FIELD_NAME_LONG = "LONG_NAME";
private static final String TIME_COL = "__time"; private static final String TIME_COL = "__time";
private long[] times = {2436, 6879, 7888, 8224}; private final long[] times = {2436, 6879, 7888, 8224};
private long[] timesSame = {2436, 2436}; private final long[] timesSame = {2436, 2436};
private SerializablePairLongString[] pairs = { private final SerializablePairLongString[] pairs = {
new SerializablePairLongString(2345001L, "first"), new SerializablePairLongString(2345001L, "first"),
new SerializablePairLongString(2345100L, "notFirst") new SerializablePairLongString(2345100L, "notFirst")
}; };
@ -69,8 +79,10 @@ public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest
private StringFirstVectorAggregator targetWithPairs; private StringFirstVectorAggregator targetWithPairs;
private StringFirstAggregatorFactory stringFirstAggregatorFactory; private StringFirstAggregatorFactory stringFirstAggregatorFactory;
@Mock(answer = Answers.RETURNS_DEEP_STUBS) private StringFirstAggregatorFactory stringFirstAggregatorFactory1;
private VectorColumnSelectorFactory selectorFactory; private VectorColumnSelectorFactory selectorFactory;
private VectorValueSelector nonStringValueSelector;
@Before @Before
public void setup() public void setup()
@ -78,19 +90,189 @@ public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest
byte[] randomBytes = new byte[1024]; byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes); ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes); buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getObjectVector();
Mockito.doReturn(times).when(timeSelector).getLongVector(); timeSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(times.length, 0, times.length))
Mockito.doReturn(timesSame).when(timeSelectorForPairs).getLongVector(); {
Mockito.doReturn(pairs).when(selectorForPairs).getObjectVector(); @Override
public long[] getLongVector()
{
return times;
}
@Nullable
@Override
public boolean[] getNullVector()
{
return null;
}
};
selector = new VectorObjectSelector()
{
@Override
public Object[] getObjectVector()
{
return VALUES;
}
@Override
public int getMaxVectorSize()
{
return 4;
}
@Override
public int getCurrentVectorSize()
{
return 0;
}
};
timeSelectorForPairs = new BaseLongVectorValueSelector(new NoFilterVectorOffset(
timesSame.length,
0,
timesSame.length
))
{
@Override
public long[] getLongVector()
{
return timesSame;
}
@Nullable
@Override
public boolean[] getNullVector()
{
return null;
}
};
selectorForPairs = new VectorObjectSelector()
{
@Override
public Object[] getObjectVector()
{
return pairs;
}
@Override
public int getMaxVectorSize()
{
return 2;
}
@Override
public int getCurrentVectorSize()
{
return 0;
}
};
nonStringValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(
LONG_VALUES.length,
0,
LONG_VALUES.length
))
{
@Override
public long[] getLongVector()
{
return LONG_VALUES;
}
@Override
public float[] getFloatVector()
{
return FLOAT_VALUES;
}
@Override
public double[] getDoubleVector()
{
return DOUBLE_VALUES;
}
@Nullable
@Override
public boolean[] getNullVector()
{
return NULLS;
}
@Override
public int getMaxVectorSize()
{
return 4;
}
@Override
public int getCurrentVectorSize()
{
return 4;
}
};
selectorFactory = new VectorColumnSelectorFactory()
{
@Override
public ReadableVectorInspector getReadableVectorInspector()
{
return new NoFilterVectorOffset(VALUES.length, 0, VALUES.length);
}
@Override
public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec)
{
return null;
}
@Override
public MultiValueDimensionVectorSelector makeMultiValueDimensionSelector(DimensionSpec dimensionSpec)
{
return null;
}
@Override
public VectorValueSelector makeValueSelector(String column)
{
if (TIME_COL.equals(column)) {
return timeSelector;
} else if (FIELD_NAME_LONG.equals(column)) {
return nonStringValueSelector;
}
return null;
}
@Override
public VectorObjectSelector makeObjectSelector(String column)
{
if (FIELD_NAME.equals(column)) {
return selector;
} else {
return null;
}
}
@Nullable
@Override
public ColumnCapabilities getColumnCapabilities(String column)
{
if (FIELD_NAME.equals(column)) {
return ColumnCapabilitiesImpl.createSimpleSingleValueStringColumnCapabilities();
} else if (FIELD_NAME_LONG.equals(column)) {
return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG);
}
return null;
}
};
target = new StringFirstVectorAggregator(timeSelector, selector, 10); target = new StringFirstVectorAggregator(timeSelector, selector, 10);
targetWithPairs = new StringFirstVectorAggregator(timeSelectorForPairs, selectorForPairs, 10); targetWithPairs = new StringFirstVectorAggregator(timeSelectorForPairs, selectorForPairs, 10);
clearBufferForPositions(0, 0); clearBufferForPositions(0, 0);
Mockito.doReturn(selector).when(selectorFactory).makeObjectSelector(FIELD_NAME);
Mockito.doReturn(timeSelector).when(selectorFactory).makeValueSelector(TIME_COL);
stringFirstAggregatorFactory = new StringFirstAggregatorFactory(NAME, FIELD_NAME, TIME_COL, 10); stringFirstAggregatorFactory = new StringFirstAggregatorFactory(NAME, FIELD_NAME, TIME_COL, 10);
stringFirstAggregatorFactory1 = new StringFirstAggregatorFactory(NAME, FIELD_NAME_LONG, TIME_COL, 10);
} }
@Test @Test
@ -129,6 +311,19 @@ public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest
Assert.assertEquals(VALUES[0], result.rhs); Assert.assertEquals(VALUES[0], result.rhs);
} }
@Test
public void testStringEarliestOnNonStringColumns()
{
Assert.assertTrue(stringFirstAggregatorFactory1.canVectorize(selectorFactory));
VectorAggregator vectorAggregator = stringFirstAggregatorFactory1.factorizeVector(selectorFactory);
Assert.assertNotNull(vectorAggregator);
Assert.assertEquals(StringFirstVectorAggregator.class, vectorAggregator.getClass());
vectorAggregator.aggregate(buf, 0, 0, LONG_VALUES.length);
Pair<Long, String> result = (Pair<Long, String>) vectorAggregator.get(buf, 0);
Assert.assertEquals(times[0], result.lhs.longValue());
Assert.assertEquals(STRING_VALUES[0], result.rhs);
}
@Test @Test
public void aggregateBatchWithoutRows() public void aggregateBatchWithoutRows()
{ {

View File

@ -23,74 +23,245 @@ import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.query.aggregation.SerializablePairLongString; import org.apache.druid.query.aggregation.SerializablePairLongString;
import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.vector.BaseLongVectorValueSelector; import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector;
import org.apache.druid.segment.vector.NoFilterVectorOffset;
import org.apache.druid.segment.vector.ReadableVectorInspector;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorObjectSelector;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest; import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Answers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
{ {
private static final double EPSILON = 1e-5; private static final double EPSILON = 1e-5;
private static final String[] VALUES = new String[]{"a", "b", null, "c"}; private static final String[] VALUES = new String[]{"a", "b", null, "c"};
private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L};
private static final String[] STRING_VALUES = new String[]{"1", "2", "3", "4"};
private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f};
private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0};
private static final boolean[] NULLS = new boolean[]{false, false, true, false}; private static final boolean[] NULLS = new boolean[]{false, false, true, false};
private static final String NAME = "NAME"; private static final String NAME = "NAME";
private static final String FIELD_NAME = "FIELD_NAME"; private static final String FIELD_NAME = "FIELD_NAME";
private static final String FIELD_NAME_LONG = "LONG_NAME";
private static final String TIME_COL = "__time"; private static final String TIME_COL = "__time";
private long[] times = {2436, 6879, 7888, 8224}; private final long[] times = {2436, 6879, 7888, 8224};
private long[] timesSame = {2436, 2436}; private final long[] timesSame = {2436, 2436};
private SerializablePairLongString[] pairs = { private final SerializablePairLongString[] pairs = {
new SerializablePairLongString(2345100L, "last"), new SerializablePairLongString(2345100L, "last"),
new SerializablePairLongString(2345001L, "notLast") new SerializablePairLongString(2345001L, "notLast")
}; };
@Mock
private VectorObjectSelector selector; private VectorObjectSelector selector;
@Mock
private VectorObjectSelector selectorForPairs;
@Mock
private BaseLongVectorValueSelector timeSelector; private BaseLongVectorValueSelector timeSelector;
@Mock private VectorValueSelector nonStringValueSelector;
private BaseLongVectorValueSelector timeSelectorForPairs;
private ByteBuffer buf; private ByteBuffer buf;
private StringLastVectorAggregator target; private StringLastVectorAggregator target;
private StringLastVectorAggregator targetWithPairs; private StringLastVectorAggregator targetWithPairs;
private StringLastAggregatorFactory stringLastAggregatorFactory; private StringLastAggregatorFactory stringLastAggregatorFactory;
@Mock(answer = Answers.RETURNS_DEEP_STUBS) private StringLastAggregatorFactory stringLastAggregatorFactory1;
private VectorColumnSelectorFactory selectorFactory; private VectorColumnSelectorFactory selectorFactory;
@Before @Before
public void setup() public void setup()
{ {
byte[] randomBytes = new byte[1024]; byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes); ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes); buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getObjectVector(); timeSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(times.length, 0, times.length))
Mockito.doReturn(times).when(timeSelector).getLongVector(); {
Mockito.doReturn(timesSame).when(timeSelectorForPairs).getLongVector(); @Override
Mockito.doReturn(pairs).when(selectorForPairs).getObjectVector(); public long[] getLongVector()
{
return times;
}
@Nullable
@Override
public boolean[] getNullVector()
{
return NULLS;
}
};
nonStringValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(
LONG_VALUES.length,
0,
LONG_VALUES.length
))
{
@Override
public long[] getLongVector()
{
return LONG_VALUES;
}
@Override
public float[] getFloatVector()
{
return FLOAT_VALUES;
}
@Override
public double[] getDoubleVector()
{
return DOUBLE_VALUES;
}
@Nullable
@Override
public boolean[] getNullVector()
{
return NULLS;
}
@Override
public int getMaxVectorSize()
{
return 4;
}
@Override
public int getCurrentVectorSize()
{
return 4;
}
};
selector = new VectorObjectSelector()
{
@Override
public Object[] getObjectVector()
{
return VALUES;
}
@Override
public int getMaxVectorSize()
{
return 0;
}
@Override
public int getCurrentVectorSize()
{
return 0;
}
};
BaseLongVectorValueSelector timeSelectorForPairs = new BaseLongVectorValueSelector(new NoFilterVectorOffset(
times.length,
0,
times.length
))
{
@Override
public long[] getLongVector()
{
return timesSame;
}
@Nullable
@Override
public boolean[] getNullVector()
{
return new boolean[0];
}
};
VectorObjectSelector selectorForPairs = new VectorObjectSelector()
{
@Override
public Object[] getObjectVector()
{
return pairs;
}
@Override
public int getMaxVectorSize()
{
return 2;
}
@Override
public int getCurrentVectorSize()
{
return 2;
}
};
selectorFactory = new VectorColumnSelectorFactory()
{
@Override
public ReadableVectorInspector getReadableVectorInspector()
{
return new NoFilterVectorOffset(LONG_VALUES.length, 0, LONG_VALUES.length);
}
@Override
public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec)
{
return null;
}
@Override
public MultiValueDimensionVectorSelector makeMultiValueDimensionSelector(DimensionSpec dimensionSpec)
{
return null;
}
@Override
public VectorValueSelector makeValueSelector(String column)
{
if (TIME_COL.equals(column)) {
return timeSelector;
} else if (FIELD_NAME_LONG.equals(column)) {
return nonStringValueSelector;
}
return null;
}
@Override
public VectorObjectSelector makeObjectSelector(String column)
{
if (FIELD_NAME.equals(column)) {
return selector;
} else {
return null;
}
}
@Nullable
@Override
public ColumnCapabilities getColumnCapabilities(String column)
{
if (FIELD_NAME.equals(column)) {
return ColumnCapabilitiesImpl.createSimpleSingleValueStringColumnCapabilities();
} else if (FIELD_NAME_LONG.equals(column)) {
return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG);
}
return null;
}
};
target = new StringLastVectorAggregator(timeSelector, selector, 10); target = new StringLastVectorAggregator(timeSelector, selector, 10);
targetWithPairs = new StringLastVectorAggregator(timeSelectorForPairs, selectorForPairs, 10); targetWithPairs = new StringLastVectorAggregator(timeSelectorForPairs, selectorForPairs, 10);
clearBufferForPositions(0, 0); clearBufferForPositions(0, 0);
Mockito.doReturn(selector).when(selectorFactory).makeObjectSelector(FIELD_NAME);
Mockito.doReturn(timeSelector).when(selectorFactory).makeValueSelector(TIME_COL);
stringLastAggregatorFactory = new StringLastAggregatorFactory(NAME, FIELD_NAME, TIME_COL, 10); stringLastAggregatorFactory = new StringLastAggregatorFactory(NAME, FIELD_NAME, TIME_COL, 10);
stringLastAggregatorFactory1 = new StringLastAggregatorFactory(NAME, FIELD_NAME_LONG, TIME_COL, 10);
} }
@Test @Test
@ -112,6 +283,19 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
Assert.assertEquals(StringLastVectorAggregator.class, vectorAggregator.getClass()); Assert.assertEquals(StringLastVectorAggregator.class, vectorAggregator.getClass());
} }
@Test
public void testStringLastOnNonStringColumns()
{
Assert.assertTrue(stringLastAggregatorFactory1.canVectorize(selectorFactory));
VectorAggregator vectorAggregator = stringLastAggregatorFactory1.factorizeVector(selectorFactory);
Assert.assertNotNull(vectorAggregator);
Assert.assertEquals(StringLastVectorAggregator.class, vectorAggregator.getClass());
vectorAggregator.aggregate(buf, 0, 0, LONG_VALUES.length);
Pair<Long, String> result = (Pair<Long, String>) vectorAggregator.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(STRING_VALUES[3], result.rhs);
}
@Test @Test
public void initValueShouldBeMinDate() public void initValueShouldBeMinDate()
{ {

View File

@ -632,7 +632,6 @@ public class CalciteSimpleQueryTest extends BaseCalciteQueryTest
@Test @Test
public void testEarliestByLatestByWithExpression() public void testEarliestByLatestByWithExpression()
{ {
cannotVectorize();
testBuilder() testBuilder()
.sql("SELECT\n" .sql("SELECT\n"
+ " channel\n" + " channel\n"