Vectorize numeric latest aggregators (#12439)

* Vectorizing Latest aggregator Part 1

* Updating benchmark tests

* Changing appropriate logic for vectors for null handling

* Introducing an abstract class and moving the commonalities there

* Adding vectorization for StringLast aggregator (initial version)

* Updated bufferized version of numeric aggregators

* Adding some javadocs

* Making sure this PR vectorizes numeric latest agg only

* Adding another benchmarking test

* Fixing intellij inspections

* Adding tests for double

* Adding test cases for long and float

* Updating testcases

* Checkstyle oops..

* One tiny change in test case

* Fixing spotbug and rhs not being used
This commit is contained in:
somu-imply 2022-04-26 11:33:08 -07:00 committed by GitHub
parent 564d6defd4
commit 027935dcff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 948 additions and 19 deletions

View File

@ -200,7 +200,15 @@ public class SqlExpressionBenchmark
// 36: time shift + non-expr agg (group by), uniform distribution low cardinality
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long4), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3",
// 37: time shift + expr agg (group by), uniform distribution high cardinality
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long5), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3"
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long5), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3",
// 38: LATEST aggregator
"SELECT LATEST(long1) FROM foo",
// 39: LATEST aggregator double
"SELECT LATEST(double4) FROM foo",
// 40: LATEST aggregator double
"SELECT LATEST(float3) FROM foo",
// 41: LATEST aggregator double
"SELECT LATEST(float3), LATEST(long1), LATEST(double4) FROM foo"
);
@Param({"5000000"})
@ -252,7 +260,11 @@ public class SqlExpressionBenchmark
"34",
"35",
"36",
"37"
"37",
"38",
"39",
"40",
"41"
})
private String query;

View File

@ -29,16 +29,23 @@ import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator;
import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory;
import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseDoubleColumnValueSelector;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@ -109,6 +116,29 @@ public class DoubleLastAggregatorFactory extends AggregatorFactory
}
}
@Override
public boolean canVectorize(ColumnInspector columnInspector)
{
return true;
}
@Override
public VectorAggregator factorizeVector(
VectorColumnSelectorFactory columnSelectorFactory
)
{
ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
//time is always long
BaseLongVectorValueSelector timeSelector = (BaseLongVectorValueSelector) columnSelectorFactory.makeValueSelector(
timeColumn);
if (capabilities == null || capabilities.isNumeric()) {
return new DoubleLastVectorAggregator(timeSelector, valueSelector);
} else {
return NumericNilVectorAggregator.doubleNilVectorAggregator();
}
}
@Override
public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
{

View File

@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized version of on heap 'last' aggregator for column selectors with type DOUBLE..
*/
public class DoubleLastVectorAggregator extends NumericLastVectorAggregator
{
double lastValue;
public DoubleLastVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector)
{
super(timeSelector, valueSelector);
lastValue = 0;
}
@Override
void putValue(ByteBuffer buf, int position, int index)
{
lastValue = valueSelector.getDoubleVector()[index];
buf.putDouble(position, lastValue);
}
@Override
public void initValue(ByteBuffer buf, int position)
{
buf.putDouble(position, 0);
}
@Nullable
@Override
public Object get(ByteBuffer buf, int position)
{
final boolean rhsNull = isValueNull(buf, position);
return new SerializablePair<>(buf.getLong(position), rhsNull ? null : buf.getDouble(position + VALUE_OFFSET));
}
}

View File

