Fix for latest agg to handle nulls in time column. Also adding optimi… (#14911)

* Fix for latest agg to handle nulls in time column. Also adding optimization for dictionary encoded string columns

* One minor fix

* Adding more tests for the new class

* Changing the init to a putInt
This commit is contained in:
Soumyava 2023-09-13 17:37:26 -07:00 committed by GitHub
parent bf99d2c7b2
commit 5c42ac8c4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 261 additions and 16 deletions

View File

@ -57,7 +57,7 @@ public class SingleStringFirstDimensionVectorAggregator implements VectorAggrega
position + NumericFirstVectorAggregator.NULL_OFFSET,
useDefault ? NullHandling.IS_NOT_NULL_BYTE : NullHandling.IS_NULL_BYTE
);
buf.putLong(position + NumericFirstVectorAggregator.VALUE_OFFSET, 0);
buf.putInt(position + NumericFirstVectorAggregator.VALUE_OFFSET, 0);
}
@Override

View File

@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.SerializablePairLongString;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
public class SingleStringLastDimensionVectorAggregator implements VectorAggregator
{
private final VectorValueSelector timeSelector;
private final SingleValueDimensionVectorSelector valueDimensionVectorSelector;
private long lastTime;
private final int maxStringBytes;
private final boolean useDefault = NullHandling.replaceWithDefault();
public SingleStringLastDimensionVectorAggregator(
VectorValueSelector timeSelector,
SingleValueDimensionVectorSelector valueDimensionVectorSelector,
int maxStringBytes
)
{
this.timeSelector = timeSelector;
this.valueDimensionVectorSelector = valueDimensionVectorSelector;
this.maxStringBytes = maxStringBytes;
this.lastTime = Long.MIN_VALUE;
}
@Override
public void init(ByteBuffer buf, int position)
{
buf.putLong(position, Long.MIN_VALUE);
buf.put(
position + NumericLastVectorAggregator.NULL_OFFSET,
useDefault ? NullHandling.IS_NOT_NULL_BYTE : NullHandling.IS_NULL_BYTE
);
buf.putInt(position + NumericLastVectorAggregator.VALUE_OFFSET, 0);
}
@Override
public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
{
final long[] timeVector = timeSelector.getLongVector();
final boolean[] nullTimeVector = timeSelector.getNullVector();
final int[] valueVector = valueDimensionVectorSelector.getRowVector();
lastTime = buf.getLong(position);
int index;
long latestTime;
for (index = endRow - 1; index >= startRow; index--) {
if (nullTimeVector != null && nullTimeVector[index]) {
continue;
}
latestTime = timeVector[index];
if (latestTime > lastTime) {
lastTime = latestTime;
buf.putLong(position, lastTime);
buf.put(position + NumericLastVectorAggregator.NULL_OFFSET, NullHandling.IS_NOT_NULL_BYTE);
buf.putInt(position + NumericLastVectorAggregator.VALUE_OFFSET, valueVector[index]);
}
}
}
@Override
public void aggregate(ByteBuffer buf, int numRows, int[] positions, @Nullable int[] rows, int positionOffset)
{
final long[] timeVector = timeSelector.getLongVector();
final boolean[] nullTimeVector = timeSelector.getNullVector();
final int[] values = valueDimensionVectorSelector.getRowVector();
for (int i = numRows - 1; i >= 0; i--) {
if (nullTimeVector != null && nullTimeVector[i]) {
continue;
}
int position = positions[i] + positionOffset;
int row = rows == null ? i : rows[i];
lastTime = buf.getLong(position);
if (timeVector[row] > lastTime) {
lastTime = timeVector[row];
buf.putLong(position, lastTime);
buf.put(position + NumericLastVectorAggregator.NULL_OFFSET, NullHandling.IS_NOT_NULL_BYTE);
buf.putInt(position + NumericLastVectorAggregator.VALUE_OFFSET, values[row]);
}
}
}
@Nullable
@Override
public Object get(ByteBuffer buf, int position)
{
int index = buf.getInt(position + NumericLastVectorAggregator.VALUE_OFFSET);
long earliest = buf.getLong(position);
String strValue = valueDimensionVectorSelector.lookupName(index);
return new SerializablePairLongString(earliest, StringUtils.chop(strValue, maxStringBytes));
}
@Override
public void close()
{
// nothing to close
}
}

