Vectorize numeric latest aggregators (#12439)

* Vectorizing Latest aggregator Part 1

* Updating benchmark tests

* Changing appropriate logic for vectors for null handling

* Introducing an abstract class and moving the commonalities there

* Adding vectorization for StringLast aggregator (initial version)

* Updated bufferized version of numeric aggregators

* Adding some javadocs

* Making sure this PR vectorizes numeric latest agg only

* Adding another benchmarking test

* Fixing intellij inspections

* Adding tests for double

* Adding test cases for long and float

* Updating testcases

* Checkstyle oops..

* One tiny change in test case

* Fixing spotbug and rhs not being used
This commit is contained in:
somu-imply 2022-04-26 11:33:08 -07:00 committed by GitHub
parent 564d6defd4
commit 027935dcff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 948 additions and 19 deletions

View File

@ -200,7 +200,15 @@ public class SqlExpressionBenchmark
// 36: time shift + non-expr agg (group by), uniform distribution low cardinality // 36: time shift + non-expr agg (group by), uniform distribution low cardinality
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long4), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3", "SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long4), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3",
// 37: time shift + expr agg (group by), uniform distribution high cardinality // 37: time shift + expr agg (group by), uniform distribution high cardinality
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long5), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3" "SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long5), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3",
// 38: LATEST aggregator
"SELECT LATEST(long1) FROM foo",
// 39: LATEST aggregator double
"SELECT LATEST(double4) FROM foo",
// 40: LATEST aggregator double
"SELECT LATEST(float3) FROM foo",
// 41: LATEST aggregator double
"SELECT LATEST(float3), LATEST(long1), LATEST(double4) FROM foo"
); );
@Param({"5000000"}) @Param({"5000000"})
@ -252,7 +260,11 @@ public class SqlExpressionBenchmark
"34", "34",
"35", "35",
"36", "36",
"37" "37",
"38",
"39",
"40",
"41"
}) })
private String query; private String query;

View File

