Fix for latest agg to handle nulls in time column. Also adding optimi… (#14911)

* Fix for latest agg to handle nulls in time column. Also adding optimization for dictionary encoded string columns

* One minor fix

* Adding more tests for the new class

* Changing the init to a putInt
This commit is contained in:
Soumyava 2023-09-13 17:37:26 -07:00 committed by GitHub
parent bf99d2c7b2
commit 5c42ac8c4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 261 additions and 16 deletions

View File

@ -57,7 +57,7 @@ public class SingleStringFirstDimensionVectorAggregator implements VectorAggrega
position + NumericFirstVectorAggregator.NULL_OFFSET, position + NumericFirstVectorAggregator.NULL_OFFSET,
useDefault ? NullHandling.IS_NOT_NULL_BYTE : NullHandling.IS_NULL_BYTE useDefault ? NullHandling.IS_NOT_NULL_BYTE : NullHandling.IS_NULL_BYTE
); );
buf.putLong(position + NumericFirstVectorAggregator.VALUE_OFFSET, 0); buf.putInt(position + NumericFirstVectorAggregator.VALUE_OFFSET, 0);
} }
@Override @Override

View File

@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.last;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.SerializablePairLongString;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
import org.apache.druid.segment.vector.VectorValueSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
public class SingleStringLastDimensionVectorAggregator implements VectorAggregator
{
private final VectorValueSelector timeSelector;
private final SingleValueDimensionVectorSelector valueDimensionVectorSelector;
private long lastTime;
private final int maxStringBytes;
private final boolean useDefault = NullHandling.replaceWithDefault();
public SingleStringLastDimensionVectorAggregator(
VectorValueSelector timeSelector,
SingleValueDimensionVectorSelector valueDimensionVectorSelector,
int maxStringBytes
)
{
this.timeSelector = timeSelector;
this.valueDimensionVectorSelector = valueDimensionVectorSelector;
this.maxStringBytes = maxStringBytes;
this.lastTime = Long.MIN_VALUE;
}
@Override
public void init(ByteBuffer buf, int position)
{
buf.putLong(position, Long.MIN_VALUE);
buf.put(
position + NumericLastVectorAggregator.NULL_OFFSET,
useDefault ? NullHandling.IS_NOT_NULL_BYTE : NullHandling.IS_NULL_BYTE
);
buf.putInt(position + NumericLastVectorAggregator.VALUE_OFFSET, 0);
}
@Override
public void aggregate(ByteBuffer buf, int position, int startRow, int endRow)
{
final long[] timeVector = timeSelector.getLongVector();
final boolean[] nullTimeVector = timeSelector.getNullVector();
final int[] valueVector = valueDimensionVectorSelector.getRowVector();
lastTime = buf.getLong(position);
int index;
long latestTime;
for (index = endRow - 1; index >= startRow; index--) {
if (nullTimeVector != null && nullTimeVector[index]) {
continue;
}
latestTime = timeVector[index];
if (latestTime > lastTime) {
lastTime = latestTime;
buf.putLong(position, lastTime);
buf.put(position + NumericLastVectorAggregator.NULL_OFFSET, NullHandling.IS_NOT_NULL_BYTE);
buf.putInt(position + NumericLastVectorAggregator.VALUE_OFFSET, valueVector[index]);
}
}
}
@Override
public void aggregate(ByteBuffer buf, int numRows, int[] positions, @Nullable int[] rows, int positionOffset)
{
final long[] timeVector = timeSelector.getLongVector();
final boolean[] nullTimeVector = timeSelector.getNullVector();
final int[] values = valueDimensionVectorSelector.getRowVector();
for (int i = numRows - 1; i >= 0; i--) {
if (nullTimeVector != null && nullTimeVector[i]) {
continue;
}
int position = positions[i] + positionOffset;
int row = rows == null ? i : rows[i];
lastTime = buf.getLong(position);
if (timeVector[row] > lastTime) {
lastTime = timeVector[row];
buf.putLong(position, lastTime);
buf.put(position + NumericLastVectorAggregator.NULL_OFFSET, NullHandling.IS_NOT_NULL_BYTE);
buf.putInt(position + NumericLastVectorAggregator.VALUE_OFFSET, values[row]);
}
}
}
@Nullable
@Override
public Object get(ByteBuffer buf, int position)
{
int index = buf.getInt(position + NumericLastVectorAggregator.VALUE_OFFSET);
long earliest = buf.getLong(position);
String strValue = valueDimensionVectorSelector.lookupName(index);
return new SerializablePairLongString(earliest, StringUtils.chop(strValue, maxStringBytes));
}
@Override
public void close()
{
// nothing to close
}
}