View File

@ -35,6 +35,7 @@ import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.first.StringFirstAggregatorFactory;
import org.apache.druid.query.aggregation.first.StringFirstLastUtils;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.segment.BaseObjectColumnValueSelector;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
@ -43,6 +44,8 @@ import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorObjectSelector;
import org.apache.druid.segment.vector.VectorValueSelector;
@ -160,6 +163,7 @@ public class StringLastAggregatorFactory extends AggregatorFactory
final ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector timeSelector = selectorFactory.makeValueSelector(timeColumn);
if (Types.isNumeric(capabilities)) {
VectorValueSelector valueSelector = selectorFactory.makeValueSelector(fieldName);
VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject(
@ -171,6 +175,18 @@ public class StringLastAggregatorFactory extends AggregatorFactory
);
return new StringLastVectorAggregator(timeSelector, objectSelector, maxStringBytes);
}
if (capabilities != null) {
if (capabilities.is(ValueType.STRING) && capabilities.isDictionaryEncoded().isTrue()) {
if (!capabilities.hasMultipleValues().isTrue()) {
SingleValueDimensionVectorSelector sSelector = selectorFactory.makeSingleValueDimensionSelector(
DefaultDimensionSpec.of(
fieldName));
return new SingleStringLastDimensionVectorAggregator(timeSelector, sSelector, maxStringBytes);
}
}
}
VectorObjectSelector vSelector = selectorFactory.makeObjectSelector(fieldName);
if (capabilities != null) {
return new StringLastVectorAggregator(timeSelector, vSelector, maxStringBytes);
@ -296,9 +312,9 @@ public class StringLastAggregatorFactory extends AggregatorFactory
}
StringLastAggregatorFactory that = (StringLastAggregatorFactory) o;
return maxStringBytes == that.maxStringBytes &&
Objects.equals(fieldName, that.fieldName) &&
Objects.equals(timeColumn, that.timeColumn) &&
Objects.equals(name, that.name);
Objects.equals(fieldName, that.fieldName) &&
Objects.equals(timeColumn, that.timeColumn) &&
Objects.equals(name, that.name);
}
@Override

View File