@ -29,16 +29,23 @@ import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator;
import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory; import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory;
import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory; import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseDoubleColumnValueSelector; import org.apache.druid.segment.BaseDoubleColumnValueSelector;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.NilColumnValueSelector; import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -109,6 +116,29 @@ public class DoubleLastAggregatorFactory extends AggregatorFactory
} }
} }
@Override
public boolean canVectorize(ColumnInspector columnInspector)
{
return true;
}
@Override
public VectorAggregator factorizeVector(
VectorColumnSelectorFactory columnSelectorFactory
)
{
ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
//time is always long
BaseLongVectorValueSelector timeSelector = (BaseLongVectorValueSelector) columnSelectorFactory.makeValueSelector(
timeColumn);
if (capabilities == null || capabilities.isNumeric()) {
return new DoubleLastVectorAggregator(timeSelector, valueSelector);
} else {
return NumericNilVectorAggregator.doubleNilVectorAggregator();
}
}
@Override @Override
public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
{ {

View File

@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized version of on heap 'last' aggregator for column selectors with type DOUBLE..
*/
public class DoubleLastVectorAggregator extends NumericLastVectorAggregator
{
double lastValue;
public DoubleLastVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector)
{
super(timeSelector, valueSelector);
lastValue = 0;
}
@Override
void putValue(ByteBuffer buf, int position, int index)
{
lastValue = valueSelector.getDoubleVector()[index];
buf.putDouble(position, lastValue);
}
@Override
public void initValue(ByteBuffer buf, int position)
{
buf.putDouble(position, 0);
}
@Nullable
@Override
public Object get(ByteBuffer buf, int position)
{
final boolean rhsNull = isValueNull(buf, position);
return new SerializablePair<>(buf.getLong(position), rhsNull ? null : buf.getDouble(position + VALUE_OFFSET));
}
}

View File

@ -29,16 +29,23 @@ import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator;
import org.apache.druid.query.aggregation.first.FloatFirstAggregatorFactory; import org.apache.druid.query.aggregation.first.FloatFirstAggregatorFactory;
import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory; import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseFloatColumnValueSelector; import org.apache.druid.segment.BaseFloatColumnValueSelector;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.NilColumnValueSelector; import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -121,6 +128,29 @@ public class FloatLastAggregatorFactory extends AggregatorFactory
} }
} }
@Override
public boolean canVectorize(ColumnInspector columnInspector)
{
return true;
}
@Override
public VectorAggregator factorizeVector(
VectorColumnSelectorFactory columnSelectorFactory
)
{
ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
//time is always long
BaseLongVectorValueSelector timeSelector = (BaseLongVectorValueSelector) columnSelectorFactory.makeValueSelector(
timeColumn);
if (capabilities == null || capabilities.isNumeric()) {
return new FloatLastVectorAggregator(timeSelector, valueSelector);
} else {
return NumericNilVectorAggregator.floatNilVectorAggregator();
}
}
@Override @Override
public Comparator getComparator() public Comparator getComparator()
{ {

View File

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized version of on heap 'last' aggregator for column selectors with type FLOAT..
*/
public class FloatLastVectorAggregator extends NumericLastVectorAggregator
{
float lastValue;
public FloatLastVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector)
{
super(timeSelector, valueSelector);
lastValue = 0;
}
@Override
void putValue(ByteBuffer buf, int position, int index)
{
lastValue = valueSelector.getFloatVector()[index];
buf.putFloat(position, lastValue);
}
@Override
public void initValue(ByteBuffer buf, int position)
{
buf.putFloat(position, 0);
}
@Nullable
@Override
public Object get(ByteBuffer buf, int position)
{
final boolean rhsNull = isValueNull(buf, position);
return new SerializablePair<>(buf.getLong(position), rhsNull ? null : buf.getFloat(position + VALUE_OFFSET));
}
}

View File

@ -29,15 +29,22 @@ import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator;
import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory; import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseLongColumnValueSelector; import org.apache.druid.segment.BaseLongColumnValueSelector;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.NilColumnValueSelector; import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -120,6 +127,28 @@ public class LongLastAggregatorFactory extends AggregatorFactory
} }
} }
@Override
public boolean canVectorize(ColumnInspector columnInspector)
{
return true;
}
@Override
public VectorAggregator factorizeVector(
VectorColumnSelectorFactory columnSelectorFactory
)
{
ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
BaseLongVectorValueSelector timeSelector = (BaseLongVectorValueSelector) columnSelectorFactory.makeValueSelector(
timeColumn);
if (capabilities == null || capabilities.isNumeric()) {
return new LongLastVectorAggregator(timeSelector, valueSelector);
} else {
return NumericNilVectorAggregator.longNilVectorAggregator();
}
}
@Override @Override
public Comparator getComparator() public Comparator getComparator()
{ {

View File

@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized version of on heap 'last' aggregator for column selectors with type LONG..
*/
public class LongLastVectorAggregator extends NumericLastVectorAggregator
{
long lastValue;
public LongLastVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector)
{
super(timeSelector, valueSelector);
lastValue = 0;
}
@Override
public void initValue(ByteBuffer buf, int position)
{
buf.putLong(position, 0);
}
@Override
void putValue(ByteBuffer buf, int position, int index)
{
lastValue = valueSelector.getLongVector()[index];
buf.putLong(position, lastValue);
}
/**
* @return The primitive object stored at the position in the buffer.
*/
@Nullable
@Override
public Object get(ByteBuffer buf, int position)
{
final boolean rhsNull = isValueNull(buf, position);
return new SerializablePair<>(buf.getLong(position), rhsNull ? null : buf.getLong(position + VALUE_OFFSET));
}
}

View File

@ -72,7 +72,7 @@ public abstract class NumericLastAggregator<TSelector extends BaseNullableColumn
} }
/** /**
* Store the current primitive typed 'first' value * Store the current primitive typed 'last' value
*/ */
abstract void setCurrentValue(); abstract void setCurrentValue();
} }

View File