View File

@ -35,6 +35,7 @@ import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.aggregation.first.StringFirstAggregatorFactory; import org.apache.druid.query.aggregation.first.StringFirstAggregatorFactory;
import org.apache.druid.query.aggregation.first.StringFirstLastUtils; import org.apache.druid.query.aggregation.first.StringFirstLastUtils;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.segment.BaseObjectColumnValueSelector; import org.apache.druid.segment.BaseObjectColumnValueSelector;
import org.apache.druid.segment.ColumnInspector; import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnSelectorFactory;
@ -43,6 +44,8 @@ import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types; import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorObjectSelector; import org.apache.druid.segment.vector.VectorObjectSelector;
import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.segment.vector.VectorValueSelector;
@ -160,6 +163,7 @@ public class StringLastAggregatorFactory extends AggregatorFactory
final ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName); final ColumnCapabilities capabilities = selectorFactory.getColumnCapabilities(fieldName);
VectorValueSelector timeSelector = selectorFactory.makeValueSelector(timeColumn); VectorValueSelector timeSelector = selectorFactory.makeValueSelector(timeColumn);
if (Types.isNumeric(capabilities)) { if (Types.isNumeric(capabilities)) {
VectorValueSelector valueSelector = selectorFactory.makeValueSelector(fieldName); VectorValueSelector valueSelector = selectorFactory.makeValueSelector(fieldName);
VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject( VectorObjectSelector objectSelector = ExpressionVectorSelectors.castValueSelectorToObject(
@ -171,6 +175,18 @@ public class StringLastAggregatorFactory extends AggregatorFactory
); );
return new StringLastVectorAggregator(timeSelector, objectSelector, maxStringBytes); return new StringLastVectorAggregator(timeSelector, objectSelector, maxStringBytes);
} }
if (capabilities != null) {
if (capabilities.is(ValueType.STRING) && capabilities.isDictionaryEncoded().isTrue()) {
if (!capabilities.hasMultipleValues().isTrue()) {
SingleValueDimensionVectorSelector sSelector = selectorFactory.makeSingleValueDimensionSelector(
DefaultDimensionSpec.of(
fieldName));
return new SingleStringLastDimensionVectorAggregator(timeSelector, sSelector, maxStringBytes);
}
}
}
VectorObjectSelector vSelector = selectorFactory.makeObjectSelector(fieldName); VectorObjectSelector vSelector = selectorFactory.makeObjectSelector(fieldName);
if (capabilities != null) { if (capabilities != null) {
return new StringLastVectorAggregator(timeSelector, vSelector, maxStringBytes); return new StringLastVectorAggregator(timeSelector, vSelector, maxStringBytes);
@ -296,9 +312,9 @@ public class StringLastAggregatorFactory extends AggregatorFactory
} }
StringLastAggregatorFactory that = (StringLastAggregatorFactory) o; StringLastAggregatorFactory that = (StringLastAggregatorFactory) o;
return maxStringBytes == that.maxStringBytes && return maxStringBytes == that.maxStringBytes &&
Objects.equals(fieldName, that.fieldName) && Objects.equals(fieldName, that.fieldName) &&
Objects.equals(timeColumn, that.timeColumn) && Objects.equals(timeColumn, that.timeColumn) &&
Objects.equals(name, that.name); Objects.equals(name, that.name);
} }
@Override @Override

View File