@ -29,16 +29,23 @@ import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator;
import org.apache.druid.query.aggregation.first.FloatFirstAggregatorFactory;
import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseFloatColumnValueSelector;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@ -121,6 +128,29 @@ public class FloatLastAggregatorFactory extends AggregatorFactory
}
}
@Override
public boolean canVectorize(ColumnInspector columnInspector)
{
return true;
}
@Override
public VectorAggregator factorizeVector(
VectorColumnSelectorFactory columnSelectorFactory
)
{
ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
//time is always long
BaseLongVectorValueSelector timeSelector = (BaseLongVectorValueSelector) columnSelectorFactory.makeValueSelector(
timeColumn);
if (capabilities == null || capabilities.isNumeric()) {
return new FloatLastVectorAggregator(timeSelector, valueSelector);
} else {
return NumericNilVectorAggregator.floatNilVectorAggregator();
}
}
@Override
public Comparator getComparator()
{

View File

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized version of on heap 'last' aggregator for column selectors with type FLOAT..
*/
public class FloatLastVectorAggregator extends NumericLastVectorAggregator
{
float lastValue;
public FloatLastVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector)
{
super(timeSelector, valueSelector);
lastValue = 0;
}
@Override
void putValue(ByteBuffer buf, int position, int index)
{
lastValue = valueSelector.getFloatVector()[index];
buf.putFloat(position, lastValue);
}
@Override
public void initValue(ByteBuffer buf, int position)
{
buf.putFloat(position, 0);
}
@Nullable
@Override
public Object get(ByteBuffer buf, int position)
{
final boolean rhsNull = isValueNull(buf, position);
return new SerializablePair<>(buf.getLong(position), rhsNull ? null : buf.getFloat(position + VALUE_OFFSET));
}
}

View File

@ -29,15 +29,22 @@ import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.any.NumericNilVectorAggregator;
import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseLongColumnValueSelector;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.NilColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@ -120,6 +127,28 @@ public class LongLastAggregatorFactory extends AggregatorFactory
}
}
@Override
public boolean canVectorize(ColumnInspector columnInspector)
{
return true;
}
@Override
public VectorAggregator factorizeVector(
VectorColumnSelectorFactory columnSelectorFactory
)
{
ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector valueSelector = columnSelectorFactory.makeValueSelector(fieldName);
BaseLongVectorValueSelector timeSelector = (BaseLongVectorValueSelector) columnSelectorFactory.makeValueSelector(
timeColumn);
if (capabilities == null || capabilities.isNumeric()) {
return new LongLastVectorAggregator(timeSelector, valueSelector);
} else {
return NumericNilVectorAggregator.longNilVectorAggregator();
}
}
@Override
public Comparator getComparator()
{

View File

@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Vectorized version of on heap 'last' aggregator for column selectors with type LONG..
*/
public class LongLastVectorAggregator extends NumericLastVectorAggregator
{
long lastValue;
public LongLastVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector)
{
super(timeSelector, valueSelector);
lastValue = 0;
}
@Override
public void initValue(ByteBuffer buf, int position)
{
buf.putLong(position, 0);
}
@Override
void putValue(ByteBuffer buf, int position, int index)
{
lastValue = valueSelector.getLongVector()[index];
buf.putLong(position, lastValue);
}
/**
* @return The primitive object stored at the position in the buffer.
*/
@Nullable
@Override
public Object get(ByteBuffer buf, int position)
{
final boolean rhsNull = isValueNull(buf, position);
return new SerializablePair<>(buf.getLong(position), rhsNull ? null : buf.getLong(position + VALUE_OFFSET));
}
}

View File

@ -72,7 +72,7 @@ public abstract class NumericLastAggregator<TSelector extends BaseNullableColumn
}
/**
* Store the current primitive typed 'first' value
* Store the current primitive typed 'last' value
*/
abstract void setCurrentValue();
}

View File

