Fix for schema mismatch to go down using the non vectorize path till we update the vectorized aggs properly (#14924)

* Fix for schema mismatch to go down using the non vectorize path till we update the vectorized aggs properly

* Fixing a failed test

* Updating numericNilAgg

* Moving to use default values in case of nil agg

* Adding the same for first agg

* Fixing a test

* fixing vectorized string agg for last/first with cast if numeric

* Updating tests to remove mockito and cover the case of string first/last on non string columns

* Updating a test to vectorize

* Addressing review comments: Name change to NilVectorAggregator and using static variables now

* fixing intellij inspections
This commit is contained in:
Soumyava 2023-09-13 13:15:14 -07:00 committed by GitHub
parent 7f757e33f0
commit bf99d2c7b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 497 additions and 83 deletions

View File

@ -123,7 +123,7 @@ public class DoubleAnyAggregatorFactory extends AggregatorFactory
if (capabilities == null || capabilities.isNumeric()) {
return new DoubleAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName));
} else {
return NumericNilVectorAggregator.doubleNilVectorAggregator();
return NilVectorAggregator.doubleNilVectorAggregator();
}
}

View File

@ -120,7 +120,7 @@ public class FloatAnyAggregatorFactory extends AggregatorFactory
if (capabilities == null || capabilities.isNumeric()) {
return new FloatAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName));
} else {
return NumericNilVectorAggregator.floatNilVectorAggregator();
return NilVectorAggregator.floatNilVectorAggregator();
}
}

View File

@ -119,7 +119,7 @@ public class LongAnyAggregatorFactory extends AggregatorFactory
if (capabilities == null || capabilities.isNumeric()) {
return new LongAnyVectorAggregator(selectorFactory.makeValueSelector(fieldName));
} else {
return NumericNilVectorAggregator.longNilVectorAggregator();
return NilVectorAggregator.longNilVectorAggregator();
}
}

View File

@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.any;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.VectorAggregator;
@ -28,24 +29,28 @@ import java.nio.ByteBuffer;
/**
* A vector aggregator that returns the default numeric value.
*/
public class NumericNilVectorAggregator implements VectorAggregator
public class NilVectorAggregator implements VectorAggregator
{
private static final NumericNilVectorAggregator DOUBLE_NIL_VECTOR_AGGREGATOR = new NumericNilVectorAggregator(
private static final NilVectorAggregator DOUBLE_NIL_VECTOR_AGGREGATOR = new NilVectorAggregator(
NullHandling.defaultDoubleValue()
);
private static final NumericNilVectorAggregator FLOAT_NIL_VECTOR_AGGREGATOR = new NumericNilVectorAggregator(
private static final NilVectorAggregator FLOAT_NIL_VECTOR_AGGREGATOR = new NilVectorAggregator(
NullHandling.defaultFloatValue()
);
private static final NumericNilVectorAggregator LONG_NIL_VECTOR_AGGREGATOR = new NumericNilVectorAggregator(
private static final NilVectorAggregator LONG_NIL_VECTOR_AGGREGATOR = new NilVectorAggregator(
NullHandling.defaultLongValue()
);
public static final SerializablePair<Long, Double> DOUBLE_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultDoubleValue());
public static final SerializablePair<Long, Long> LONG_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultLongValue());
public static final SerializablePair<Long, Float> FLOAT_NIL_PAIR = new SerializablePair<>(0L, NullHandling.defaultFloatValue());
/**
* @return A vectorized aggregator that returns the default double value.
*/
public static NumericNilVectorAggregator doubleNilVectorAggregator()
public static NilVectorAggregator doubleNilVectorAggregator()
{
return DOUBLE_NIL_VECTOR_AGGREGATOR;
}
@ -53,7 +58,7 @@ public class NumericNilVectorAggregator implements VectorAggregator
/**
* @return A vectorized aggregator that returns the default float value.
*/
public static NumericNilVectorAggregator floatNilVectorAggregator()
public static NilVectorAggregator floatNilVectorAggregator()
{
return FLOAT_NIL_VECTOR_AGGREGATOR;
}
@ -61,7 +66,7 @@ public class NumericNilVectorAggregator implements VectorAggregator
/**
* @return A vectorized aggregator that returns the default long value.
*/
public static NumericNilVectorAggregator longNilVectorAggregator()
public static NilVectorAggregator longNilVectorAggregator()
{
return LONG_NIL_VECTOR_AGGREGATOR;
}
@ -69,7 +74,12 @@ public class NumericNilVectorAggregator implements VectorAggregator
@Nullable
private final Object returnValue;
private NumericNilVectorAggregator(@Nullable Object returnValue)
public static NilVectorAggregator of(Object returnValue)
{
return new NilVectorAggregator(returnValue);
}
private NilVectorAggregator(@Nullable Object returnValue)
{
this.returnValue = returnValue;
}

View File

@ -30,7 +30,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator;
import org.apache.druid.query.aggregation.any.NilVectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseDoubleColumnValueSelector;
@ -149,7 +149,7 @@ public class DoubleFirstAggregatorFactory extends AggregatorFactory
timeColumn);
return new DoubleFirstVectorAggregator(timeSelector, valueSelector);
}
return NumericNilVectorAggregator.doubleNilVectorAggregator();
return NilVectorAggregator.of(NilVectorAggregator.DOUBLE_NIL_PAIR);
}
@Override

