mirror of
https://github.com/apache/druid.git
synced 2025-02-06 01:58:20 +00:00
Vectorizing earliest string aggregator
This commit is contained in:
parent
2b556f6b19
commit
59118ae885
@ -32,12 +32,18 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.AggregatorUtil;
|
||||
import org.apache.druid.query.aggregation.BufferAggregator;
|
||||
import org.apache.druid.query.aggregation.SerializablePairLongString;
|
||||
import org.apache.druid.query.aggregation.VectorAggregator;
|
||||
import org.apache.druid.query.cache.CacheKeyBuilder;
|
||||
import org.apache.druid.segment.BaseObjectColumnValueSelector;
|
||||
import org.apache.druid.segment.ColumnInspector;
|
||||
import org.apache.druid.segment.ColumnSelectorFactory;
|
||||
import org.apache.druid.segment.NilColumnValueSelector;
|
||||
import org.apache.druid.segment.column.ColumnCapabilities;
|
||||
import org.apache.druid.segment.column.ColumnHolder;
|
||||
import org.apache.druid.segment.column.ColumnType;
|
||||
import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
|
||||
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
|
||||
import org.apache.druid.segment.vector.VectorObjectSelector;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.nio.ByteBuffer;
|
||||
@ -154,6 +160,26 @@ public class StringFirstAggregatorFactory extends AggregatorFactory
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory)
|
||||
{
|
||||
ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName);
|
||||
VectorObjectSelector vSelector = selectorFactory.makeObjectSelector(fieldName);
|
||||
BaseLongVectorValueSelector timeSelector = (BaseLongVectorValueSelector) selectorFactory.makeValueSelector(
|
||||
timeColumn);
|
||||
if (capabilities != null) {
|
||||
return new StringFirstVectorAggregator(timeSelector, vSelector, maxStringBytes);
|
||||
} else {
|
||||
return new StringFirstVectorAggregator(null, vSelector, maxStringBytes);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean canVectorize(ColumnInspector columnInspector)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Comparator getComparator()
|
||||
{
|
||||
|
@ -0,0 +1,175 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.druid.query.aggregation.first;
|
||||
|
||||
import org.apache.druid.java.util.common.DateTimes;
|
||||
import org.apache.druid.query.aggregation.SerializablePairLongString;
|
||||
import org.apache.druid.query.aggregation.VectorAggregator;
|
||||
import org.apache.druid.segment.DimensionHandlerUtils;
|
||||
import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
|
||||
import org.apache.druid.segment.vector.VectorObjectSelector;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
public class StringFirstVectorAggregator implements VectorAggregator
|
||||
{
|
||||
private static final SerializablePairLongString INIT = new SerializablePairLongString(
|
||||
DateTimes.MAX.getMillis(),
|
||||
null
|
||||
);
|
||||
private final BaseLongVectorValueSelector timeSelector;
|
||||
private final VectorObjectSelector valueSelector;
|
||||
private final int maxStringBytes;
|
||||
protected long firstTime;
|
||||
|
||||
public StringFirstVectorAggregator(
|
||||
BaseLongVectorValueSelector timeSelector,
|
||||
VectorObjectSelector valueSelector,
|
||||
int maxStringBytes
|
||||
) {
|
||||
this.timeSelector = timeSelector;
|
||||
this.valueSelector = valueSelector;
|
||||
this.maxStringBytes = maxStringBytes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void init(ByteBuffer buf, int position)
|
||||
{
|
||||
StringFirstLastUtils.writePair(buf, position, INIT, maxStringBytes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
|
||||
{
|
||||
if (timeSelector == null) {
|
||||
return;
|
||||
}
|
||||
long[] times = timeSelector.getLongVector();
|
||||
Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector();
|
||||
firstTime = buf.getLong(position);
|
||||
int index;
|
||||
for(int i=startRow; i<endRow; i++) {
|
||||
if (times[i] > firstTime) {
|
||||
break;
|
||||
}
|
||||
index = i;
|
||||
final boolean foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeStrings[index]);
|
||||
if (foldNeeded) {
|
||||
final SerializablePairLongString inPair = StringFirstLastUtils.readPairFromVectorSelectorsAtIndex(
|
||||
timeSelector,
|
||||
valueSelector,
|
||||
index
|
||||
);
|
||||
if (inPair != null) {
|
||||
final long firstTime = buf.getLong(position);
|
||||
if (inPair.lhs < firstTime) {
|
||||
StringFirstLastUtils.writePair(
|
||||
buf,
|
||||
position,
|
||||
new SerializablePairLongString(inPair.lhs, inPair.rhs),
|
||||
maxStringBytes
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
final long time = times[index];
|
||||
if (time < firstTime) {
|
||||
final String value = DimensionHandlerUtils.convertObjectToString(objectsWhichMightBeStrings[index]);
|
||||
firstTime = time;
|
||||
StringFirstLastUtils.writePair(
|
||||
buf,
|
||||
position,
|
||||
new SerializablePairLongString(time, value),
|
||||
maxStringBytes
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void aggregate(ByteBuffer buf, int numRows, int[] positions, @Nullable int[] rows, int positionOffset)
|
||||
{
|
||||
long[] timeVector = timeSelector.getLongVector();
|
||||
Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector();
|
||||
|
||||
// iterate once over the object vector to find first non null element and
|
||||
// determine if the type is Pair or not
|
||||
boolean foldNeeded = false;
|
||||
for (Object obj : objectsWhichMightBeStrings) {
|
||||
if (obj == null) {
|
||||
continue;
|
||||
} else {
|
||||
foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(obj);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < numRows; i++) {
|
||||
int position = positions[i] + positionOffset;
|
||||
int row = rows == null ? i : rows[i];
|
||||
long firstTime = buf.getLong(position);
|
||||
if (timeVector[row] < firstTime) {
|
||||
if (foldNeeded) {
|
||||
final SerializablePairLongString inPair = StringFirstLastUtils.readPairFromVectorSelectorsAtIndex(
|
||||
timeSelector,
|
||||
valueSelector,
|
||||
row
|
||||
);
|
||||
if (inPair != null) {
|
||||
if (inPair.lhs < firstTime) {
|
||||
StringFirstLastUtils.writePair(
|
||||
buf,
|
||||
position,
|
||||
new SerializablePairLongString(inPair.lhs, inPair.rhs),
|
||||
maxStringBytes
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
final String value = DimensionHandlerUtils.convertObjectToString(objectsWhichMightBeStrings[row]);
|
||||
firstTime = timeVector[row];
|
||||
StringFirstLastUtils.writePair(
|
||||
buf,
|
||||
position,
|
||||
new SerializablePairLongString(firstTime, value),
|
||||
maxStringBytes
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public Object get(ByteBuffer buf, int position)
|
||||
{
|
||||
return StringFirstLastUtils.readPair(buf, position);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close()
|
||||
{
|
||||
// nothing to close
|
||||
}
|
||||
}
|
@ -0,0 +1,167 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.apache.druid.query.aggregation.first;
|
||||
|
||||
import org.apache.druid.java.util.common.DateTimes;
|
||||
import org.apache.druid.java.util.common.Pair;
|
||||
import org.apache.druid.query.aggregation.SerializablePairLongString;
|
||||
import org.apache.druid.query.aggregation.VectorAggregator;
|
||||
import org.apache.druid.segment.vector.BaseLongVectorValueSelector;
|
||||
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
|
||||
import org.apache.druid.segment.vector.VectorObjectSelector;
|
||||
import org.apache.druid.testing.InitializedNullHandlingTest;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Answers;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.Mockito;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class StringFirstVectorAggregatorTest extends InitializedNullHandlingTest
|
||||
{
|
||||
private static final double EPSILON = 1e-5;
|
||||
private static final String[] VALUES = new String[]{"a", "b", null, "c"};
|
||||
private static final boolean[] NULLS = new boolean[]{false, false, true, false};
|
||||
private static final String NAME = "NAME";
|
||||
private static final String FIELD_NAME = "FIELD_NAME";
|
||||
private static final String TIME_COL = "__time";
|
||||
private long[] times = {2436, 6879, 7888, 8224};
|
||||
private long[] timesSame = {2436, 2436};
|
||||
private SerializablePairLongString[] pairs = {
|
||||
new SerializablePairLongString(2345001L, "first"),
|
||||
new SerializablePairLongString(2345100L, "notFirst")
|
||||
};
|
||||
|
||||
@Mock
|
||||
private VectorObjectSelector selector;
|
||||
@Mock
|
||||
private VectorObjectSelector selectorForPairs;
|
||||
@Mock
|
||||
private BaseLongVectorValueSelector timeSelector;
|
||||
@Mock
|
||||
private BaseLongVectorValueSelector timeSelectorForPairs;
|
||||
private ByteBuffer buf;
|
||||
private StringFirstVectorAggregator target;
|
||||
private StringFirstVectorAggregator targetWithPairs;
|
||||
|
||||
private StringFirstAggregatorFactory stringFirstAggregatorFactory;
|
||||
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
|
||||
private VectorColumnSelectorFactory selectorFactory;
|
||||
|
||||
@Before
|
||||
public void setup()
|
||||
{
|
||||
byte[] randomBytes = new byte[1024];
|
||||
ThreadLocalRandom.current().nextBytes(randomBytes);
|
||||
buf = ByteBuffer.wrap(randomBytes);
|
||||
Mockito.doReturn(VALUES).when(selector).getObjectVector();
|
||||
Mockito.doReturn(times).when(timeSelector).getLongVector();
|
||||
Mockito.doReturn(timesSame).when(timeSelectorForPairs).getLongVector();
|
||||
Mockito.doReturn(pairs).when(selectorForPairs).getObjectVector();
|
||||
target = new StringFirstVectorAggregator(timeSelector, selector, 10);
|
||||
targetWithPairs = new StringFirstVectorAggregator(timeSelectorForPairs, selectorForPairs, 10);
|
||||
clearBufferForPositions(0, 0);
|
||||
|
||||
|
||||
Mockito.doReturn(selector).when(selectorFactory).makeObjectSelector(FIELD_NAME);
|
||||
Mockito.doReturn(timeSelector).when(selectorFactory).makeValueSelector(TIME_COL);
|
||||
stringFirstAggregatorFactory = new StringFirstAggregatorFactory(NAME, FIELD_NAME, TIME_COL, 10);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAggregateWithPairs()
|
||||
{
|
||||
targetWithPairs.aggregate(buf, 0, 0, pairs.length);
|
||||
Pair<Long, String> result = (Pair<Long, String>) targetWithPairs.get(buf, 0);
|
||||
//Should come 0 as the last value as the left of the pair is greater
|
||||
Assert.assertEquals(pairs[0].lhs.longValue(), result.lhs.longValue());
|
||||
Assert.assertEquals(pairs[0].rhs, result.rhs);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFactory()
|
||||
{
|
||||
Assert.assertTrue(stringFirstAggregatorFactory.canVectorize(selectorFactory));
|
||||
VectorAggregator vectorAggregator = stringFirstAggregatorFactory.factorizeVector(selectorFactory);
|
||||
Assert.assertNotNull(vectorAggregator);
|
||||
Assert.assertEquals(StringFirstVectorAggregator.class, vectorAggregator.getClass());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void initValueShouldBeMaxDate()
|
||||
{
|
||||
target.init(buf, 0);
|
||||
long initVal = buf.getLong(0);
|
||||
Assert.assertEquals(DateTimes.MAX.getMillis(), initVal);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void aggregate()
|
||||
{
|
||||
target.aggregate(buf, 0, 0, VALUES.length);
|
||||
Pair<Long, String> result = (Pair<Long, String>) target.get(buf, 0);
|
||||
Assert.assertEquals(times[0], result.lhs.longValue());
|
||||
Assert.assertEquals(VALUES[0], result.rhs);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void aggregateBatchWithoutRows()
|
||||
{
|
||||
int[] positions = new int[]{0, 43, 70};
|
||||
int positionOffset = 2;
|
||||
clearBufferForPositions(positionOffset, positions);
|
||||
target.aggregate(buf, 3, positions, null, positionOffset);
|
||||
for (int i = 0; i < positions.length; i++) {
|
||||
Pair<Long, String> result = (Pair<Long, String>) target.get(buf, positions[i] + positionOffset);
|
||||
Assert.assertEquals(times[i], result.lhs.longValue());
|
||||
Assert.assertEquals(VALUES[i], result.rhs);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void aggregateBatchWithRows()
|
||||
{
|
||||
int[] positions = new int[]{0, 43, 70};
|
||||
int[] rows = new int[]{3, 2, 0};
|
||||
int positionOffset = 2;
|
||||
clearBufferForPositions(positionOffset, positions);
|
||||
target.aggregate(buf, 3, positions, rows, positionOffset);
|
||||
for (int i = 0; i < positions.length; i++) {
|
||||
Pair<Long, String> result = (Pair<Long, String>) target.get(buf, positions[i] + positionOffset);
|
||||
Assert.assertEquals(times[rows[i]], result.lhs.longValue());
|
||||
Assert.assertEquals(VALUES[rows[i]], result.rhs);
|
||||
}
|
||||
}
|
||||
|
||||
private void clearBufferForPositions(int offset, int... positions)
|
||||
{
|
||||
for (int position : positions) {
|
||||
target.init(buf, offset + position);
|
||||
}
|
||||
}
|
||||
}
|
@ -198,12 +198,13 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
|
||||
final String fieldName = getColumnName(plannerContext, virtualColumnRegistry, args.get(0), rexNodes.get(0));
|
||||
|
||||
if (!rowSignature.contains(ColumnHolder.TIME_COLUMN_NAME) && (aggregatorType == AggregatorType.LATEST || aggregatorType == AggregatorType.EARLIEST)) {
|
||||
throw new ISE("%s() aggregator depends on __time column, the underlying datasource "
|
||||
plannerContext.setPlanningError("%s() aggregator depends on __time column, the underlying datasource "
|
||||
+ "or extern function you are querying doesn't contain __time column, "
|
||||
+ "Please use %s_BY() and specify the time column you want to use",
|
||||
aggregatorType.name(),
|
||||
aggregatorType.name()
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
final AggregatorFactory theAggFactory;
|
||||
|
@ -638,8 +638,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
||||
public void testEarliestAggregators()
|
||||
{
|
||||
notMsqCompatible();
|
||||
// Cannot vectorize EARLIEST aggregator.
|
||||
skipVectorize();
|
||||
|
||||
testQuery(
|
||||
"SELECT "
|
||||
@ -1071,8 +1069,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
||||
public void testPrimitiveEarliestInSubquery()
|
||||
{
|
||||
notMsqCompatible();
|
||||
// Cannot vectorize EARLIEST aggregator.
|
||||
skipVectorize();
|
||||
|
||||
testQuery(
|
||||
"SELECT SUM(val1), SUM(val2), SUM(val3) FROM (SELECT dim2, EARLIEST(m1) AS val1, EARLIEST(cnt) AS val2, EARLIEST(m2) AS val3 FROM foo GROUP BY dim2)",
|
||||
@ -1170,9 +1166,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
||||
@Test
|
||||
public void testStringEarliestInSubquery()
|
||||
{
|
||||
// Cannot vectorize EARLIEST aggregator.
|
||||
skipVectorize();
|
||||
|
||||
testQuery(
|
||||
"SELECT SUM(val) FROM (SELECT dim2, EARLIEST(dim1, 10) AS val FROM foo GROUP BY dim2)",
|
||||
ImmutableList.of(
|
||||
@ -1424,8 +1417,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
||||
public void testFirstLatestAggregatorsSkipNulls()
|
||||
{
|
||||
notMsqCompatible();
|
||||
// Cannot vectorize EARLIEST aggregator.
|
||||
skipVectorize();
|
||||
|
||||
final DimFilter filter;
|
||||
if (useDefault) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user