@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
/**
* Base type for vectorized version of on heap 'last' aggregator for primitive numeric column selectors..
*/
public abstract class NumericLastVectorAggregator implements VectorAggregator
{
static final int NULL_OFFSET = Long.BYTES;
static final int VALUE_OFFSET = NULL_OFFSET + Byte.BYTES;
final VectorValueSelector valueSelector;
private final boolean useDefault = NullHandling.replaceWithDefault();
private final VectorValueSelector timeSelector;
private long lastTime;
public NumericLastVectorAggregator(VectorValueSelector timeSelector, VectorValueSelector valueSelector)
{
this.timeSelector = timeSelector;
this.valueSelector = valueSelector;
lastTime = Long.MIN_VALUE;
}
@Override
public void init(ByteBuffer buf, int position)
{
buf.putLong(position, Long.MIN_VALUE);
buf.put(position + NULL_OFFSET, useDefault ? NullHandling.IS_NOT_NULL_BYTE : NullHandling.IS_NULL_BYTE);
initValue(buf, position + VALUE_OFFSET);
}
@Override
public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
{
final long[] timeVector = timeSelector.getLongVector();
final boolean[] nullValueVector = valueSelector.getNullVector();
boolean nullAbsent = false;
lastTime = buf.getLong(position);
//check if nullVector is found or not
// the nullVector is null if no null values are found
// set the nullAbsent flag accordingly
if (nullValueVector == null) {
nullAbsent = true;
}
//the time vector is already sorted so the last element would be the latest
//traverse the value vector from the back (for latest)
int index = endRow - 1;
if (!useDefault && !nullAbsent) {
for (int i = endRow - 1; i >= startRow; i--) {
if (!nullValueVector[i]) {
index = i;
break;
}
}
}
//find the first non-null value
final long latestTime = timeVector[index];
if (latestTime >= lastTime) {
lastTime = latestTime;
if (useDefault || nullValueVector == null || !nullValueVector[index]) {
updateTimeWithValue(buf, position, lastTime, index);
} else {
updateTimeWithNull(buf, position, lastTime);
}
}
}
/**
*
* Checks if the aggregated value at a position in the buffer is null or not
*
* @param buf byte buffer storing the byte array representation of the aggregate
* @param position offset within the byte buffer at which the current aggregate value is stored
* @return
*/
boolean isValueNull(ByteBuffer buf, int position)
{
return buf.get(position + NULL_OFFSET) == NullHandling.IS_NULL_BYTE;
}
@Override
public void aggregate(
ByteBuffer buf,
int numRows,
int[] positions,
@Nullable int[] rows,
int positionOffset
)
{
boolean[] nulls = useDefault ? null : valueSelector.getNullVector();
long[] timeVector = timeSelector.getLongVector();
for (int i = 0; i < numRows; i++) {
int position = positions[i] + positionOffset;
int row = rows == null ? i : rows[i];
long lastTime = buf.getLong(position);
if (timeVector[row] >= lastTime) {
if (useDefault || nulls == null || !nulls[row]) {
updateTimeWithValue(buf, position, timeVector[row], row);
} else {
updateTimeWithNull(buf, position, timeVector[row]);
}
}
}
}
/**
*
* @param buf byte buffer storing the byte array representation of the aggregate
* @param position offset within the byte buffer at which the current aggregate value is stored
* @param time the time to be updated in the buffer as the last time
* @param index the index of the vectorized vector which is the last value
*/
void updateTimeWithValue(ByteBuffer buf, int position, long time, int index)
{
buf.putLong(position, time);
buf.put(position + NULL_OFFSET, NullHandling.IS_NOT_NULL_BYTE);
putValue(buf, position + VALUE_OFFSET, index);
}
/**
*
* @param buf byte buffer storing the byte array representation of the aggregate
* @param position offset within the byte buffer at which the current aggregate value is stored
* @param time the time to be updated in the buffer as the last time
*/
void updateTimeWithNull(ByteBuffer buf, int position, long time)
{
buf.putLong(position, time);
buf.put(position + NULL_OFFSET, NullHandling.IS_NULL_BYTE);
}
/**
*Abstract function which needs to be overridden by subclasses to set the initial value
*/
abstract void initValue(ByteBuffer buf, int position);
/**
*Abstract function which needs to be overridden by subclasses to set the
* latest value in the buffer depending on the datatype
*/
abstract void putValue(ByteBuffer buf, int position, int index);
@Override
public void close()
{
// no resources to cleanup
}
}

View File

@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class DoubleLastVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final double EPSILON = 1e-5;
private static final double[] VALUES = new double[]{7.8d, 11, 23.67, 60};
private static final boolean[] NULLS = new boolean[]{false, false, true, false};
private long[] times = {2436, 6879, 7888, 8224};
@Mock
private VectorValueSelector selector;
@Mock
private VectorValueSelector timeSelector;
private ByteBuffer buf;
private DoubleLastVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getDoubleVector();
Mockito.doReturn(times).when(timeSelector).getLongVector();
target = new DoubleLastVectorAggregator(timeSelector, selector);
clearBufferForPositions(0, 0);
}
@Test
public void initValueShouldInitZero()
{
target.initValue(buf, 0);
double initVal = buf.getDouble(0);
Assert.assertEquals(0, initVal, EPSILON);
}
@Test
public void aggregate()
{
target.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, Double> result = (Pair<Long, Double>) target.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs, EPSILON);
}
@Test
public void aggregateWithNulls()
{
mockNullsVector();
target.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, Double> result = (Pair<Long, Double>) target.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, Double> result = (Pair<Long, Double>) target.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[i], result.lhs.longValue());
Assert.assertEquals(VALUES[i], result.rhs, EPSILON);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, Double> result = (Pair<Long, Double>) target.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[rows[i]], result.lhs.longValue());
Assert.assertEquals(VALUES[rows[i]], result.rhs, EPSILON);
}
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
target.init(buf, offset + position);
}
}
private void mockNullsVector()
{
if (!NullHandling.replaceWithDefault()) {
Mockito.doReturn(NULLS).when(selector).getNullVector();
}
}
}