@ -64,8 +64,9 @@ public class StringLastVectorAggregator implements VectorAggregator
if (timeSelector == null) { if (timeSelector == null) {
return; return;
} }
long[] times = timeSelector.getLongVector(); final long[] times = timeSelector.getLongVector();
Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector(); final boolean[] nullTimeVector = timeSelector.getNullVector();
final Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector();
lastTime = buf.getLong(position); lastTime = buf.getLong(position);
int index; int index;
@ -76,6 +77,9 @@ public class StringLastVectorAggregator implements VectorAggregator
if (times[i] <= lastTime) { if (times[i] <= lastTime) {
continue; continue;
} }
if (nullTimeVector != null && nullTimeVector[i]) {
continue;
}
index = i; index = i;
final boolean foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeStrings[index]); final boolean foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(objectsWhichMightBeStrings[index]);
if (foldNeeded) { if (foldNeeded) {
@ -127,22 +131,24 @@ public class StringLastVectorAggregator implements VectorAggregator
if (timeSelector == null) { if (timeSelector == null) {
return; return;
} }
long[] timeVector = timeSelector.getLongVector(); final long[] timeVector = timeSelector.getLongVector();
Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector(); final boolean[] nullTimeVector = timeSelector.getNullVector();
final Object[] objectsWhichMightBeStrings = valueSelector.getObjectVector();
// iterate once over the object vector to find first non null element and // iterate once over the object vector to find first non null element and
// determine if the type is Pair or not // determine if the type is Pair or not
boolean foldNeeded = false; boolean foldNeeded = false;
for (Object obj : objectsWhichMightBeStrings) { for (Object obj : objectsWhichMightBeStrings) {
if (obj == null) { if (obj != null) {
continue;
} else {
foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(obj); foldNeeded = StringFirstLastUtils.objectNeedsFoldCheck(obj);
break; break;
} }
} }
for (int i = 0; i < numRows; i++) { for (int i = 0; i < numRows; i++) {
if (nullTimeVector != null && nullTimeVector[i]) {
continue;
}
int position = positions[i] + positionOffset; int position = positions[i] + positionOffset;
int row = rows == null ? i : rows[i]; int row = rows == null ? i : rows[i];
long lastTime = buf.getLong(position); long lastTime = buf.getLong(position);

View File

@ -23,7 +23,9 @@ import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.query.aggregation.SerializablePairLongString; import org.apache.druid.query.aggregation.SerializablePairLongString;
import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.IdLookup;
import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl; import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
@ -49,11 +51,13 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
{ {
private static final double EPSILON = 1e-5; private static final double EPSILON = 1e-5;
private static final String[] VALUES = new String[]{"a", "b", null, "c"}; private static final String[] VALUES = new String[]{"a", "b", null, "c"};
private static final int[] DICT_VALUES = new int[]{1, 2, 0, 3};
private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L}; private static final long[] LONG_VALUES = new long[]{1L, 2L, 3L, 4L};
private static final String[] STRING_VALUES = new String[]{"1", "2", "3", "4"}; private static final String[] STRING_VALUES = new String[]{"1", "2", "3", "4"};
private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f}; private static final float[] FLOAT_VALUES = new float[]{1.0f, 2.0f, 3.0f, 4.0f};
private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0}; private static final double[] DOUBLE_VALUES = new double[]{1.0, 2.0, 3.0, 4.0};
private static final boolean[] NULLS = new boolean[]{false, false, true, false}; private static final boolean[] NULLS = new boolean[]{false, false, true, false};
private static final boolean[] NULLS1 = new boolean[]{false, false};
private static final String NAME = "NAME"; private static final String NAME = "NAME";
private static final String FIELD_NAME = "FIELD_NAME"; private static final String FIELD_NAME = "FIELD_NAME";
private static final String FIELD_NAME_LONG = "LONG_NAME"; private static final String FIELD_NAME_LONG = "LONG_NAME";
@ -74,6 +78,7 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
private StringLastAggregatorFactory stringLastAggregatorFactory; private StringLastAggregatorFactory stringLastAggregatorFactory;
private StringLastAggregatorFactory stringLastAggregatorFactory1; private StringLastAggregatorFactory stringLastAggregatorFactory1;
private SingleStringLastDimensionVectorAggregator targetSingleDim;
private VectorColumnSelectorFactory selectorFactory; private VectorColumnSelectorFactory selectorFactory;
@ -96,7 +101,7 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
@Override @Override
public boolean[] getNullVector() public boolean[] getNullVector()
{ {
return NULLS; return null;
} }
}; };
nonStringValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset( nonStringValueSelector = new BaseLongVectorValueSelector(new NoFilterVectorOffset(
@ -163,9 +168,9 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
} }
}; };
BaseLongVectorValueSelector timeSelectorForPairs = new BaseLongVectorValueSelector(new NoFilterVectorOffset( BaseLongVectorValueSelector timeSelectorForPairs = new BaseLongVectorValueSelector(new NoFilterVectorOffset(
times.length, timesSame.length,
0, 0,
times.length timesSame.length
)) ))
{ {
@Override @Override
@ -178,7 +183,7 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
@Override @Override
public boolean[] getNullVector() public boolean[] getNullVector()
{ {
return new boolean[0]; return NULLS1;
} }
}; };
VectorObjectSelector selectorForPairs = new VectorObjectSelector() VectorObjectSelector selectorForPairs = new VectorObjectSelector()
@ -212,7 +217,61 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
@Override @Override
public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec) public SingleValueDimensionVectorSelector makeSingleValueDimensionSelector(DimensionSpec dimensionSpec)
{ {
return null; return new SingleValueDimensionVectorSelector()
{
@Override
public int[] getRowVector()
{
return DICT_VALUES;
}
@Override
public int getValueCardinality()
{
return DICT_VALUES.length;
}
@Nullable
@Override
public String lookupName(int id)
{
switch (id) {
case 1:
return "a";
case 2:
return "b";
case 3:
return "c";
default:
return null;
}
}
@Override
public boolean nameLookupPossibleInAdvance()
{
return false;
}
@Nullable
@Override
public IdLookup idLookup()
{
return null;
}
@Override
public int getMaxVectorSize()
{
return DICT_VALUES.length;
}
@Override
public int getCurrentVectorSize()
{
return DICT_VALUES.length;
}
};
} }
@Override @Override
@ -257,6 +316,8 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
target = new StringLastVectorAggregator(timeSelector, selector, 10); target = new StringLastVectorAggregator(timeSelector, selector, 10);
targetWithPairs = new StringLastVectorAggregator(timeSelectorForPairs, selectorForPairs, 10); targetWithPairs = new StringLastVectorAggregator(timeSelectorForPairs, selectorForPairs, 10);
targetSingleDim = new SingleStringLastDimensionVectorAggregator(timeSelector, selectorFactory.makeSingleValueDimensionSelector(
DefaultDimensionSpec.of(FIELD_NAME)), 10);
clearBufferForPositions(0, 0); clearBufferForPositions(0, 0);
@ -361,6 +422,44 @@ public class StringLastVectorAggregatorTest extends InitializedNullHandlingTest
} }
} }
@Test
public void aggregateSingleDim()
{
targetSingleDim.aggregate(buf, 0, 0, VALUES.length);
Pair<Long, String> result = (Pair<Long, String>) targetSingleDim.get(buf, 0);
Assert.assertEquals(times[3], result.lhs.longValue());
Assert.assertEquals(VALUES[3], result.rhs);
}
@Test
public void aggregateBatchWithoutRowsSingleDim()
{
int[] positions = new int[]{0, 43, 70};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
targetSingleDim.aggregate(buf, 3, positions, null, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, String> result = (Pair<Long, String>) targetSingleDim.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[i], result.lhs.longValue());
Assert.assertEquals(VALUES[i], result.rhs);
}
}
@Test
public void aggregateBatchWithRowsSingleDim()
{
int[] positions = new int[]{0, 43, 70};
int[] rows = new int[]{3, 2, 0};
int positionOffset = 2;
clearBufferForPositions(positionOffset, positions);
targetSingleDim.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
Pair<Long, String> result = (Pair<Long, String>) targetSingleDim.get(buf, positions[i] + positionOffset);
Assert.assertEquals(times[rows[i]], result.lhs.longValue());
Assert.assertEquals(VALUES[rows[i]], result.rhs);
}
}
private void clearBufferForPositions(int offset, int... positions) private void clearBufferForPositions(int offset, int... positions)
{ {
for (int position : positions) { for (int position : positions) {