View File

@ -30,7 +30,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator;
import org.apache.druid.query.aggregation.any.NilVectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseFloatColumnValueSelector;
@ -138,7 +138,7 @@ public class FloatFirstAggregatorFactory extends AggregatorFactory
VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn);
return new FloatFirstVectorAggregator(timeSelector, valueSelector);
}
return NumericNilVectorAggregator.floatNilVectorAggregator();
return NilVectorAggregator.of(NilVectorAggregator.FLOAT_NIL_PAIR);
}
@Override

View File

@ -30,7 +30,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator;
import org.apache.druid.query.aggregation.any.NilVectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseLongColumnValueSelector;
@ -138,7 +138,7 @@ public class LongFirstAggregatorFactory extends AggregatorFactory
timeColumn);
return new LongFirstVectorAggregator(timeSelector, valueSelector);
}
return NumericNilVectorAggregator.longNilVectorAggregator();
return NilVectorAggregator.of(NilVectorAggregator.LONG_NIL_PAIR);
}
@Override

View File

@ -42,11 +42,14 @@ import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorObjectSelector;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.segment.virtual.ExpressionVectorSelectors;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@ -188,6 +191,17 @@ public class StringFirstAggregatorFactory extends AggregatorFactory
{
final VectorValueSelector timeSelector = selectorFactory.makeValueSelector(timeColumn);
ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName);
if (Types.isNumeric(capabilities)) {
VectorValueSelector valueSelector = selectorFactory.makeValueSelector(fieldName);
VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject(
selectorFactory.getReadableVectorInspector(),
fieldName,
valueSelector,
capabilities.toColumnType(),
ColumnType.STRING
);
return new StringFirstVectorAggregator(timeSelector, objectSelector, maxStringBytes);
}
if (capabilities != null) {
if (capabilities.is(ValueType.STRING) && capabilities.isDictionaryEncoded().isTrue()) {
// Case 1: Single value string with dimension selector

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.query.aggregation.AggregateCombiner;
import org.apache.druid.query.aggregation.Aggregator;
@ -30,7 +31,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator;
import org.apache.druid.query.aggregation.any.NilVectorAggregator;
import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
@ -42,6 +43,7 @@ import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorValueSelector;
@ -125,14 +127,12 @@ public class DoubleLastAggregatorFactory extends AggregatorFactory
)
{
ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(
timeColumn);
if (capabilities == null || capabilities.isNumeric()) {
if (Types.isNumeric(capabilities)) {
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn);
return new DoubleLastVectorAggregator(timeSelector, valueSelector);
} else {
return NumericNilVectorAggregator.doubleNilVectorAggregator();
return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultDoubleValue()));
}
}

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.query.aggregation.AggregateCombiner;
import org.apache.druid.query.aggregation.Aggregator;
@ -30,7 +31,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator;
import org.apache.druid.query.aggregation.any.NilVectorAggregator;
import org.apache.druid.query.aggregation.first.FloatFirstAggregatorFactory;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
@ -42,6 +43,7 @@ import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorValueSelector;
@ -136,15 +138,13 @@ public class FloatLastAggregatorFactory extends AggregatorFactory
VectorColumnSelectorFactory columnSelectorFactory
)
{
ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(
timeColumn);
if (capabilities == null || capabilities.isNumeric()) {
final ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
if (Types.isNumeric(capabilities)) {
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn);
return new FloatLastVectorAggregator(timeSelector, valueSelector);
} else {
return NumericNilVectorAggregator.floatNilVectorAggregator();
return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultFloatValue()));
}
}

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.query.aggregation.AggregateCombiner;
import org.apache.druid.query.aggregation.Aggregator;
@ -30,7 +31,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator;
import org.apache.druid.query.aggregation.any.NilVectorAggregator;
import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
@ -42,6 +43,7 @@ import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorValueSelector;
@ -136,14 +138,13 @@ public class LongLastAggregatorFactory extends AggregatorFactory
VectorColumnSelectorFactory columnSelectorFactory
)
{
ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(
timeColumn);
if (capabilities == null || capabilities.isNumeric()) {
final ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
if (Types.isNumeric(capabilities)) {
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
VectorValueSelector timeSelector = columnSelectorFactory.makeValueSelector(timeColumn);
return new LongLastVectorAggregator(timeSelector, valueSelector);
} else {
return NumericNilVectorAggregator.longNilVectorAggregator();
return NilVectorAggregator.of(new SerializablePair<>(0L, NullHandling.defaultLongValue()));
}
}

View File

@ -42,9 +42,11 @@ import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorObjectSelector;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.segment.virtual.ExpressionVectorSelectors;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@ -156,16 +158,25 @@ public class StringLastAggregatorFactory extends AggregatorFactory
public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory)
{
ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName);
final ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector timeSelector = selectorFactory.makeValueSelector(timeColumn);
if (Types.isNumeric(capabilities)) {
VectorValueSelector valueSelector = selectorFactory.makeValueSelector(fieldName);
VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject(
selectorFactory.getReadableVectorInspector(),
fieldName,
valueSelector,
capabilities.toColumnType(),
ColumnType.STRING
);
return new StringLastVectorAggregator(timeSelector, objectSelector, maxStringBytes);
}
VectorObjectSelector vSelector = selectorFactory.makeObjectSelector(fieldName);
VectorValueSelector timeSelector = selectorFactory.makeValueSelector(
timeColumn);
if (capabilities != null) {
return new StringLastVectorAggregator(timeSelector, vSelector, maxStringBytes);
} else {
return new StringLastVectorAggregator(null, vSelector, maxStringBytes);
}
}
@Override