@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Base type for vectorized version of on heap 'last' aggregator for primitive numeric column selectors..
*/
public abstract class NumericLastVectorAggregator implements VectorAggregator
{
static final int NULL_OFFSET = Long.BYTES;
static final int VALUE_OFFSET = NULL_OFFSET + Byte.BYTES;
final VectorValueSelector valueSelector;
private final boolean useDefault = NullHandling.replaceWithDefault();
private final VectorValueSelector timeSelector;
private long lastTime;
public NumericLastVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector)
{
this.timeSelector = timeSelector;
this.valueSelector = valueSelector;
lastTime = Long.MIN_VALUE;
}
@Override
public void init(ByteBuffer buf, int position)
{
buf.putLong(position, Long.MIN_VALUE);
buf.put(position + NULL_OFFSET, useDefault ? NullHandling.IS_NOT_NULL_BYTE : NullHandling.IS_NULL_BYTE);
initValue(buf, position + VALUE_OFFSET);
}
@Override
public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
{
final long[] timeVector = timeSelector.getLongVector();
final boolean[] nullValueVector = valueSelector.getNullVector();
boolean nullAbsent = false;
lastTime = buf.getLong(position);
//check if nullVector is found or not
// the nullVector is null if no null values are found
// set the nullAbsent flag accordingly
if (nullValueVector == null) {
nullAbsent = true;
}
//the time vector is already sorted so the last element would be the latest
//traverse the value vector from the back (for latest)
int index = endRow - 1;
if (!useDefault && !nullAbsent) {
for (int i = endRow - 1; i >= startRow; i--) {
if (!nullValueVector[i]) {
index = i;
break;
}
}
}
//find the first non-null value
final long latestTime = timeVector[index];
if (latestTime >= lastTime) {
lastTime = latestTime;
if (useDefault || nullValueVector == null || !nullValueVector[index]) {
updateTimeWithValue(buf, position, lastTime, index);
} else {
updateTimeWithNull(buf, position, lastTime);
}
}
}
/**
*
* Checks if the aggregated value at a position in the buffer is null or not
*
* @param buf byte buffer storing the byte array representation of the aggregate
* @param position offset within the byte buffer at which the current aggregate value is stored
* @return
*/
boolean isValueNull(ByteBuffer buf, int position)
{
return buf.get(position + NULL_OFFSET) == NullHandling.IS_NULL_BYTE;
}
@Override
public void aggregate(
ByteBuffer buf,
int numRows,
int[] positions,
@Nullable int[] rows,
int positionOffset
)
{
boolean[] nulls = useDefault ? null : valueSelector.getNullVector();
long[] timeVector = timeSelector.getLongVector();
for (int i = 0; i < numRows; i++) {
int position = positions[i] + positionOffset;
int row = rows == null ? i : rows[i];
long lastTime = buf.getLong(position);
if (timeVector[row] >= lastTime) {
if (useDefault || nulls == null || !nulls[row]) {
updateTimeWithValue(buf, position, timeVector[row], row);
} else {
updateTimeWithNull(buf, position, timeVector[row]);
}
}
}
}
/**
*
* @param buf byte buffer storing the byte array representation of the aggregate
* @param position offset within the byte buffer at which the current aggregate value is stored
* @param time the time to be updated in the buffer as the last time
* @param index the index of the vectorized vector which is the last value
*/
void updateTimeWithValue(ByteBuffer buf, int position, long time, int index)
{
buf.putLong(position, time);
buf.put(position + NULL_OFFSET, NullHandling.IS_NOT_NULL_BYTE);
putValue(buf, position + VALUE_OFFSET, index);
}
/**
*
* @param buf byte buffer storing the byte array representation of the aggregate
* @param position offset within the byte buffer at which the current aggregate value is stored
* @param time the time to be updated in the buffer as the last time
*/
void updateTimeWithNull(ByteBuffer buf, int position, long time)
{
buf.putLong(position, time);
buf.put(position + NULL_OFFSET, NullHandling.IS_NULL_BYTE);
}
/**
*Abstract function which needs to be overridden by subclasses to set the initial value
*/
abstract void initValue(ByteBuffer buf, int position);
/**
*Abstract function which needs to be overridden by subclasses to set the
* latest value in the buffer depending on the datatype
*/
abstract void putValue(ByteBuffer buf, int position, int index);
@Override
public void close()
{
// no resources to cleanup
}
}

View File

@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class DoubleLastVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final double EPSILON = 1e-5;
private static final double[] VALUES = new double[]{7.8d, 11, 23.67, 60};
private static final boolean[] NULLS = new boolean[]{false, false, true, false};
private long[] times = {2436, 6879, 7888, 8224};
@Mock
private VectorValueSelector selector;
@Mock
private VectorValueSelector timeSelector;
private ByteBuffer buf;
private DoubleLastVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getDoubleVector();
Mockito.doReturn(times).when(timeSelector).getLongVector();
target = new DoubleLastVectorAggregator(timeSelector, selector);
clearBufferForPositions(0, 0);
}
@Test
public void initValueShouldInitZero()
{
target.initValue(buf, 0);
double initVal = buf.getDouble(0);
Assert.assertEquals(0, initVal, EPSILON);
}
@Test
public void aggregate()
{
target.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, Double> result = (Pair<Long, Double>) target.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs, EPSILON);
}
@Test
public void aggregateWithNulls()
{
mockNullsVector();
target.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, Double> result = (Pair<Long, Double>) target.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, Double> result = (Pair<Long, Double>) target.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[i], result.lhs.longValue());
Assert.assertEquals(VALUES[i], result.rhs, EPSILON);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, Double> result = (Pair<Long, Double>) target.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[rows[i]], result.lhs.longValue());
Assert.assertEquals(VALUES[rows[i]], result.rhs, EPSILON);
}
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
target.init(buf, offset + position);
}
}
private void mockNullsVector()
{
if (!NullHandling.replaceWithDefault()) {
Mockito.doReturn(NULLS).when(selector).getNullVector();
}
}
}