View File

@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class FloatLastVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final double EPSILON = 1e-5;
private static final float[] VALUES = new float[]{7.2f, 15.6f, 2.1f, 150.0f};
private static final boolean[] NULLS = new boolean[]{false, false, true, false};
private long[] times = {2436, 6879, 7888, 8224};
@Mock
private VectorValueSelector selector;
@Mock
private VectorValueSelector timeSelector;
private ByteBuffer buf;
private FloatLastVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getFloatVector();
Mockito.doReturn(times).when(timeSelector).getLongVector();
target = new FloatLastVectorAggregator(timeSelector, selector);
clearBufferForPositions(0, 0);
}
@Test
public void initValueShouldBeZero()
{
target.initValue(buf, 0);
float initVal = buf.getFloat(0);
Assert.assertEquals(0.0f, initVal, EPSILON);
}
@Test
public void aggregate()
{
target.init(buf, 0);
target.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, Float> result = (Pair<Long, Float>) target.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs, EPSILON);
}
@Test
public void aggregateWithNulls()
{
mockNullsVector();
target.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, Float> result = (Pair<Long, Float>) target.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, Float> result = (Pair<Long, Float>) target.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[i], result.lhs.longValue());
Assert.assertEquals(VALUES[i], result.rhs, EPSILON);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, Float> result = (Pair<Long, Float>) target.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[rows[i]], result.lhs.longValue());
Assert.assertEquals(VALUES[rows[i]], result.rhs, EPSILON);
}
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
target.init(buf, offset + position);
}
}
private void mockNullsVector()
{
if (!NullHandling.replaceWithDefault()) {
Mockito.doReturn(NULLS).when(selector).getNullVector();
}
}
}

View File

@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.segment.vector.VectorValueSelector;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
@RunWith(MockitoJUnitRunner.class)
public class LongLastVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final double EPSILON = 1e-5;
private static final long[] VALUES = new long[]{7, 15, 2, 150};
private static final boolean[] NULLS = new boolean[]{false, false, true, false};
private long[] times = {2436, 6879, 7888, 8224};
@Mock
private VectorValueSelector selector;
@Mock
private VectorValueSelector timeSelector;
private ByteBuffer buf;
private LongLastVectorAggregator target;
@Before
public void setup()
{
byte[] randomBytes = new byte[1024];
ThreadLocalRandom.current().nextBytes(randomBytes);
buf = ByteBuffer.wrap(randomBytes);
Mockito.doReturn(VALUES).when(selector).getLongVector();
Mockito.doReturn(times).when(timeSelector).getLongVector();
target = new LongLastVectorAggregator(timeSelector, selector);
clearBufferForPositions(0, 0);
}
@Test
public void initValueShouldInitZero()
{
target.initValue(buf, 0);
long initVal = buf.getLong(0);
Assert.assertEquals(0, initVal);
}
@Test
public void aggregate()
{
target.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, Long> result = (Pair<Long, Long>) target.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs, EPSILON);
}
@Test
public void aggregateWithNulls()
{
mockNullsVector();
target.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, Long> result = (Pair<Long, Long>) target.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs, EPSILON);
}
@Test
public void aggregateBatchWithoutRows()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, Long> result = (Pair<Long, Long>) target.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[i], result.lhs.longValue());
Assert.assertEquals(VALUES[i], result.rhs, EPSILON);
}
}
@Test
public void aggregateBatchWithRows()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
target.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, Long> result = (Pair<Long, Long>) target.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[rows[i]], result.lhs.longValue());
Assert.assertEquals(VALUES[rows[i]], result.rhs, EPSILON);
}
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {
target.init(buf, offset + position);
}
}
private void mockNullsVector()
{
if (!NullHandling.replaceWithDefault()) {
Mockito.doReturn(NULLS).when(selector).getNullVector();
}
}
}