@ -64,8 +64,9 @@ public class StringLastVectorAggregator implements VectorAggregator
if (timeSelector == null) {
return;
}
long[] times = timeSelector.getLongVector();
Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector();
final long[] times = timeSelector.getLongVector();
final boolean[] nullTimeVector = timeSelector.getNullVector();
final Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector();
lastTime = buf.getLong(position);
int index;
@ -76,6 +77,9 @@ public class StringLastVectorAggregator implements VectorAggregator
if (times[i] <= lastTime) {
continue;
}
if (nullTimeVector != null && nullTimeVector[i]) {
continue;
}
index = i;
final boolean foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeStrings[index]);
if (foldNeeded) {
@ -127,22 +131,24 @@ public class StringLastVectorAggregator implements VectorAggregator
if (timeSelector == null) {
return;
}
long[] timeVector = timeSelector.getLongVector();
Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector();
final long[] timeVector = timeSelector.getLongVector();
final boolean[] nullTimeVector = timeSelector.getNullVector();
final Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector();
// iterate once over the object vector to find first non null element and
// determine if the type is Pair or not
boolean foldNeeded = false;
for (Object obj : objectsWhichMightBeStrings) {
if (obj == null) {
continue;
} else {
if (obj != null) {
foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(obj);
break;
}
}
for (int i = 0; i < numRows; i++) {
if (nullTimeVector != null && nullTimeVector[i]) {
continue;
}
int position = positions[i] + positionOffset;
int row = rows == null ? i : rows[i];
long lastTime = buf.getLong(position);

View File

@ -23,7 +23,9 @@ import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.query.aggregation.SerializablePairLongString;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.IdLookup;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ColumnType;
@ -49,11 +51,13 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
{
private static final double EPSILON = 1e-5;
private static final String[] VALUES = new String[]{"a", "b", null, "c"};
private static final int[] DICT_VALUES = new int[]{1, 2, 0, 3};
private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L};
private static final String[] STRING_VALUES = new String[]{"1", "2", "3", "4"};
private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f};
private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0};
private static final boolean[] NULLS = new boolean[]{false, false, true, false};
private static final boolean[] NULLS1 = new boolean[]{false, false};
private static final String NAME = "NAME";
private static final String FIELD_NAME = "FIELD_NAME";
private static final String FIELD_NAME_LONG = "LONG_NAME";
@ -74,6 +78,7 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
private StringLastAggregatorFactory stringLastAggregatorFactory;
private StringLastAggregatorFactory stringLastAggregatorFactory1;
private SingleStringLastDimensionVectorAggregator targetSingleDim;
private VectorColumnSelectorFactory selectorFactory;
@ -96,7 +101,7 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
@Override
public boolean[] getNullVector()
{
return NULLS;
return null;
}
};
nonStringValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(
@ -163,9 +168,9 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
}
};
BaseLongVectorValueSelector timeSelectorForPairs = new BaseLongVectorValueSelector(new NoFilterVectorOffset(
times.length,
timesSame.length,
0,
times.length
timesSame.length
))
{
@Override
@ -178,7 +183,7 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
@Override
public boolean[] getNullVector()
{
return new boolean[0];
return NULLS1;
}
};
VectorObjectSelector selectorForPairs = new VectorObjectSelector()
@ -212,7 +217,61 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
@Override
public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec)
{
return null;
return new SingleValueDimensionVectorSelector()
{
@Override
public int[] getRowVector()
{
return DICT_VALUES;
}
@Override
public int getValueCardinality()
{
return DICT_VALUES.length;
}
@Nullable
@Override
public String lookupName(int id)
{
switch (id) {
case 1:
return "a";
case 2:
return "b";
case 3:
return "c";
default:
return null;
}
}
@Override
public boolean nameLookupPossibleInAdvance()
{
return false;
}
@Nullable
@Override
public IdLookup idLookup()
{
return null;
}
@Override
public int getMaxVectorSize()
{
return DICT_VALUES.length;
}
@Override
public int getCurrentVectorSize()
{
return DICT_VALUES.length;
}
};
}
@Override
@ -257,6 +316,8 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
target = new StringLastVectorAggregator(timeSelector, selector, 10);
targetWithPairs = new StringLastVectorAggregator(timeSelectorForPairs, selectorForPairs, 10);
targetSingleDim = new SingleStringLastDimensionVectorAggregator(timeSelector, selectorFactory.makeSingleValueDimensionSelector(
DefaultDimensionSpec.of(FIELD_NAME)), 10);
clearBufferForPositions(0, 0);
@ -361,6 +422,44 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
}
}
@Test
public void aggregateSingleDim()
{
targetSingleDim.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, String> result = (Pair<Long, String>) targetSingleDim.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs);
}
@Test
public void aggregateBatchWithoutRowsSingleDim()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
targetSingleDim.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, String> result = (Pair<Long, String>) targetSingleDim.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[i], result.lhs.longValue());
Assert.assertEquals(VALUES[i], result.rhs);
}
}
@Test
public void aggregateBatchWithRowsSingleDim()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
targetSingleDim.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, String> result = (Pair<Long, String>) targetSingleDim.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[rows[i]], result.lhs.longValue());
Assert.assertEquals(VALUES[rows[i]], result.rhs);
}
}
private void clearBufferForPositions(int offset, int... positions)
{
for (int position : positions) {