diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java index 323ce413e6f..6b93be7d708 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstLastUtils.java @@ -27,6 +27,8 @@ import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.NilColumnValueSelector; import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.vector.BaseLongVectorValueSelector; +import org.apache.druid.segment.vector.VectorObjectSelector; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -59,6 +61,47 @@ public class StringFirstLastUtils || SerializablePairLongString.class.isAssignableFrom(clazz); } + /** + * Returns whether an object *might* contain SerializablePairLongString objects. + */ + public static boolean objectNeedsFoldCheck(Object obj) + { + if (obj == null) { + return false; + } + final Class clazz = obj.getClass(); + return clazz.isAssignableFrom(SerializablePairLongString.class) + || SerializablePairLongString.class.isAssignableFrom(clazz); + } + + /** + * Return the object at a particular index from the vector selectors. + * index of bounds issues is the responsibility of the caller + */ + public static SerializablePairLongString readPairFromVectorSelectorsAtIndex( + BaseLongVectorValueSelector timeSelector, + VectorObjectSelector valueSelector, + int index + ) + { + final long time; + final String string; + final Object object = valueSelector.getObjectVector()[index]; + if (object instanceof SerializablePairLongString) { + final SerializablePairLongString pair = (SerializablePairLongString) object; + time = pair.lhs; + string = pair.rhs; + } else if (object != null) { + time = timeSelector.getLongVector()[index]; + string = DimensionHandlerUtils.convertObjectToString(object); + } else { + // Don't aggregate nulls. + return null; + } + + return new SerializablePairLongString(time, string); + } + @Nullable public static SerializablePairLongString readPairFromSelectors( final BaseLongColumnValueSelector timeSelector, diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java index 71bf66e6082..39c5b29647c 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastAggregatorFactory.java @@ -31,14 +31,20 @@ import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.SerializablePairLongString; +import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.first.StringFirstAggregatorFactory; import org.apache.druid.query.aggregation.first.StringFirstLastUtils; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.segment.BaseObjectColumnValueSelector; +import org.apache.druid.segment.ColumnInspector; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.NilColumnValueSelector; +import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.vector.BaseLongVectorValueSelector; +import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorObjectSelector; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -141,6 +147,28 @@ public class StringLastAggregatorFactory extends AggregatorFactory } } + @Override + public boolean canVectorize(ColumnInspector columnInspector) + { + return true; + } + + @Override + public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory) + { + + ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName); + VectorObjectSelector vSelector = selectorFactory.makeObjectSelector(fieldName); + BaseLongVectorValueSelector timeSelector = (BaseLongVectorValueSelector) selectorFactory.makeValueSelector( + timeColumn); + if (capabilities != null) { + return new StringLastVectorAggregator(timeSelector, vSelector, maxStringBytes); + } else { + return new StringLastVectorAggregator(null, vSelector, maxStringBytes); + } + + } + @Override public Comparator getComparator() { diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java new file mode 100644 index 00000000000..045360ba616 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregator.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.last; + +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.query.aggregation.SerializablePairLongString; +import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.query.aggregation.first.StringFirstLastUtils; +import org.apache.druid.segment.DimensionHandlerUtils; +import org.apache.druid.segment.vector.BaseLongVectorValueSelector; +import org.apache.druid.segment.vector.VectorObjectSelector; + +import javax.annotation.Nullable; +import java.nio.ByteBuffer; + +public class StringLastVectorAggregator implements VectorAggregator +{ + private static final SerializablePairLongString INIT = new SerializablePairLongString( + DateTimes.MIN.getMillis(), + null + ); + private final BaseLongVectorValueSelector timeSelector; + private final VectorObjectSelector valueSelector; + private final int maxStringBytes; + protected long lastTime; + + public StringLastVectorAggregator( + final BaseLongVectorValueSelector timeSelector, + final VectorObjectSelector valueSelector, + final int maxStringBytes + ) + { + this.timeSelector = timeSelector; + this.valueSelector = valueSelector; + this.maxStringBytes = maxStringBytes; + } + + @Override + public void init(ByteBuffer buf, int position) + { + StringFirstLastUtils.writePair(buf, position, INIT, maxStringBytes); + } + + @Override + public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) + { + if (timeSelector == null) { + return; + } + long[] times = timeSelector.getLongVector(); + Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector(); + + lastTime = buf.getLong(position); + int index; + for (int i = endRow - 1; i >= startRow; i--) { + if (objectsWhichMightBeStrings[i] == null) { + continue; + } + if (times[i] < lastTime) { + break; + } + index = i; + final boolean foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeStrings[index]); + if (foldNeeded) { + // Less efficient code path when folding is a possibility (we must read the value selector first just in case + // it's a foldable object). + final SerializablePairLongString inPair = StringFirstLastUtils.readPairFromVectorSelectorsAtIndex( + timeSelector, + valueSelector, + index + ); + if (inPair != null) { + final long lastTime = buf.getLong(position); + if (inPair.lhs >= lastTime) { + StringFirstLastUtils.writePair( + buf, + position, + new SerializablePairLongString(inPair.lhs, inPair.rhs), + maxStringBytes + ); + } + } + } else { + final long time = times[index]; + + if (time >= lastTime) { + final String value = DimensionHandlerUtils.convertObjectToString(objectsWhichMightBeStrings[index]); + lastTime = time; + StringFirstLastUtils.writePair( + buf, + position, + new SerializablePairLongString(time, value), + maxStringBytes + ); + } + } + } + + } + + @Override + public void aggregate( + ByteBuffer buf, + int numRows, + int[] positions, + @Nullable int[] rows, + int positionOffset + ) + { + long[] timeVector = timeSelector.getLongVector(); + Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector(); + + // iterate once over the object vector to find first non null element and + // determine if the type is Pair or not + boolean foldNeeded = false; + for (Object obj : objectsWhichMightBeStrings) { + if (obj == null) { + continue; + } else { + foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(obj); + break; + } + } + + for (int i = 0; i < numRows; i++) { + int position = positions[i] + positionOffset; + int row = rows == null ? i : rows[i]; + long lastTime = buf.getLong(position); + if (timeVector[row] >= lastTime) { + if (foldNeeded) { + final SerializablePairLongString inPair = StringFirstLastUtils.readPairFromVectorSelectorsAtIndex( + timeSelector, + valueSelector, + row + ); + if (inPair != null) { + if (inPair.lhs >= lastTime) { + StringFirstLastUtils.writePair( + buf, + position, + new SerializablePairLongString(inPair.lhs, inPair.rhs), + maxStringBytes + ); + } + } + } else { + final String value = DimensionHandlerUtils.convertObjectToString(objectsWhichMightBeStrings[row]); + lastTime = timeVector[row]; + StringFirstLastUtils.writePair( + buf, + position, + new SerializablePairLongString(lastTime, value), + maxStringBytes + ); + } + } + } + } + + @Nullable + @Override + public Object get(ByteBuffer buf, int position) + { + return StringFirstLastUtils.readPair(buf, position); + } + + @Override + public void close() + { + // nothing to close + } +} + diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java new file mode 100644 index 00000000000..428ff3e3742 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastVectorAggregatorTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.last; + +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.query.aggregation.SerializablePairLongString; +import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.segment.vector.BaseLongVectorValueSelector; +import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.segment.vector.VectorObjectSelector; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Answers; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import java.nio.ByteBuffer; +import java.util.concurrent.ThreadLocalRandom; + + +@RunWith(MockitoJUnitRunner.class) +public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest +{ + private static final double EPSILON = 1e-5; + private static final String[] VALUES = new String[]{"a", "b", null, "c"}; + private static final boolean[] NULLS = new boolean[]{false, false, true, false}; + private static final String NAME = "NAME"; + private static final String FIELD_NAME = "FIELD_NAME"; + private static final String TIME_COL = "__time"; + private long[] times = {2436, 6879, 7888, 8224}; + private long[] timesSame = {2436, 2436}; + private SerializablePairLongString[] pairs = { + new SerializablePairLongString(2345100L, "last"), + new SerializablePairLongString(2345001L, "notLast") + }; + + @Mock + private VectorObjectSelector selector; + @Mock + private VectorObjectSelector selectorForPairs; + @Mock + private BaseLongVectorValueSelector timeSelector; + @Mock + private BaseLongVectorValueSelector timeSelectorForPairs; + private ByteBuffer buf; + private StringLastVectorAggregator target; + private StringLastVectorAggregator targetWithPairs; + + private StringLastAggregatorFactory stringLastAggregatorFactory; + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private VectorColumnSelectorFactory selectorFactory; + + @Before + public void setup() + { + byte[] randomBytes = new byte[1024]; + ThreadLocalRandom.current().nextBytes(randomBytes); + buf = ByteBuffer.wrap(randomBytes); + Mockito.doReturn(VALUES).when(selector).getObjectVector(); + Mockito.doReturn(times).when(timeSelector).getLongVector(); + Mockito.doReturn(timesSame).when(timeSelectorForPairs).getLongVector(); + Mockito.doReturn(pairs).when(selectorForPairs).getObjectVector(); + target = new StringLastVectorAggregator(timeSelector, selector, 10); + targetWithPairs = new StringLastVectorAggregator(timeSelectorForPairs, selectorForPairs, 10); + clearBufferForPositions(0, 0); + + + Mockito.doReturn(selector).when(selectorFactory).makeObjectSelector(FIELD_NAME); + Mockito.doReturn(timeSelector).when(selectorFactory).makeValueSelector(TIME_COL); + stringLastAggregatorFactory = new StringLastAggregatorFactory(NAME, FIELD_NAME, TIME_COL, 10); + + } + + @Test + public void testAggregateWithPairs() + { + targetWithPairs.aggregate(buf, 0, 0, pairs.length); + Pair result = (Pair) targetWithPairs.get(buf, 0); + //Should come 0 as the last value as the left of the pair is greater + Assert.assertEquals(pairs[0].lhs.longValue(), result.lhs.longValue()); + Assert.assertEquals(pairs[0].rhs, result.rhs); + } + + @Test + public void testFactory() + { + Assert.assertTrue(stringLastAggregatorFactory.canVectorize(selectorFactory)); + VectorAggregator vectorAggregator = stringLastAggregatorFactory.factorizeVector(selectorFactory); + Assert.assertNotNull(vectorAggregator); + Assert.assertEquals(StringLastVectorAggregator.class, vectorAggregator.getClass()); + } + + @Test + public void initValueShouldBeMinDate() + { + target.init(buf, 0); + long initVal = buf.getLong(0); + Assert.assertEquals(DateTimes.MIN.getMillis(), initVal); + } + + @Test + public void aggregate() + { + target.aggregate(buf, 0, 0, VALUES.length); + Pair result = (Pair) target.get(buf, 0); + Assert.assertEquals(times[3], result.lhs.longValue()); + Assert.assertEquals(VALUES[3], result.rhs); + } + + @Test + public void aggregateBatchWithoutRows() + { + int[] positions = new int[]{0, 43, 70}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + target.aggregate(buf, 3, positions, null, positionOffset); + for (int i = 0; i < positions.length; i++) { + Pair result = (Pair) target.get(buf, positions[i] + positionOffset); + Assert.assertEquals(times[i], result.lhs.longValue()); + Assert.assertEquals(VALUES[i], result.rhs); + } + } + + @Test + public void aggregateBatchWithRows() + { + int[] positions = new int[]{0, 43, 70}; + int[] rows = new int[]{3, 2, 0}; + int positionOffset = 2; + clearBufferForPositions(positionOffset, positions); + target.aggregate(buf, 3, positions, rows, positionOffset); + for (int i = 0; i < positions.length; i++) { + Pair result = (Pair) target.get(buf, positions[i] + positionOffset); + Assert.assertEquals(times[rows[i]], result.lhs.longValue()); + Assert.assertEquals(VALUES[rows[i]], result.rhs); + } + } + + private void clearBufferForPositions(int offset, int... positions) + { + for (int position : positions) { + target.init(buf, offset + position); + } + } +} diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 9feb8226797..1af59214f8f 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -683,8 +683,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest @Test public void testLatestAggregators() throws Exception { - // Cannot vectorize until StringLast is vectorized - skipVectorize(); + testQuery( "SELECT " + "LATEST(cnt), LATEST(m1), LATEST(dim1, 10), " @@ -944,6 +943,39 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testStringLatestGroupBy() throws Exception + { + testQuery( + "SELECT dim2, LATEST(dim4,10) AS val1 FROM druid.numfoo GROUP BY dim2", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions(new DefaultDimensionSpec("dim2", "_d0"))) + .setAggregatorSpecs(aggregators( + new StringLastAggregatorFactory("a0", "dim4", null, 10) + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.sqlCompatible() + ? ImmutableList.of( + new Object[]{null, "b"}, + new Object[]{"", "a"}, + new Object[]{"a", "b"}, + new Object[]{"abc", "b"} + ) + : ImmutableList.of( + new Object[]{"", "b"}, + new Object[]{"a", "b"}, + new Object[]{"abc", "b"} + ) + ); + } + // This test the off-heap (buffer) version of the EarliestAggregator (Double/Float/Long) @Test public void testPrimitiveEarliestInSubquery() throws Exception @@ -999,9 +1031,6 @@ public class CalciteQueryTest extends BaseCalciteQueryTest @Test public void testStringLatestInSubquery() throws Exception { - // Cannot vectorize LATEST aggregator for Strings - skipVectorize(); - testQuery( "SELECT SUM(val) FROM (SELECT dim2, LATEST(dim1, 10) AS val FROM foo GROUP BY dim2)", ImmutableList.of(