View File

@ -647,12 +647,44 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
);
}
@Test
public void testLatestVectorAggregators() throws Exception
{
testQuery(
"SELECT "
+ "LATEST(cnt), LATEST(cnt + 1), LATEST(m1), LATEST(m1+1) "
+ "FROM druid.numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.virtualColumns(
expressionVirtualColumn("v0", "(\"cnt\" + 1)", ColumnType.LONG),
expressionVirtualColumn("v1", "(\"m1\" + 1)", ColumnType.FLOAT)
)
.aggregators(
aggregators(
new LongLastAggregatorFactory("a0", "cnt", null),
new LongLastAggregatorFactory("a1", "v0", null),
new FloatLastAggregatorFactory("a2", "m1", null),
new FloatLastAggregatorFactory("a3", "v1", null)
)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{1L, 2L, 6.0f, 7.0f}
)
);
}
@Test
public void testLatestAggregators() throws Exception
{
// Cannot vectorize LATEST aggregator.
// Cannot vectorize until StringLast is vectorized
skipVectorize();
testQuery(
"SELECT "
+ "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), "
@ -834,9 +866,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test
public void testPrimitiveLatestInSubquery() throws Exception
{
// Cannot vectorize LATEST aggregator.
skipVectorize();
testQuery(
"SELECT SUM(val1), SUM(val2), SUM(val3) FROM (SELECT dim2, LATEST(m1) AS val1, LATEST(cnt) AS val2, LATEST(m2) AS val3 FROM foo GROUP BY dim2)",
ImmutableList.of(
@ -881,6 +910,40 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
);
}
@Test
public void testPrimitiveLatestInSubqueryGroupBy() throws Exception
{
testQuery(
"SELECT dim2, LATEST(m1) AS val1 FROM foo GROUP BY dim2",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(new DefaultDimensionSpec("dim2", "d0")))
.setAggregatorSpecs(aggregators(
new FloatLastAggregatorFactory("a0", "m1", null)
)
)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
NullHandling.sqlCompatible()
? ImmutableList.of(
new Object[]{null, 6.0f},
new Object[]{"", 3.0f},
new Object[]{"a", 4.0f},
new Object[]{"abc", 5.0f}
)
: ImmutableList.of(
new Object[]{"", 6.0f},
new Object[]{"a", 4.0f},
new Object[]{"abc", 5.0f}
)
);
}
// This test the off-heap (buffer) version of the EarliestAggregator (Double/Float/Long)
@Test
public void testPrimitiveEarliestInSubquery() throws Exception
@ -936,7 +999,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test
public void testStringLatestInSubquery() throws Exception
{
// Cannot vectorize LATEST aggregator.
// Cannot vectorize LATEST aggregator for Strings
skipVectorize();
testQuery(
@ -1176,9 +1239,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test
public void testLatestAggregatorsNumericNull() throws Exception
{
// Cannot vectorize LATEST aggregator.
skipVectorize();
testQuery(
"SELECT LATEST(l1), LATEST(d1), LATEST(f1) FROM druid.numfoo",
ImmutableList.of(
@ -1209,7 +1269,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test
public void testFirstLatestAggregatorsSkipNulls() throws Exception
{
// Cannot vectorize LATEST aggregator.
// Cannot vectorize EARLIEST aggregator.
skipVectorize();
final DimFilter filter;
@ -1465,8 +1525,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test
public void testOrderByLatestFloat() throws Exception
{
// Cannot vectorize LATEST aggregator.
skipVectorize();
List<Object[]> expected;
if (NullHandling.replaceWithDefault()) {
expected = ImmutableList.of(
@ -1513,8 +1572,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test
public void testOrderByLatestDouble() throws Exception
{
// Cannot vectorize LATEST aggregator.
skipVectorize();
List<Object[]> expected;
if (NullHandling.replaceWithDefault()) {
expected = ImmutableList.of(
@ -1560,8 +1617,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
@Test
public void testOrderByLatestLong() throws Exception
{
// Cannot vectorize LATEST aggregator.
skipVectorize();
List<Object[]> expected;
if (NullHandling.replaceWithDefault()) {
expected = ImmutableList.of(