View File

@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class FloatLastVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final double EPSILON = 1e-5;
private static final float[] VALUES = new float[]{7.2f, 15.6f, 2.1f, 150.0f};
private static final boolean[] NULLS = new boolean[]{false, false, true, false};
private long[] times = {2436, 6879, 7888, 8224};
@Mock
private VectorValueSelector selector;
@Mock
private VectorValueSelector timeSelector;
private ByteBuffer buf;
private FloatLastVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getFloatVector();
Mockito.doReturn(times).when(timeSelector).getLongVector();
target = new FloatLastVectorAggregator(timeSelector, selector);
clearBufferForPositions(0, 0);
}
@Test
public void initValueShouldBeZero()
{
target.initValue(buf, 0);
float initVal = buf.getFloat(0);
Assert.assertEquals(0.0f, initVal, EPSILON);
}
@Test
public void aggregate()
{
target.init(buf, 0);
target.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, Float> result = (Pair<Long, Float>) target.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs, EPSILON);
}
@Test
public void aggregateWithNulls()
{
mockNullsVector();
target.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, Float> result = (Pair<Long, Float>) target.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, Float> result = (Pair<Long, Float>) target.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[i], result.lhs.longValue());
Assert.assertEquals(VALUES[i], result.rhs, EPSILON);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, Float> result = (Pair<Long, Float>) target.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[rows[i]], result.lhs.longValue());
Assert.assertEquals(VALUES[rows[i]], result.rhs, EPSILON);
}
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
target.init(buf, offset + position);
}
}
private void mockNullsVector()
{
if (!NullHandling.replaceWithDefault()) {
Mockito.doReturn(NULLS).when(selector).getNullVector();
}
}
}

View File

@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class LongLastVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final double EPSILON = 1e-5;
private static final long[] VALUES = new long[]{7, 15, 2, 150};
private static final boolean[] NULLS = new boolean[]{false, false, true, false};
private long[] times = {2436, 6879, 7888, 8224};
@Mock
private VectorValueSelector selector;
@Mock
private VectorValueSelector timeSelector;
private ByteBuffer buf;
private LongLastVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getLongVector();
Mockito.doReturn(times).when(timeSelector).getLongVector();
target = new LongLastVectorAggregator(timeSelector, selector);
clearBufferForPositions(0, 0);
}
@Test
public void initValueShouldInitZero()
{
target.initValue(buf, 0);
long initVal = buf.getLong(0);
Assert.assertEquals(0, initVal);
}
@Test
public void aggregate()
{
target.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, Long> result = (Pair<Long, Long>) target.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs, EPSILON);
}
@Test
public void aggregateWithNulls()
{
mockNullsVector();
target.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, Long> result = (Pair<Long, Long>) target.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, Long> result = (Pair<Long, Long>) target.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[i], result.lhs.longValue());
Assert.assertEquals(VALUES[i], result.rhs, EPSILON);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, Long> result = (Pair<Long, Long>) target.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[rows[i]], result.lhs.longValue());
Assert.assertEquals(VALUES[rows[i]], result.rhs, EPSILON);
}
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
target.init(buf, offset + position);
}
}
private void mockNullsVector()
{
if (!NullHandling.replaceWithDefault()) {
Mockito.doReturn(NULLS).when(selector).getNullVector();
}
}
}

View File