View File

@ -23,35 +23,45 @@ import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.query.aggregation.SerializablePairLongString;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector;
import org.apache.druid.segment.vector.NoFilterVectorOffset;
import org.apache.druid.segment.vector.ReadableVectorInspector;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorObjectSelector;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Answers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final double EPSILON = 1e-5;
private static final String[] VALUES = new String[]{"a", "b", null, "c"};
private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L};
private static final String[] STRING_VALUES = new String[]{"1", "2", "3", "4"};
private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f};
private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0};
private static final boolean[] NULLS = new boolean[]{false, false, true, false};
private static final String NAME = "NAME";
private static final String FIELD_NAME = "FIELD_NAME";
private static final String FIELD_NAME_LONG = "LONG_NAME";
private static final String TIME_COL = "__time";
private long[] times = {2436, 6879, 7888, 8224};
private long[] timesSame = {2436, 2436};
private SerializablePairLongString[] pairs = {
private final long[] times = {2436, 6879, 7888, 8224};
private final long[] timesSame = {2436, 2436};
private final SerializablePairLongString[] pairs = {
new SerializablePairLongString(2345001L, "first"),
new SerializablePairLongString(2345100L, "notFirst")
};
@ -69,8 +79,10 @@ public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest
private StringFirstVectorAggregator targetWithPairs;
private StringFirstAggregatorFactory stringFirstAggregatorFactory;
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private StringFirstAggregatorFactory stringFirstAggregatorFactory1;
private VectorColumnSelectorFactory selectorFactory;
private VectorValueSelector nonStringValueSelector;
@Before
public void setup()
@ -78,19 +90,189 @@ public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getObjectVector();
Mockito.doReturn(times).when(timeSelector).getLongVector();
Mockito.doReturn(timesSame).when(timeSelectorForPairs).getLongVector();
Mockito.doReturn(pairs).when(selectorForPairs).getObjectVector();
timeSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(times.length, 0, times.length))
{
@Override
public long[] getLongVector()
{
return times;
}
@Nullable
@Override
public boolean[] getNullVector()
{
return null;
}
};
selector = new VectorObjectSelector()
{
@Override
public Object[] getObjectVector()
{
return VALUES;
}
@Override
public int getMaxVectorSize()
{
return 4;
}
@Override
public int getCurrentVectorSize()
{
return 0;
}
};
timeSelectorForPairs = new BaseLongVectorValueSelector(new NoFilterVectorOffset(
timesSame.length,
0,
timesSame.length
))
{
@Override
public long[] getLongVector()
{
return timesSame;
}
@Nullable
@Override
public boolean[] getNullVector()
{
return null;
}
};
selectorForPairs = new VectorObjectSelector()
{
@Override
public Object[] getObjectVector()
{
return pairs;
}
@Override
public int getMaxVectorSize()
{
return 2;
}
@Override
public int getCurrentVectorSize()
{
return 0;
}
};
nonStringValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(
LONG_VALUES.length,
0,
LONG_VALUES.length
))
{
@Override
public long[] getLongVector()
{
return LONG_VALUES;
}
@Override
public float[] getFloatVector()
{
return FLOAT_VALUES;
}
@Override
public double[] getDoubleVector()
{
return DOUBLE_VALUES;
}
@Nullable
@Override
public boolean[] getNullVector()
{
return NULLS;
}
@Override
public int getMaxVectorSize()
{
return 4;
}
@Override
public int getCurrentVectorSize()
{
return 4;
}
};
selectorFactory = new VectorColumnSelectorFactory()
{
@Override
public ReadableVectorInspector getReadableVectorInspector()
{
return new NoFilterVectorOffset(VALUES.length, 0, VALUES.length);
}
@Override
public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec)
{
return null;
}
@Override
public MultiValueDimensionVectorSelector makeMultiValueDimensionSelector(DimensionSpec dimensionSpec)
{
return null;
}
@Override
public VectorValueSelector makeValueSelector(String column)
{
if (TIME_COL.equals(column)) {
return timeSelector;
} else if (FIELD_NAME_LONG.equals(column)) {
return nonStringValueSelector;
}
return null;
}
@Override
public VectorObjectSelector makeObjectSelector(String column)
{
if (FIELD_NAME.equals(column)) {
return selector;
} else {
return null;
}
}
@Nullable
@Override
public ColumnCapabilities getColumnCapabilities(String column)
{
if (FIELD_NAME.equals(column)) {
return ColumnCapabilitiesImpl.createSimpleSingleValueStringColumnCapabilities();
} else if (FIELD_NAME_LONG.equals(column)) {
return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG);
}
return null;
}
};
target = new StringFirstVectorAggregator(timeSelector, selector, 10);
targetWithPairs = new StringFirstVectorAggregator(timeSelectorForPairs, selectorForPairs, 10);
clearBufferForPositions(0, 0);
Mockito.doReturn(selector).when(selectorFactory).makeObjectSelector(FIELD_NAME);
Mockito.doReturn(timeSelector).when(selectorFactory).makeValueSelector(TIME_COL);
stringFirstAggregatorFactory = new StringFirstAggregatorFactory(NAME, FIELD_NAME, TIME_COL, 10);
stringFirstAggregatorFactory1 = new StringFirstAggregatorFactory(NAME, FIELD_NAME_LONG, TIME_COL, 10);
}
@Test
@ -129,6 +311,19 @@ public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest
Assert.assertEquals(VALUES[0], result.rhs);
}
@Test
public void testStringEarliestOnNonStringColumns()
{
Assert.assertTrue(stringFirstAggregatorFactory1.canVectorize(selectorFactory));
VectorAggregator vectorAggregator = stringFirstAggregatorFactory1.factorizeVector(selectorFactory);
Assert.assertNotNull(vectorAggregator);
Assert.assertEquals(StringFirstVectorAggregator.class, vectorAggregator.getClass());
vectorAggregator.aggregate(buf, 0, 0, LONG_VALUES.length);
Pair<Long, String> result = (Pair<Long, String>) vectorAggregator.get(buf, 0);
Assert.assertEquals(times[0], result.lhs.longValue());
Assert.assertEquals(STRING_VALUES[0], result.rhs);
}
@Test
public void aggregateBatchWithoutRows()
{

View File

@ -23,74 +23,245 @@ import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.query.aggregation.SerializablePairLongString;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector;
import org.apache.druid.segment.vector.NoFilterVectorOffset;
import org.apache.druid.segment.vector.ReadableVectorInspector;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorObjectSelector;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Answers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final double EPSILON = 1e-5;
private static final String[] VALUES = new String[]{"a", "b", null, "c"};
private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L};
private static final String[] STRING_VALUES = new String[]{"1", "2", "3", "4"};
private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f};
private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0};
private static final boolean[] NULLS = new boolean[]{false, false, true, false};
private static final String NAME = "NAME";
private static final String FIELD_NAME = "FIELD_NAME";
private static final String FIELD_NAME_LONG = "LONG_NAME";
private static final String TIME_COL = "__time";
private long[] times = {2436, 6879, 7888, 8224};
private long[] timesSame = {2436, 2436};
private SerializablePairLongString[] pairs = {
private final long[] times = {2436, 6879, 7888, 8224};
private final long[] timesSame = {2436, 2436};
private final SerializablePairLongString[] pairs = {
new SerializablePairLongString(2345100L, "last"),
new SerializablePairLongString(2345001L, "notLast")
};
@Mock
private VectorObjectSelector selector;
@Mock
private VectorObjectSelector selectorForPairs;
@Mock
private BaseLongVectorValueSelector timeSelector;
@Mock
private BaseLongVectorValueSelector timeSelectorForPairs;
private VectorValueSelector nonStringValueSelector;
private ByteBuffer buf;
private StringLastVectorAggregator target;
private StringLastVectorAggregator targetWithPairs;
private StringLastAggregatorFactory stringLastAggregatorFactory;
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private StringLastAggregatorFactory stringLastAggregatorFactory1;
private VectorColumnSelectorFactory selectorFactory;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getObjectVector();
Mockito.doReturn(times).when(timeSelector).getLongVector();
Mockito.doReturn(timesSame).when(timeSelectorForPairs).getLongVector();
Mockito.doReturn(pairs).when(selectorForPairs).getObjectVector();
timeSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(times.length, 0, times.length))
{
@Override
public long[] getLongVector()
{
return times;
}
@Nullable
@Override
public boolean[] getNullVector()
{
return NULLS;
}
};
nonStringValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(
LONG_VALUES.length,
0,
LONG_VALUES.length
))
{
@Override
public long[] getLongVector()
{
return LONG_VALUES;
}
@Override
public float[] getFloatVector()
{
return FLOAT_VALUES;
}
@Override
public double[] getDoubleVector()
{
return DOUBLE_VALUES;
}
@Nullable
@Override
public boolean[] getNullVector()
{
return NULLS;
}
@Override
public int getMaxVectorSize()
{
return 4;
}
@Override
public int getCurrentVectorSize()
{
return 4;
}
};
selector = new VectorObjectSelector()
{
@Override
public Object[] getObjectVector()
{
return VALUES;
}
@Override
public int getMaxVectorSize()
{
return 0;
}
@Override
public int getCurrentVectorSize()
{
return 0;
}
};
BaseLongVectorValueSelector timeSelectorForPairs = new BaseLongVectorValueSelector(new NoFilterVectorOffset(
times.length,
0,
times.length
))
{
@Override
public long[] getLongVector()
{
return timesSame;
}
@Nullable
@Override
public boolean[] getNullVector()
{
return new boolean[0];
}
};
VectorObjectSelector selectorForPairs = new VectorObjectSelector()
{
@Override
public Object[] getObjectVector()
{
return pairs;
}
@Override
public int getMaxVectorSize()
{
return 2;
}
@Override
public int getCurrentVectorSize()
{
return 2;
}
};
selectorFactory = new VectorColumnSelectorFactory()
{
@Override
public ReadableVectorInspector getReadableVectorInspector()
{
return new NoFilterVectorOffset(LONG_VALUES.length, 0, LONG_VALUES.length);
}
@Override
public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec)
{
return null;
}
@Override
public MultiValueDimensionVectorSelector makeMultiValueDimensionSelector(DimensionSpec dimensionSpec)
{
return null;
}
@Override
public VectorValueSelector makeValueSelector(String column)
{
if (TIME_COL.equals(column)) {
return timeSelector;
} else if (FIELD_NAME_LONG.equals(column)) {
return nonStringValueSelector;
}
return null;
}
@Override
public VectorObjectSelector makeObjectSelector(String column)
{
if (FIELD_NAME.equals(column)) {
return selector;
} else {
return null;
}
}
@Nullable
@Override
public ColumnCapabilities getColumnCapabilities(String column)
{
if (FIELD_NAME.equals(column)) {
return ColumnCapabilitiesImpl.createSimpleSingleValueStringColumnCapabilities();
} else if (FIELD_NAME_LONG.equals(column)) {
return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG);
}
return null;
}
};
target = new StringLastVectorAggregator(timeSelector, selector, 10);
targetWithPairs = new StringLastVectorAggregator(timeSelectorForPairs, selectorForPairs, 10);
clearBufferForPositions(0, 0);
Mockito.doReturn(selector).when(selectorFactory).makeObjectSelector(FIELD_NAME);
Mockito.doReturn(timeSelector).when(selectorFactory).makeValueSelector(TIME_COL);
stringLastAggregatorFactory = new StringLastAggregatorFactory(NAME, FIELD_NAME, TIME_COL, 10);
stringLastAggregatorFactory1 = new StringLastAggregatorFactory(NAME, FIELD_NAME_LONG, TIME_COL, 10);
}
@Test
@ -112,6 +283,19 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
Assert.assertEquals(StringLastVectorAggregator.class, vectorAggregator.getClass());
}
@Test
public void testStringLastOnNonStringColumns()
{
Assert.assertTrue(stringLastAggregatorFactory1.canVectorize(selectorFactory));
VectorAggregator vectorAggregator = stringLastAggregatorFactory1.factorizeVector(selectorFactory);
Assert.assertNotNull(vectorAggregator);
Assert.assertEquals(StringLastVectorAggregator.class, vectorAggregator.getClass());
vectorAggregator.aggregate(buf, 0, 0, LONG_VALUES.length);
Pair<Long, String> result = (Pair<Long, String>) vectorAggregator.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(STRING_VALUES[3], result.rhs);
}
@Test
public void initValueShouldBeMinDate()
{

View File

@ -632,7 +632,6 @@ public class CalciteSimpleQueryTest extends BaseCalciteQueryTest
@Test
public void testEarliestByLatestByWithExpression()
{
cannotVectorize();
testBuilder()
.sql("SELECT\n"
+ " channel\n"