@ -647,12 +647,44 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
); );
} }
@Test
public void testLatestVectorAggregators() throws Exception
{
testQuery(
"SELECT "
+ "LATEST(cnt), LATEST(cnt + 1), LATEST(m1), LATEST(m1+1) "
+ "FROM druid.numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.virtualColumns(
expressionVirtualColumn("v0", "(\"cnt\" + 1)", ColumnType.LONG),
expressionVirtualColumn("v1", "(\"m1\" + 1)", ColumnType.FLOAT)
)
.aggregators(
aggregators(
new LongLastAggregatorFactory("a0", "cnt", null),
new LongLastAggregatorFactory("a1", "v0", null),
new FloatLastAggregatorFactory("a2", "m1", null),
new FloatLastAggregatorFactory("a3", "v1", null)
)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{1L, 2L, 6.0f, 7.0f}
)
);
}
@Test @Test
public void testLatestAggregators() throws Exception public void testLatestAggregators() throws Exception
{ {
// Cannot vectorize LATEST aggregator. // Cannot vectorize until StringLast is vectorized
skipVectorize(); skipVectorize();
testQuery( testQuery(
"SELECT " "SELECT "
+ "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), " + "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), "
@ -834,9 +866,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test @Test
public void testPrimitiveLatestInSubquery() throws Exception public void testPrimitiveLatestInSubquery() throws Exception
{ {
// Cannot vectorize LATEST aggregator.
skipVectorize();
testQuery( testQuery(
"SELECT SUM(val1), SUM(val2), SUM(val3) FROM (SELECT dim2, LATEST(m1) AS val1, LATEST(cnt) AS val2, LATEST(m2) AS val3 FROM foo GROUP BY dim2)", "SELECT SUM(val1), SUM(val2), SUM(val3) FROM (SELECT dim2, LATEST(m1) AS val1, LATEST(cnt) AS val2, LATEST(m2) AS val3 FROM foo GROUP BY dim2)",
ImmutableList.of( ImmutableList.of(
@ -881,6 +910,40 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
); );
} }
@Test
public void testPrimitiveLatestInSubqueryGroupBy() throws Exception
{
testQuery(
"SELECT dim2, LATEST(m1) AS val1 FROM foo GROUP BY dim2",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(new DefaultDimensionSpec("dim2", "d0")))
.setAggregatorSpecs(aggregators(
new FloatLastAggregatorFactory("a0", "m1", null)
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
NullHandling.sqlCompatible()
? ImmutableList.of(
new Object[]{null, 6.0f},
new Object[]{"", 3.0f},
new Object[]{"a", 4.0f},
new Object[]{"abc", 5.0f}
)
: ImmutableList.of(
new Object[]{"", 6.0f},
new Object[]{"a", 4.0f},
new Object[]{"abc", 5.0f}
)
);
}
// This test the off-heap (buffer) version of the EarliestAggregator (Double/Float/Long) // This test the off-heap (buffer) version of the EarliestAggregator (Double/Float/Long)
@Test @Test
public void testPrimitiveEarliestInSubquery() throws Exception public void testPrimitiveEarliestInSubquery() throws Exception
@ -936,7 +999,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test @Test
public void testStringLatestInSubquery() throws Exception public void testStringLatestInSubquery() throws Exception
{ {
// Cannot vectorize LATEST aggregator. // Cannot vectorize LATEST aggregator for Strings
skipVectorize(); skipVectorize();
testQuery( testQuery(
@ -1176,9 +1239,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test @Test
public void testLatestAggregatorsNumericNull() throws Exception public void testLatestAggregatorsNumericNull() throws Exception
{ {
// Cannot vectorize LATEST aggregator.
skipVectorize();
testQuery( testQuery(
"SELECT LATEST(l1), LATEST(d1), LATEST(f1) FROM druid.numfoo", "SELECT LATEST(l1), LATEST(d1), LATEST(f1) FROM druid.numfoo",
ImmutableList.of( ImmutableList.of(
@ -1209,7 +1269,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test @Test
public void testFirstLatestAggregatorsSkipNulls() throws Exception public void testFirstLatestAggregatorsSkipNulls() throws Exception
{ {
// Cannot vectorize LATEST aggregator. // Cannot vectorize EARLIEST aggregator.
skipVectorize(); skipVectorize();
final DimFilter filter; final DimFilter filter;
@ -1465,8 +1525,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test @Test
public void testOrderByLatestFloat() throws Exception public void testOrderByLatestFloat() throws Exception
{ {
// Cannot vectorize LATEST aggregator.
skipVectorize();
List<Object[]> expected; List<Object[]> expected;
if (NullHandling.replaceWithDefault()) { if (NullHandling.replaceWithDefault()) {
expected = ImmutableList.of( expected = ImmutableList.of(
@ -1513,8 +1572,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test @Test
public void testOrderByLatestDouble() throws Exception public void testOrderByLatestDouble() throws Exception
{ {
// Cannot vectorize LATEST aggregator.
skipVectorize();
List<Object[]> expected; List<Object[]> expected;
if (NullHandling.replaceWithDefault()) { if (NullHandling.replaceWithDefault()) {
expected = ImmutableList.of( expected = ImmutableList.of(
@ -1560,8 +1617,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test @Test
public void testOrderByLatestLong() throws Exception public void testOrderByLatestLong() throws Exception
{ {
// Cannot vectorize LATEST aggregator.
skipVectorize();
List<Object[]> expected; List<Object[]> expected;
if (NullHandling.replaceWithDefault()) { if (NullHandling.replaceWithDefault()) {
expected = ImmutableList.of( expected = ImmutableList.of(