Enabling aggregateMultipleValues in all StringAnyAggregators (#15434)

* Enabling aggregateMultipleValues in all StringAnyAggregators

* Adding more tests

* More validation

* fix warning

* updating asserts in decoupled mode

* fix intellij inspection

* Addressing comments

* Addressing comments

* Adding early validations and make aggregate consistent across all

* fixing tests

* fixing tests

* Update docs/querying/sql-aggregations.md

Co-authored-by: Clint Wylie <cjwylie@gmail.com>

* fixing static check

---------

Co-authored-by: Clint Wylie <cjwylie@gmail.com>
This commit is contained in:
Pranav 2023-11-29 14:32:49 -08:00 committed by GitHub
parent 64fcb32bcf
commit 93cd638645
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 493 additions and 145 deletions

View File

@ -90,7 +90,7 @@ In the aggregation functions supported by Druid, only `COUNT`, `ARRAY_AGG`, and
|`EARLIEST_BY(expr, timestampExpr, [maxBytesPerValue])`|Returns the earliest value of `expr`.<br />The earliest value of `expr` is taken from the row with the overall earliest non-null value of `timestampExpr`. <br />If the earliest non-null value of `timestampExpr` appears in multiple rows, the `expr` may be taken from any of those rows.<br /><br />If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.<br/>If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)| |`EARLIEST_BY(expr, timestampExpr, [maxBytesPerValue])`|Returns the earliest value of `expr`.<br />The earliest value of `expr` is taken from the row with the overall earliest non-null value of `timestampExpr`. <br />If the earliest non-null value of `timestampExpr` appears in multiple rows, the `expr` may be taken from any of those rows.<br /><br />If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.<br/>If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
|`LATEST(expr, [maxBytesPerValue])`|Returns the latest value of `expr`<br />The `expr` must come from a relation with a timestamp column (like `__time` in a Druid datasource) and the "latest" is taken from the row with the overall latest non-null value of the timestamp column.<br />If the latest non-null value of the timestamp column appears in multiple rows, the `expr` may be taken from any of those rows.<br /><br />If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.<br/>If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)| |`LATEST(expr, [maxBytesPerValue])`|Returns the latest value of `expr`<br />The `expr` must come from a relation with a timestamp column (like `__time` in a Druid datasource) and the "latest" is taken from the row with the overall latest non-null value of the timestamp column.<br />If the latest non-null value of the timestamp column appears in multiple rows, the `expr` may be taken from any of those rows.<br /><br />If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.<br/>If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
|`LATEST_BY(expr, timestampExpr, [maxBytesPerValue])`|Returns the latest value of `expr`.<br />The latest value of `expr` is taken from the row with the overall latest non-null value of `timestampExpr`.<br />If the overall latest non-null value of `timestampExpr` appears in multiple rows, the `expr` may be taken from any of those rows.<br /><br />If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.<br/>If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)| |`LATEST_BY(expr, timestampExpr, [maxBytesPerValue])`|Returns the latest value of `expr`.<br />The latest value of `expr` is taken from the row with the overall latest non-null value of `timestampExpr`.<br />If the overall latest non-null value of `timestampExpr` appears in multiple rows, the `expr` may be taken from any of those rows.<br /><br />If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.<br/>If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
|`ANY_VALUE(expr, [maxBytesPerValue])`|Returns any value of `expr` including null. This aggregator can simplify and optimize the performance by returning the first encountered value (including `null`).<br /><br />If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.<br/>If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)| |`ANY_VALUE(expr, [maxBytesPerValue, [aggregateMultipleValues]])`|Returns any value of `expr` including null. This aggregator can simplify and optimize the performance by returning the first encountered value (including `null`).<br /><br />If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.<br/>If `maxBytesPerValue` is omitted; it defaults to `1024`. `aggregateMultipleValues` is an optional boolean flag controls the behavior of aggregating a [multi-value dimension](./multi-value-dimensions.md). `aggregateMultipleValues` is set as true by default and returns the stringified array in case of a multi-value dimension. By setting it to false, function will return first value instead. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
|`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy dimension is included in a row, when using `GROUPING SETS`. Refer to [additional documentation](aggregations.md#grouping-aggregator) on how to infer this number.|N/A| |`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy dimension is included in a row, when using `GROUPING SETS`. Refer to [additional documentation](aggregations.md#grouping-aggregator) on how to infer this number.|N/A|
|`ARRAY_AGG(expr, [size])`|Collects all values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes). If the aggregated array grows larger than the maximum size in bytes, the query will fail. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the output array may vary depending on processing order.|`null`| |`ARRAY_AGG(expr, [size])`|Collects all values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes). If the aggregated array grows larger than the maximum size in bytes, the query will fail. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the output array may vary depending on processing order.|`null`|
|`ARRAY_AGG(DISTINCT expr, [size])`|Collects all distinct values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes) per aggregate. If the aggregated array grows larger than the maximum size in bytes, the query will fail. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results will be based on the default for the element type.|`null`| |`ARRAY_AGG(DISTINCT expr, [size])`|Collects all distinct values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes) per aggregate. If the aggregated array grows larger than the maximum size in bytes, the query will fail. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results will be based on the default for the element type.|`null`|

View File

@ -50,7 +50,7 @@ Calculates the arc cosine of a numeric expression.
## ANY_VALUE ## ANY_VALUE
`ANY_VALUE(expr, [maxBytesPerValue])` `ANY_VALUE(expr, [maxBytesPerValue, [aggregateMultipleValues]])`
**Function type:** [Aggregation](sql-aggregations.md) **Function type:** [Aggregation](sql-aggregations.md)

View File

@ -24,19 +24,23 @@ import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.segment.BaseObjectColumnValueSelector; import org.apache.druid.segment.BaseObjectColumnValueSelector;
import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.DimensionHandlerUtils;
import java.util.List;
public class StringAnyAggregator implements Aggregator public class StringAnyAggregator implements Aggregator
{ {
private final BaseObjectColumnValueSelector valueSelector; private final BaseObjectColumnValueSelector valueSelector;
private final int maxStringBytes; private final int maxStringBytes;
private boolean isFound; private boolean isFound;
private String foundValue; private String foundValue;
private final boolean aggregateMultipleValues;
public StringAnyAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes) public StringAnyAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes, boolean aggregateMultipleValues)
{ {
this.valueSelector = valueSelector; this.valueSelector = valueSelector;
this.maxStringBytes = maxStringBytes; this.maxStringBytes = maxStringBytes;
this.foundValue = null; this.foundValue = null;
this.isFound = false; this.isFound = false;
this.aggregateMultipleValues = aggregateMultipleValues;
} }
@Override @Override
@ -44,18 +48,36 @@ public class StringAnyAggregator implements Aggregator
{ {
if (!isFound) { if (!isFound) {
final Object object = valueSelector.getObject(); final Object object = valueSelector.getObject();
foundValue = DimensionHandlerUtils.convertObjectToString(object); foundValue = StringUtils.fastLooseChop(readValue(object), maxStringBytes);
if (foundValue != null && foundValue.length() > maxStringBytes) {
foundValue = foundValue.substring(0, maxStringBytes);
}
isFound = true; isFound = true;
} }
} }
private String readValue(final Object object)
{
if (object == null) {
return null;
}
if (object instanceof List) {
List<Object> objectList = (List) object;
if (objectList.size() == 0) {
return null;
}
if (objectList.size() == 1) {
return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
}
if (aggregateMultipleValues) {
return DimensionHandlerUtils.convertObjectToString(objectList);
}
return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
}
return DimensionHandlerUtils.convertObjectToString(object);
}
@Override @Override
public Object get() public Object get()
{ {
return StringUtils.chop(foundValue, maxStringBytes); return foundValue;
} }
@Override @Override

View File

@ -48,13 +48,15 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
private final String fieldName; private final String fieldName;
private final String name; private final String name;
protected final int maxStringBytes; private final int maxStringBytes;
private final boolean aggregateMultipleValues;
@JsonCreator @JsonCreator
public StringAnyAggregatorFactory( public StringAnyAggregatorFactory(
@JsonProperty("name") String name, @JsonProperty("name") String name,
@JsonProperty("fieldName") final String fieldName, @JsonProperty("fieldName") final String fieldName,
@JsonProperty("maxStringBytes") Integer maxStringBytes @JsonProperty("maxStringBytes") Integer maxStringBytes,
@JsonProperty("aggregateMultipleValues") @Nullable final Boolean aggregateMultipleValues
) )
{ {
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name"); Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
@ -67,18 +69,19 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
this.maxStringBytes = maxStringBytes == null this.maxStringBytes = maxStringBytes == null
? StringFirstAggregatorFactory.DEFAULT_MAX_STRING_SIZE ? StringFirstAggregatorFactory.DEFAULT_MAX_STRING_SIZE
: maxStringBytes; : maxStringBytes;
this.aggregateMultipleValues = aggregateMultipleValues == null ? true : aggregateMultipleValues;
} }
@Override @Override
public Aggregator factorize(ColumnSelectorFactory metricFactory) public Aggregator factorize(ColumnSelectorFactory metricFactory)
{ {
return new StringAnyAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes); return new StringAnyAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes, aggregateMultipleValues);
} }
@Override @Override
public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
{ {
return new StringAnyBufferAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes); return new StringAnyBufferAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes, aggregateMultipleValues);
} }
@Override @Override
@ -90,13 +93,15 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
return new StringAnyVectorAggregator( return new StringAnyVectorAggregator(
null, null,
selectorFactory.makeMultiValueDimensionSelector(DefaultDimensionSpec.of(fieldName)), selectorFactory.makeMultiValueDimensionSelector(DefaultDimensionSpec.of(fieldName)),
maxStringBytes maxStringBytes,
aggregateMultipleValues
); );
} else { } else {
return new StringAnyVectorAggregator( return new StringAnyVectorAggregator(
selectorFactory.makeSingleValueDimensionSelector(DefaultDimensionSpec.of(fieldName)), selectorFactory.makeSingleValueDimensionSelector(DefaultDimensionSpec.of(fieldName)),
null, null,
maxStringBytes maxStringBytes,
aggregateMultipleValues
); );
} }
} }
@ -122,7 +127,7 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
@Override @Override
public AggregatorFactory getCombiningFactory() public AggregatorFactory getCombiningFactory()
{ {
return new StringAnyAggregatorFactory(name, name, maxStringBytes); return new StringAnyAggregatorFactory(name, name, maxStringBytes, aggregateMultipleValues);
} }
@Override @Override
@ -155,6 +160,11 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
{ {
return maxStringBytes; return maxStringBytes;
} }
@JsonProperty
public boolean getAggregateMultipleValues()
{
return aggregateMultipleValues;
}
@Override @Override
public List<String> requiredFields() public List<String> requiredFields()
@ -192,7 +202,7 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
@Override @Override
public AggregatorFactory withName(String newName) public AggregatorFactory withName(String newName)
{ {
return new StringAnyAggregatorFactory(newName, getFieldName(), getMaxStringBytes()); return new StringAnyAggregatorFactory(newName, getFieldName(), getMaxStringBytes(), getAggregateMultipleValues());
} }
@Override @Override

View File

@ -25,6 +25,7 @@ import org.apache.druid.segment.BaseObjectColumnValueSelector;
import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.DimensionHandlerUtils;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.List;
public class StringAnyBufferAggregator implements BufferAggregator public class StringAnyBufferAggregator implements BufferAggregator
{ {
@ -34,11 +35,13 @@ public class StringAnyBufferAggregator implements BufferAggregator
private final BaseObjectColumnValueSelector valueSelector; private final BaseObjectColumnValueSelector valueSelector;
private final int maxStringBytes; private final int maxStringBytes;
private final boolean aggregateMultipleValues;
public StringAnyBufferAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes) public StringAnyBufferAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes, boolean aggregateMultipleValues)
{ {
this.valueSelector = valueSelector; this.valueSelector = valueSelector;
this.maxStringBytes = maxStringBytes; this.maxStringBytes = maxStringBytes;
this.aggregateMultipleValues = aggregateMultipleValues;
} }
@Override @Override
@ -51,8 +54,7 @@ public class StringAnyBufferAggregator implements BufferAggregator
public void aggregate(ByteBuffer buf, int position) public void aggregate(ByteBuffer buf, int position)
{ {
if (buf.getInt(position) == NOT_FOUND_FLAG_VALUE) { if (buf.getInt(position) == NOT_FOUND_FLAG_VALUE) {
final Object object = valueSelector.getObject(); String foundValue = readValue(valueSelector.getObject());
String foundValue = DimensionHandlerUtils.convertObjectToString(object);
if (foundValue != null) { if (foundValue != null) {
ByteBuffer mutationBuffer = buf.duplicate(); ByteBuffer mutationBuffer = buf.duplicate();
mutationBuffer.position(position + FOUND_VALUE_OFFSET); mutationBuffer.position(position + FOUND_VALUE_OFFSET);
@ -65,6 +67,27 @@ public class StringAnyBufferAggregator implements BufferAggregator
} }
} }
private String readValue(Object object)
{
if (object == null) {
return null;
}
if (object instanceof List) {
List<Object> objectList = (List) object;
if (objectList.size() == 0) {
return null;
}
if (objectList.size() == 1) {
return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
}
if (aggregateMultipleValues) {
return DimensionHandlerUtils.convertObjectToString(objectList);
}
return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
}
return DimensionHandlerUtils.convertObjectToString(object);
}
@Override @Override
public Object get(ByteBuffer buf, int position) public Object get(ByteBuffer buf, int position)
{ {

View File

@ -23,12 +23,15 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.DimensionHandlerUtils;
import org.apache.druid.segment.data.IndexedInts; import org.apache.druid.segment.data.IndexedInts;
import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector; import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
public class StringAnyVectorAggregator implements VectorAggregator public class StringAnyVectorAggregator implements VectorAggregator
{ {
@ -43,11 +46,13 @@ public class StringAnyVectorAggregator implements VectorAggregator
@Nullable @Nullable
private final MultiValueDimensionVectorSelector multiValueSelector; private final MultiValueDimensionVectorSelector multiValueSelector;
private final int maxStringBytes; private final int maxStringBytes;
private final boolean aggregateMultipleValues;
public StringAnyVectorAggregator( public StringAnyVectorAggregator(
SingleValueDimensionVectorSelector singleValueSelector, SingleValueDimensionVectorSelector singleValueSelector,
MultiValueDimensionVectorSelector multiValueSelector, MultiValueDimensionVectorSelector multiValueSelector,
int maxStringBytes int maxStringBytes,
final boolean aggregateMultipleValues
) )
{ {
Preconditions.checkState( Preconditions.checkState(
@ -61,6 +66,7 @@ public class StringAnyVectorAggregator implements VectorAggregator
this.multiValueSelector = multiValueSelector; this.multiValueSelector = multiValueSelector;
this.singleValueSelector = singleValueSelector; this.singleValueSelector = singleValueSelector;
this.maxStringBytes = maxStringBytes; this.maxStringBytes = maxStringBytes;
this.aggregateMultipleValues = aggregateMultipleValues;
} }
@Override @Override
@ -78,7 +84,7 @@ public class StringAnyVectorAggregator implements VectorAggregator
if (startRow < rows.length) { if (startRow < rows.length) {
IndexedInts row = rows[startRow]; IndexedInts row = rows[startRow];
@Nullable @Nullable
String foundValue = row.size() == 0 ? null : multiValueSelector.lookupName(row.get(0)); String foundValue = readValue(row);
putValue(buf, position, foundValue); putValue(buf, position, foundValue);
} }
} else if (singleValueSelector != null) { } else if (singleValueSelector != null) {
@ -93,6 +99,24 @@ public class StringAnyVectorAggregator implements VectorAggregator
} }
} }
private String readValue(IndexedInts row)
{
if (row.size() == 0) {
return null;
}
if (aggregateMultipleValues) {
if (row.size() == 1) {
return multiValueSelector.lookupName(row.get(0));
}
List<String> arrayList = new ArrayList<>();
row.forEach(rowIndex -> {
arrayList.add(multiValueSelector.lookupName(rowIndex));
});
return DimensionHandlerUtils.convertObjectToString(arrayList);
}
return multiValueSelector.lookupName(row.get(0));
}
@Override @Override
public void aggregate( public void aggregate(
ByteBuffer buf, ByteBuffer buf,
@ -142,4 +166,9 @@ public class StringAnyVectorAggregator implements VectorAggregator
buf.putInt(position, FOUND_AND_NULL_FLAG_VALUE); buf.putInt(position, FOUND_AND_NULL_FLAG_VALUE);
} }
} }
public boolean isAggregateMultipleValues()
{
return aggregateMultipleValues;
}
} }

View File

@ -148,7 +148,7 @@ public class AggregatorFactoryTest extends InitializedNullHandlingTest
// string aggregators // string aggregators
new StringFirstAggregatorFactory("stringFirst", "col", null, 1024), new StringFirstAggregatorFactory("stringFirst", "col", null, 1024),
new StringLastAggregatorFactory("stringLast", "col", null, 1024), new StringLastAggregatorFactory("stringLast", "col", null, 1024),
new StringAnyAggregatorFactory("stringAny", "col", 1024), new StringAnyAggregatorFactory("stringAny", "col", 1024, true),
// sketch aggs // sketch aggs
new CardinalityAggregatorFactory("cardinality", ImmutableList.of(DefaultDimensionSpec.of("some-col")), false), new CardinalityAggregatorFactory("cardinality", ImmutableList.of(DefaultDimensionSpec.of("some-col")), false),
new HyperUniquesAggregatorFactory("hyperUnique", "hyperunique"), new HyperUniquesAggregatorFactory("hyperUnique", "hyperunique"),
@ -307,7 +307,8 @@ public class AggregatorFactoryTest extends InitializedNullHandlingTest
// string aggregators // string aggregators
new StringFirstAggregatorFactory("col", "col", null, 1024), new StringFirstAggregatorFactory("col", "col", null, 1024),
new StringLastAggregatorFactory("col", "col", null, 1024), new StringLastAggregatorFactory("col", "col", null, 1024),
new StringAnyAggregatorFactory("col", "col", 1024), new StringAnyAggregatorFactory("col", "col", 1024, true),
new StringAnyAggregatorFactory("col", "col", 1024, false),
// sketch aggs // sketch aggs
new CardinalityAggregatorFactory("col", ImmutableList.of(DefaultDimensionSpec.of("some-col")), false), new CardinalityAggregatorFactory("col", ImmutableList.of(DefaultDimensionSpec.of("some-col")), false),
new HyperUniquesAggregatorFactory("col", "hyperunique"), new HyperUniquesAggregatorFactory("col", "hyperunique"),

View File

@ -49,7 +49,7 @@ public class StringAnyAggregationTest
@Before @Before
public void setup() public void setup()
{ {
stringAnyAggFactory = new StringAnyAggregatorFactory("billy", "nilly", MAX_STRING_SIZE); stringAnyAggFactory = new StringAnyAggregatorFactory("billy", "nilly", MAX_STRING_SIZE, true);
combiningAggFactory = stringAnyAggFactory.getCombiningFactory(); combiningAggFactory = stringAnyAggFactory.getCombiningFactory();
valueSelector = new TestObjectColumnSelector<>(strings); valueSelector = new TestObjectColumnSelector<>(strings);
objectSelector = new TestObjectColumnSelector<>(strings); objectSelector = new TestObjectColumnSelector<>(strings);

View File

@ -19,48 +19,44 @@
package org.apache.druid.query.aggregation.any; package org.apache.druid.query.aggregation.any;
import org.apache.druid.segment.ColumnInspector; import com.google.common.collect.Lists;
import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.TestObjectColumnSelector;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.DimensionSelector;
import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector; import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; import org.apache.druid.segment.vector.TestVectorColumnSelectorFactory;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import org.apache.druid.segment.virtual.FallbackVirtualColumnTest;
import org.apache.druid.testing.InitializedNullHandlingTest; import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import static org.mockito.ArgumentMatchers.any; import java.util.List;
@RunWith(MockitoJUnitRunner.class)
public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
{ {
private static final String NAME = "NAME"; private static final String NAME = "NAME";
private static final String FIELD_NAME = "FIELD_NAME"; private static final String FIELD_NAME = "FIELD_NAME";
private static final int MAX_STRING_BYTES = 10; private static final int MAX_STRING_BYTES = 10;
@Mock private TestColumnSelectorFactory columnInspector;
private ColumnInspector columnInspector;
@Mock
private ColumnCapabilities capabilities; private ColumnCapabilities capabilities;
@Mock private TestVectorColumnSelectorFactory vectorSelectorFactory;
private VectorColumnSelectorFactory vectorSelectorFactory;
@Mock
private SingleValueDimensionVectorSelector singleValueDimensionVectorSelector;
@Mock
private MultiValueDimensionVectorSelector multiValueDimensionVectorSelector;
private StringAnyAggregatorFactory target; private StringAnyAggregatorFactory target;
@Before @Before
public void setUp() public void setUp()
{ {
Mockito.doReturn(capabilities).when(vectorSelectorFactory).getColumnCapabilities(FIELD_NAME); target = new StringAnyAggregatorFactory(NAME, FIELD_NAME, MAX_STRING_BYTES, true);
Mockito.doReturn(ColumnCapabilities.Capable.UNKNOWN).when(capabilities).hasMultipleValues(); columnInspector = new TestColumnSelectorFactory();
target = new StringAnyAggregatorFactory(NAME, FIELD_NAME, MAX_STRING_BYTES); vectorSelectorFactory = new TestVectorColumnSelectorFactory();
capabilities = ColumnCapabilitiesImpl.createDefault().setHasMultipleValues(true);
vectorSelectorFactory.addCapabilities(FIELD_NAME, capabilities);
vectorSelectorFactory.addMVDVS(FIELD_NAME, new FallbackVirtualColumnTest.SameMultiVectorSelector());
} }
@Test @Test
@ -72,10 +68,6 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
@Test @Test
public void factorizeVectorWithoutCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector() public void factorizeVectorWithoutCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector()
{ {
Mockito.doReturn(null).when(vectorSelectorFactory).getColumnCapabilities(FIELD_NAME);
Mockito.doReturn(singleValueDimensionVectorSelector)
.when(vectorSelectorFactory)
.makeSingleValueDimensionSelector(any());
StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory); StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory);
Assert.assertNotNull(aggregator); Assert.assertNotNull(aggregator);
} }
@ -83,9 +75,6 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
@Test @Test
public void factorizeVectorWithUnknownCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector() public void factorizeVectorWithUnknownCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector()
{ {
Mockito.doReturn(multiValueDimensionVectorSelector)
.when(vectorSelectorFactory)
.makeMultiValueDimensionSelector(any());
StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory); StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory);
Assert.assertNotNull(aggregator); Assert.assertNotNull(aggregator);
} }
@ -93,10 +82,6 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
@Test @Test
public void factorizeVectorWithMultipleValuesCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector() public void factorizeVectorWithMultipleValuesCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector()
{ {
Mockito.doReturn(ColumnCapabilities.Capable.TRUE).when(capabilities).hasMultipleValues();
Mockito.doReturn(multiValueDimensionVectorSelector)
.when(vectorSelectorFactory)
.makeMultiValueDimensionSelector(any());
StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory); StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory);
Assert.assertNotNull(aggregator); Assert.assertNotNull(aggregator);
} }
@ -104,11 +89,86 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
@Test @Test
public void factorizeVectorWithoutMultipleValuesCapabilitiesShouldReturnAggregatorWithSingleDimensionSelector() public void factorizeVectorWithoutMultipleValuesCapabilitiesShouldReturnAggregatorWithSingleDimensionSelector()
{ {
Mockito.doReturn(ColumnCapabilities.Capable.FALSE).when(capabilities).hasMultipleValues();
Mockito.doReturn(singleValueDimensionVectorSelector)
.when(vectorSelectorFactory)
.makeSingleValueDimensionSelector(any());
StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory); StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory);
Assert.assertNotNull(aggregator); Assert.assertNotNull(aggregator);
} }
@Test
public void testFactorize()
{
Aggregator res = target.factorize(new TestColumnSelectorFactory());
Assert.assertTrue(res instanceof StringAnyAggregator);
res.aggregate();
Assert.assertEquals(null, res.get());
StringAnyVectorAggregator vectorAggregator = target.factorizeVector(vectorSelectorFactory);
Assert.assertTrue(vectorAggregator.isAggregateMultipleValues());
}
@Test
public void testSvdStringAnyAggregator()
{
TestColumnSelectorFactory columnSelectorFactory = new TestColumnSelectorFactory();
Aggregator res = target.factorize(columnSelectorFactory);
Assert.assertTrue(res instanceof StringAnyAggregator);
columnSelectorFactory.moveSelectorCursorToNext();
res.aggregate();
Assert.assertEquals("CCCC", res.get());
}
@Test
public void testMvdStringAnyAggregator()
{
TestColumnSelectorFactory columnSelectorFactory = new TestColumnSelectorFactory();
Aggregator res = target.factorize(columnSelectorFactory);
Assert.assertTrue(res instanceof StringAnyAggregator);
columnSelectorFactory.moveSelectorCursorToNext();
columnSelectorFactory.moveSelectorCursorToNext();
res.aggregate();
Assert.assertEquals("[AAAA, AAA", res.get());
}
@Test
public void testMvdStringAnyAggregatorWithAggregateMultipleToFalse()
{
StringAnyAggregatorFactory target = new StringAnyAggregatorFactory(NAME, FIELD_NAME, MAX_STRING_BYTES, false);
TestColumnSelectorFactory columnSelectorFactory = new TestColumnSelectorFactory();
Aggregator res = target.factorize(columnSelectorFactory);
Assert.assertTrue(res instanceof StringAnyAggregator);
columnSelectorFactory.moveSelectorCursorToNext();
columnSelectorFactory.moveSelectorCursorToNext();
res.aggregate();
// picks up first value in mvd list
Assert.assertEquals("AAAA", res.get());
}
static class TestColumnSelectorFactory implements ColumnSelectorFactory
{
List<String> mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC");
final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"};
Integer maxStringBytes = 1024;
TestObjectColumnSelector<Object> objectColumnSelector = new TestObjectColumnSelector<>(mvds);
@Override
public DimensionSelector makeDimensionSelector(DimensionSpec dimensionSpec)
{
return null;
}
@Override
public ColumnValueSelector<?> makeColumnValueSelector(String columnName)
{
return objectColumnSelector;
}
@Override
public ColumnCapabilities getColumnCapabilities(String columnName)
{
return ColumnCapabilitiesImpl.createDefault().setHasMultipleValues(true);
}
public void moveSelectorCursorToNext()
{
objectColumnSelector.increment();
}
}
} }

View File

@ -19,15 +19,22 @@
package org.apache.druid.query.aggregation.any; package org.apache.druid.query.aggregation.any;
import com.google.common.collect.Lists;
import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.TestObjectColumnSelector; import org.apache.druid.query.aggregation.TestObjectColumnSelector;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
public class StringAnyBufferAggregatorTest public class StringAnyBufferAggregatorTest
{ {
StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
"billy", "billy", 1024, true
);
private void aggregateBuffer( private void aggregateBuffer(
TestObjectColumnSelector valueSelector, TestObjectColumnSelector valueSelector,
BufferAggregator agg, BufferAggregator agg,
@ -44,17 +51,14 @@ public class StringAnyBufferAggregatorTest
{ {
final String[] strings = {"AAAA", "BBBB", "CCCC", "DDDD", "EEEE"}; final String[] strings = {"AAAA", "BBBB", "CCCC", "DDDD", "EEEE"};
Integer maxStringBytes = 1024; int maxStringBytes = 1024;
TestObjectColumnSelector<String> objectColumnSelector = new TestObjectColumnSelector<>(strings); TestObjectColumnSelector<String> objectColumnSelector = new TestObjectColumnSelector<>(strings);
StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
"billy", "billy", maxStringBytes
);
StringAnyBufferAggregator agg = new StringAnyBufferAggregator( StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector, objectColumnSelector,
maxStringBytes maxStringBytes,
true
); );
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize()); ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@ -75,17 +79,15 @@ public class StringAnyBufferAggregatorTest
public void testBufferAggregateWithFoldCheck() public void testBufferAggregateWithFoldCheck()
{ {
final String[] strings = {"AAAA", "BBBB", "CCCC", "DDDD", "EEEE"}; final String[] strings = {"AAAA", "BBBB", "CCCC", "DDDD", "EEEE"};
Integer maxStringBytes = 1024; int maxStringBytes = 1024;
TestObjectColumnSelector<String> objectColumnSelector = new TestObjectColumnSelector<>(strings); TestObjectColumnSelector<String> objectColumnSelector = new TestObjectColumnSelector<>(strings);
StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
"billy", "billy", maxStringBytes
);
StringAnyBufferAggregator agg = new StringAnyBufferAggregator( StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector, objectColumnSelector,
maxStringBytes maxStringBytes,
true
); );
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize()); ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@ -108,17 +110,14 @@ public class StringAnyBufferAggregatorTest
{ {
final String[] strings = {"CCCC", "AAAA", "BBBB", null, "EEEE"}; final String[] strings = {"CCCC", "AAAA", "BBBB", null, "EEEE"};
Integer maxStringBytes = 1024; int maxStringBytes = 1024;
TestObjectColumnSelector<String> objectColumnSelector = new TestObjectColumnSelector<>(strings); TestObjectColumnSelector<String> objectColumnSelector = new TestObjectColumnSelector<>(strings);
StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
"billy", "billy", maxStringBytes
);
StringAnyBufferAggregator agg = new StringAnyBufferAggregator( StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector, objectColumnSelector,
maxStringBytes maxStringBytes,
true
); );
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize()); ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@ -140,17 +139,13 @@ public class StringAnyBufferAggregatorTest
{ {
final String[] strings = {null, "CCCC", "AAAA", "BBBB", "EEEE"}; final String[] strings = {null, "CCCC", "AAAA", "BBBB", "EEEE"};
Integer maxStringBytes = 1024; int maxStringBytes = 1024;
TestObjectColumnSelector<String> objectColumnSelector = new TestObjectColumnSelector<>(strings); TestObjectColumnSelector<String> objectColumnSelector = new TestObjectColumnSelector<>(strings);
StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
"billy", "billy", maxStringBytes
);
StringAnyBufferAggregator agg = new StringAnyBufferAggregator( StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector, objectColumnSelector,
maxStringBytes maxStringBytes, true
); );
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize()); ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@ -170,19 +165,15 @@ public class StringAnyBufferAggregatorTest
@Test @Test
public void testNonStringValue() public void testNonStringValue()
{ {
final Double[] doubles = {1.00, 2.00}; final Double[] doubles = {1.00, 2.00};
Integer maxStringBytes = 1024; int maxStringBytes = 1024;
TestObjectColumnSelector<Double> objectColumnSelector = new TestObjectColumnSelector<>(doubles); TestObjectColumnSelector<Double> objectColumnSelector = new TestObjectColumnSelector<>(doubles);
StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
"billy", "billy", maxStringBytes
);
StringAnyBufferAggregator agg = new StringAnyBufferAggregator( StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector, objectColumnSelector,
maxStringBytes maxStringBytes,
true
); );
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize()); ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@ -198,4 +189,77 @@ public class StringAnyBufferAggregatorTest
Assert.assertEquals("1.0", result); Assert.assertEquals("1.0", result);
} }
@Test
public void testMvds()
{
List<String> mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC");
final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"};
int maxStringBytes = 1024;
TestObjectColumnSelector<Object> objectColumnSelector = new TestObjectColumnSelector<>(mvds);
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector,
maxStringBytes, true
);
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize() * 2);
int position = 0;
int[] positions = new int[]{0, 1, 43, 100, 189};
Arrays.stream(positions).forEach(i -> agg.init(buf, i));
//noinspection ForLoopReplaceableByForEach
for (int i = 0; i < mvds.length; i++) {
aggregateBuffer(objectColumnSelector, agg, buf, positions[i]);
}
String result = ((String) agg.get(buf, position));
Assert.assertNull(result);
for (int i = 0; i < positions.length; i++) {
if (i == 2) {
Assert.assertEquals(mvd.toString(), agg.get(buf, positions[2]));
} else {
Assert.assertEquals(mvds[i], agg.get(buf, positions[i]));
}
}
}
@Test
public void testMvdsWithCustomAggregate()
{
List<String> mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC");
final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"};
final int maxStringBytes = 1024;
TestObjectColumnSelector<Object> objectColumnSelector = new TestObjectColumnSelector<>(mvds);
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector,
maxStringBytes, false
);
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize() * 2);
int position = 0;
int[] positions = new int[]{0, 1, 43, 100, 189};
Arrays.stream(positions).forEach(i -> agg.init(buf, i));
//noinspection ForLoopReplaceableByForEach
for (int i = 0; i < mvds.length; i++) {
aggregateBuffer(objectColumnSelector, agg, buf, positions[i]);
}
String result = ((String) agg.get(buf, position));
Assert.assertNull(result);
for (int i = 0; i < positions.length; i++) {
if (i == 2) {
// takes first in case of mvds
Assert.assertEquals(mvd.get(0), agg.get(buf, positions[2]));
} else {
Assert.assertEquals(mvds[i], agg.get(buf, positions[i]));
}
}
}
} }

View File

@ -33,6 +33,8 @@ import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadLocalRandom;
import static org.apache.druid.query.aggregation.any.StringAnyVectorAggregator.NOT_FOUND_FLAG_VALUE; import static org.apache.druid.query.aggregation.any.StringAnyVectorAggregator.NOT_FOUND_FLAG_VALUE;
@ -61,6 +63,7 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
private StringAnyVectorAggregator singleValueTarget; private StringAnyVectorAggregator singleValueTarget;
private StringAnyVectorAggregator multiValueTarget; private StringAnyVectorAggregator multiValueTarget;
private StringAnyVectorAggregator customMultiValueTarget;
@Before @Before
public void setUp() public void setUp()
@ -74,20 +77,22 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
return index >= DICTIONARY.length ? null : DICTIONARY[index]; return index >= DICTIONARY.length ? null : DICTIONARY[index];
}).when(singleValueSelector).lookupName(anyInt()); }).when(singleValueSelector).lookupName(anyInt());
initializeRandomBuffer(); initializeRandomBuffer();
singleValueTarget = new StringAnyVectorAggregator(singleValueSelector, null, MAX_STRING_BYTES); singleValueTarget = new StringAnyVectorAggregator(singleValueSelector, null, MAX_STRING_BYTES, true);
multiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector, MAX_STRING_BYTES); multiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector, MAX_STRING_BYTES, true);
// customMultiValueTarget aggregates to only single value in case of MVDs
customMultiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector, MAX_STRING_BYTES, false);
} }
@Test(expected = IllegalStateException.class) @Test(expected = IllegalStateException.class)
public void initWithBothSingleAndMultiValueSelectorShouldThrowException() public void initWithBothSingleAndMultiValueSelectorShouldThrowException()
{ {
new StringAnyVectorAggregator(singleValueSelector, multiValueSelector, MAX_STRING_BYTES); new StringAnyVectorAggregator(singleValueSelector, multiValueSelector, MAX_STRING_BYTES, true);
} }
@Test(expected = IllegalStateException.class) @Test(expected = IllegalStateException.class)
public void initWithNeitherSingleNorMultiValueSelectorShouldThrowException() public void initWithNeitherSingleNorMultiValueSelectorShouldThrowException()
{ {
new StringAnyVectorAggregator(null, null, MAX_STRING_BYTES); new StringAnyVectorAggregator(null, null, MAX_STRING_BYTES, true);
} }
@Test @Test
@ -122,7 +127,7 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
public void aggregateMultiValuePositionNotFoundShouldPutFirstValue() public void aggregateMultiValuePositionNotFoundShouldPutFirstValue()
{ {
multiValueTarget.aggregate(buf, POSITION, 0, 2); multiValueTarget.aggregate(buf, POSITION, 0, 2);
Assert.assertEquals(DICTIONARY[1], multiValueTarget.get(buf, POSITION)); Assert.assertEquals("[One, Zero]", multiValueTarget.get(buf, POSITION));
} }
@Test @Test
@ -155,9 +160,9 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
@Test @Test
public void aggregateBatchWithRowsShouldAggregateAllRows() public void aggregateBatchWithRowsShouldAggregateAllRows()
{ {
int[] positions = new int[] {0, 43, 100}; int[] positions = new int[]{0, 43, 100};
int positionOffset = 2; int positionOffset = 2;
int[] rows = new int[] {2, 1, 0}; int[] rows = new int[]{2, 1, 0};
clearBufferForPositions(positionOffset, positions); clearBufferForPositions(positionOffset, positions);
multiValueTarget.aggregate(buf, 3, positions, rows, positionOffset); multiValueTarget.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) { for (int i = 0; i < positions.length; i++) {
@ -166,8 +171,32 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
IndexedInts rowIndex = MULTI_VALUE_ROWS[row]; IndexedInts rowIndex = MULTI_VALUE_ROWS[row];
if (rowIndex.size() == 0) { if (rowIndex.size() == 0) {
Assert.assertNull(multiValueTarget.get(buf, position)); Assert.assertNull(multiValueTarget.get(buf, position));
} else { } else if (rowIndex.size() == 1) {
Assert.assertEquals(multiValueSelector.lookupName(rowIndex.get(0)), multiValueTarget.get(buf, position)); Assert.assertEquals(multiValueSelector.lookupName(rowIndex.get(0)), multiValueTarget.get(buf, position));
} else {
List<String> res = new ArrayList<>();
rowIndex.forEach(index -> res.add(multiValueSelector.lookupName(index)));
Assert.assertEquals(res.toString(), multiValueTarget.get(buf, position));
}
}
}
@Test
public void aggregateBatchWithRowsShouldAggregateAllRowsWithAggregateMVDFalse()
{
int[] positions = new int[]{0, 43, 100};
int positionOffset = 2;
int[] rows = new int[]{2, 1, 0};
clearBufferForPositions(positionOffset, positions);
customMultiValueTarget.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
int position = positions[i] + positionOffset;
int row = rows[i];
IndexedInts rowIndex = MULTI_VALUE_ROWS[row];
if (rowIndex.size() == 0) {
Assert.assertNull(customMultiValueTarget.get(buf, position));
} else {
Assert.assertEquals(multiValueSelector.lookupName(rowIndex.get(0)), customMultiValueTarget.get(buf, position));
} }
} }
} }

View File

@ -489,7 +489,7 @@ public class FallbackVirtualColumnTest
} }
} }
private static class SameMultiVectorSelector implements MultiValueDimensionVectorSelector public static class SameMultiVectorSelector implements MultiValueDimensionVectorSelector
{ {
@Override @Override
public int getValueCardinality() public int getValueCardinality()

View File

@ -95,7 +95,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
String fieldName, String fieldName,
String timeColumn, String timeColumn,
ColumnType type, ColumnType type,
Integer maxStringBytes Integer maxStringBytes,
Boolean aggregateMultipleValues
) )
{ {
switch (type.getType()) { switch (type.getType()) {
@ -121,7 +122,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
String fieldName, String fieldName,
String timeColumn, String timeColumn,
ColumnType type, ColumnType type,
Integer maxStringBytes Integer maxStringBytes,
Boolean aggregateMultipleValues
) )
{ {
switch (type.getType()) { switch (type.getType()) {
@ -147,7 +149,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
String fieldName, String fieldName,
String timeColumn, String timeColumn,
ColumnType type, ColumnType type,
Integer maxStringBytes Integer maxStringBytes,
Boolean aggregateMultipleValues
) )
{ {
switch (type.getType()) { switch (type.getType()) {
@ -158,7 +161,7 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
case DOUBLE: case DOUBLE:
return new DoubleAnyAggregatorFactory(name, fieldName); return new DoubleAnyAggregatorFactory(name, fieldName);
case STRING: case STRING:
return new StringAnyAggregatorFactory(name, fieldName, maxStringBytes); return new StringAnyAggregatorFactory(name, fieldName, maxStringBytes, aggregateMultipleValues);
default: default:
throw SimpleSqlAggregator.badTypeException(fieldName, "ANY", type); throw SimpleSqlAggregator.badTypeException(fieldName, "ANY", type);
} }
@ -170,7 +173,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
String fieldName, String fieldName,
String timeColumn, String timeColumn,
ColumnType outputType, ColumnType outputType,
Integer maxStringBytes Integer maxStringBytes,
Boolean aggregateMultipleValues
); );
} }
@ -244,37 +248,38 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
final AggregatorFactory theAggFactory; final AggregatorFactory theAggFactory;
switch (args.size()) { switch (args.size()) {
case 1: case 1:
theAggFactory = aggregatorType.createAggregatorFactory(aggregatorName, fieldName, null, outputType, null); theAggFactory = aggregatorType.createAggregatorFactory(aggregatorName, fieldName, null, outputType, null, true);
break; break;
case 2: case 2:
int maxStringBytes; Integer maxStringBytes = RexLiteral.intValue(rexNodes.get(1)); // added not null check at the function
try {
maxStringBytes = RexLiteral.intValue(rexNodes.get(1));
}
catch (AssertionError ae) {
plannerContext.setPlanningError(
"The second argument '%s' to function '%s' is not a number",
rexNodes.get(1),
aggregateCall.getName()
);
return null;
}
theAggFactory = aggregatorType.createAggregatorFactory( theAggFactory = aggregatorType.createAggregatorFactory(
aggregatorName, aggregatorName,
fieldName, fieldName,
null, null,
outputType, outputType,
maxStringBytes maxStringBytes.intValue(),
true
);
break;
case 3:
maxStringBytes = RexLiteral.intValue(rexNodes.get(1)); // added not null check at the function for rexNode 1,2
boolean aggregateMultipleValues = RexLiteral.booleanValue(rexNodes.get(2));
theAggFactory = aggregatorType.createAggregatorFactory(
aggregatorName,
fieldName,
null,
outputType,
maxStringBytes,
aggregateMultipleValues
); );
break; break;
default: default:
throw InvalidSqlInput.exception( throw InvalidSqlInput.exception(
"Function [%s] expects 1 or 2 arguments but found [%s]", "Function [%s] expects 1 or 2 or 3 arguments but found [%s]",
aggregateCall.getName(), aggregateCall.getName(),
args.size() args.size()
); );
} }
return Aggregation.create( return Aggregation.create(
Collections.singletonList(theAggFactory), Collections.singletonList(theAggFactory),
finalizeAggregations ? new FinalizingFieldAccessPostAggregator(name, aggregatorName) : null finalizeAggregations ? new FinalizingFieldAccessPostAggregator(name, aggregatorName) : null
@ -372,10 +377,11 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
InferTypes.RETURN_TYPE, InferTypes.RETURN_TYPE,
DefaultOperandTypeChecker DefaultOperandTypeChecker
.builder() .builder()
.operandNames("expr", "maxBytesPerString") .operandNames("expr", "maxBytesPerStringInt", "aggregateMultipleValuesBoolean")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.BOOLEAN)
.requiredOperandCount(1) .requiredOperandCount(1)
.literalOperands(1) .literalOperands(1, 2)
.notNullOperands(1, 2)
.build(), .build(),
SqlFunctionCategory.USER_DEFINED_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
false, false,
@ -402,9 +408,9 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
SqlParserPos pos = call.getParserPosition(); SqlParserPos pos = call.getParserPosition();
if (operands.isEmpty() || operands.size() > 2) { if (operands.isEmpty() || operands.size() > 3) {
throw InvalidSqlInput.exception( throw InvalidSqlInput.exception(
"Function [%s] expects 1 or 2 arguments but found [%s]", "Function [%s] expects 1 or 2 or 3 arguments but found [%s]",
getName(), getName(),
operands.size() operands.size()
); );
@ -417,6 +423,9 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
if (operands.size() == 2) { if (operands.size() == 2) {
newOperands.add(operands.get(1)); newOperands.add(operands.get(1));
} }
if (operands.size() == 3) {
newOperands.add(operands.get(2));
}
return replacementAggFunc.createCall(pos, newOperands); return replacementAggFunc.createCall(pos, newOperands);
} }

View File

@ -119,7 +119,8 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator
rexNodes.get(1) rexNodes.get(1)
), ),
outputType, outputType,
null null,
true
); );
break; break;
case 3: case 3:
@ -145,7 +146,8 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator
rexNodes.get(1) rexNodes.get(1)
), ),
outputType, outputType,
maxStringBytes maxStringBytes,
true
); );
break; break;
default: default:

View File

@ -188,6 +188,7 @@ public class DefaultOperandTypeChecker implements SqlOperandTypeChecker
@Nullable @Nullable
private Integer requiredOperandCount; private Integer requiredOperandCount;
private int[] literalOperands; private int[] literalOperands;
private IntSet notNullOperands = new IntArraySet();
private Builder() private Builder()
{ {
@ -229,6 +230,12 @@ public class DefaultOperandTypeChecker implements SqlOperandTypeChecker
return this; return this;
} }
public Builder notNullOperands(final int... notNullOperands)
{
Arrays.stream(notNullOperands).forEach(this.notNullOperands::add);
return this;
}
public DefaultOperandTypeChecker build() public DefaultOperandTypeChecker build()
{ {
int computedRequiredOperandCount = requiredOperandCount == null ? operandTypes.size() : requiredOperandCount; int computedRequiredOperandCount = requiredOperandCount == null ? operandTypes.size() : requiredOperandCount;
@ -236,16 +243,18 @@ public class DefaultOperandTypeChecker implements SqlOperandTypeChecker
operandNames, operandNames,
operandTypes, operandTypes,
computedRequiredOperandCount, computedRequiredOperandCount,
DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount, operandTypes.size()), DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount, operandTypes.size(), notNullOperands),
literalOperands literalOperands
); );
} }
} }
public static IntSet buildNullableOperands(int requiredOperandCount, int totalOperandCount) public static IntSet buildNullableOperands(int requiredOperandCount, int totalOperandCount, IntSet notNullOperands)
{ {
final IntSet nullableOperands = new IntArraySet(); final IntSet nullableOperands = new IntArraySet();
IntStream.range(requiredOperandCount, totalOperandCount).forEach(nullableOperands::add); IntStream.range(requiredOperandCount, totalOperandCount)
.filter(i -> !notNullOperands.contains(i))
.forEach(nullableOperands::add);
return nullableOperands; return nullableOperands;
} }
} }

View File

@ -593,7 +593,7 @@ public class OperatorConversions
{ {
final IntSet nullableOperands = requiredOperandCount == null final IntSet nullableOperands = requiredOperandCount == null
? new IntArraySet() ? new IntArraySet()
: DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount, operandTypes.size()); : DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount, operandTypes.size(), new IntArraySet());
if (operandTypeInference == null) { if (operandTypeInference == null) {
SqlOperandTypeInference defaultInference = new DefaultOperandTypeInference(operandTypes, nullableOperands); SqlOperandTypeInference defaultInference = new DefaultOperandTypeInference(operandTypes, nullableOperands);
return (callBinding, returnType, types) -> { return (callBinding, returnType, types) -> {

View File

@ -497,7 +497,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
cannotVectorize(); cannotVectorize();
testQuery( testQuery(
"SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100) FROM foo WHERE (TIME_FLOOR(__time, 'PT1H'), m1) IN\n" "SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100, true) FROM foo WHERE (TIME_FLOOR(__time, 'PT1H'), m1) IN\n"
+ " (\n" + " (\n"
+ " SELECT TIME_FLOOR(__time, 'PT1H') AS t1, MIN(m1) AS t2 FROM foo WHERE dim3 = 'b'\n" + " SELECT TIME_FLOOR(__time, 'PT1H') AS t1, MIN(m1) AS t2 FROM foo WHERE dim3 = 'b'\n"
+ " AND __time BETWEEN '1994-04-29 00:00:00' AND '2020-01-11 00:00:00' GROUP BY 1\n" + " AND __time BETWEEN '1994-04-29 00:00:00' AND '2020-01-11 00:00:00' GROUP BY 1\n"
@ -532,7 +532,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
) )
.setGranularity(Granularities.ALL) .setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators( .setAggregatorSpecs(aggregators(
new StringAnyAggregatorFactory("a0", "dim3", 100) new StringAnyAggregatorFactory("a0", "dim3", 100, true)
)) ))
.setContext(QUERY_CONTEXT_DEFAULT) .setContext(QUERY_CONTEXT_DEFAULT)
.build() .build()
@ -598,7 +598,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
) )
.setGranularity(Granularities.ALL) .setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators( .setAggregatorSpecs(aggregators(
new StringAnyAggregatorFactory("a0", "dim3", 100) new StringAnyAggregatorFactory("a0", "dim3", 100, true)
)) ))
.setContext(QUERY_CONTEXT_DEFAULT) .setContext(QUERY_CONTEXT_DEFAULT)
.build() .build()
@ -609,6 +609,71 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
) )
); );
} }
@Test
public void testJoinOnGroupByInsteadOfTimeseriesWithFloorOnTimeWithNoAggregateMultipleValues()
{
// Cannot vectorize JOIN operator.
cannotVectorize();
testQuery(
"SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100, false) FROM foo WHERE (CAST(TIME_FLOOR(__time, 'PT1H') AS BIGINT) + 1, m1) IN\n"
+ " (\n"
+ " SELECT CAST(TIME_FLOOR(__time, 'PT1H') AS BIGINT) + 1 AS t1, MIN(m1) AS t2 FROM foo WHERE dim3 = 'b'\n"
+ " AND __time BETWEEN '1994-04-29 00:00:00' AND '2020-01-11 00:00:00' GROUP BY 1\n"
+ " )\n"
+ "GROUP BY 1, 2\n",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(
join(
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Intervals.of(
"1994-04-29/2020-01-11T00:00:00.001Z")))
.setVirtualColumns(
expressionVirtualColumn(
"v0",
"(timestamp_floor(\"__time\",'PT1H',null,'UTC') + 1)",
ColumnType.LONG
)
)
.setDimFilter(equality("dim3", "b", ColumnType.STRING))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(new DefaultDimensionSpec(
"v0",
"d0",
ColumnType.LONG
)))
.setAggregatorSpecs(aggregators(
new FloatMinAggregatorFactory("a0", "m1")
))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()),
"j0.",
"(((timestamp_floor(\"__time\",'PT1H',null,'UTC') + 1) == \"j0.d0\") && (\"m1\" == \"j0.a0\"))",
JoinType.INNER
)
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(
new DefaultDimensionSpec("__time", "d0", ColumnType.LONG),
new DefaultDimensionSpec("m1", "d1", ColumnType.FLOAT)
)
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(
new StringAnyAggregatorFactory("a0", "dim3", 100, false)
))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{946684800000L, 1.0f, "a"}, // picks up first from [a, b]
new Object[]{946771200000L, 2.0f, "b"} // picks up first from [b, c]
)
);
}
@Test @Test
@Parameters(source = QueryContextForJoinProvider.class) @Parameters(source = QueryContextForJoinProvider.class)
@ -1480,7 +1545,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
new SubstringDimExtractionFn(0, 1) new SubstringDimExtractionFn(0, 1)
) )
) )
.setAggregatorSpecs(new StringAnyAggregatorFactory("a0", "v", 10)) .setAggregatorSpecs(new StringAnyAggregatorFactory("a0", "v", 10, true))
.build() .build()
), ),
"j0.", "j0.",

View File

@ -844,10 +844,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new LongAnyAggregatorFactory("a0", "cnt"), new LongAnyAggregatorFactory("a0", "cnt"),
new FloatAnyAggregatorFactory("a1", "m1"), new FloatAnyAggregatorFactory("a1", "m1"),
new DoubleAnyAggregatorFactory("a2", "m2"), new DoubleAnyAggregatorFactory("a2", "m2"),
new StringAnyAggregatorFactory("a3", "dim1", 10), new StringAnyAggregatorFactory("a3", "dim1", 10, true),
new LongAnyAggregatorFactory("a4", "v0"), new LongAnyAggregatorFactory("a4", "v0"),
new FloatAnyAggregatorFactory("a5", "v1"), new FloatAnyAggregatorFactory("a5", "v1"),
new StringAnyAggregatorFactory("a6", "v2", 10) new StringAnyAggregatorFactory("a6", "v2", 10, true)
) )
) )
.context(QUERY_CONTEXT_DEFAULT) .context(QUERY_CONTEXT_DEFAULT)
@ -1420,7 +1420,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setAggregatorSpecs(aggregators(new StringAnyAggregatorFactory( .setAggregatorSpecs(aggregators(new StringAnyAggregatorFactory(
"a0:a", "a0:a",
"dim1", "dim1",
10 10, true
))) )))
.setPostAggregatorSpecs( .setPostAggregatorSpecs(
ImmutableList.of( ImmutableList.of(
@ -1565,7 +1565,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.granularity(Granularities.ALL) .granularity(Granularities.ALL)
.aggregators( .aggregators(
aggregators( aggregators(
new StringAnyAggregatorFactory("a0", "dim1", 32), new StringAnyAggregatorFactory("a0", "dim1", 32, true),
new LongAnyAggregatorFactory("a1", "l2"), new LongAnyAggregatorFactory("a1", "l2"),
new DoubleAnyAggregatorFactory("a2", "d2"), new DoubleAnyAggregatorFactory("a2", "d2"),
new FloatAnyAggregatorFactory("a3", "f2") new FloatAnyAggregatorFactory("a3", "f2")
@ -1607,7 +1607,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.filters(filter) .filters(filter)
.aggregators( .aggregators(
aggregators( aggregators(
new StringAnyAggregatorFactory("a0", "dim1", 32), new StringAnyAggregatorFactory("a0", "dim1", 32, true),
new LongAnyAggregatorFactory("a1", "l2"), new LongAnyAggregatorFactory("a1", "l2"),
new DoubleAnyAggregatorFactory("a2", "d2"), new DoubleAnyAggregatorFactory("a2", "d2"),
new FloatAnyAggregatorFactory("a3", "f2") new FloatAnyAggregatorFactory("a3", "f2")
@ -9422,7 +9422,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.granularity(Granularities.ALL) .granularity(Granularities.ALL)
.aggregators( .aggregators(
aggregators( aggregators(
new StringAnyAggregatorFactory("a0", "dim1", 1024), new StringAnyAggregatorFactory("a0", "dim1", 1024, true),
new LongAnyAggregatorFactory("a1", "l1"), new LongAnyAggregatorFactory("a1", "l1"),
new StringFirstAggregatorFactory("a2", "dim1", null, 1024), new StringFirstAggregatorFactory("a2", "dim1", null, 1024),
new LongFirstAggregatorFactory("a3", "l1", null), new LongFirstAggregatorFactory("a3", "l1", null),
@ -9741,7 +9741,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setAggregatorSpecs( .setAggregatorSpecs(
aggregators( aggregators(
new FilteredAggregatorFactory( new FilteredAggregatorFactory(
new StringAnyAggregatorFactory("a0", "dim1", 1024), new StringAnyAggregatorFactory("a0", "dim1", 1024, true),
equality("dim1", "nonexistent", ColumnType.STRING) equality("dim1", "nonexistent", ColumnType.STRING)
), ),
new FilteredAggregatorFactory( new FilteredAggregatorFactory(
@ -13533,6 +13533,24 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
); );
} }
@Test
public void testStringAnyAggArgValidation()
{
DruidException e = assertThrows(DruidException.class, () -> testBuilder()
.sql("SELECT ANY_VALUE(dim3, 1000, 'true') FROM foo")
.queryContext(ImmutableMap.of())
.run());
assertThat(e, invalidSqlIs(
"Cannot apply 'ANY_VALUE' to arguments of type 'ANY_VALUE(<VARCHAR>, <INTEGER>, <CHAR(4)>)'. Supported form(s): 'ANY_VALUE(<expr>, [<maxBytesPerStringInt>, [<aggregateMultipleValuesBoolean>]])' (line [1], column [8])"));
DruidException e1 = assertThrows(DruidException.class, () -> testBuilder()
.sql("SELECT ANY_VALUE(dim3, 1000, null) FROM foo")
.queryContext(ImmutableMap.of()).run());
Assert.assertEquals("Illegal use of 'NULL' (line [1], column [30])", e1.getMessage());
DruidException e2 = assertThrows(DruidException.class, () -> testBuilder()
.sql("SELECT ANY_VALUE(dim3, null, true) FROM foo")
.queryContext(ImmutableMap.of()).run());
Assert.assertEquals("Illegal use of 'NULL' (line [1], column [24])", e2.getMessage());
}
@Test @Test
public void testStringAggMaxBytes() public void testStringAggMaxBytes()
{ {
@ -14367,7 +14385,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new StringLastAggregatorFactory("a1", "dim1", "__time", 1024), new StringLastAggregatorFactory("a1", "dim1", "__time", 1024),
new StringFirstAggregatorFactory("a2", "dim3", "__time", 1024), new StringFirstAggregatorFactory("a2", "dim3", "__time", 1024),
new StringFirstAggregatorFactory("a3", "dim1", "__time", 1024), new StringFirstAggregatorFactory("a3", "dim1", "__time", 1024),
new StringAnyAggregatorFactory("a4", "dim3", 1024))) new StringAnyAggregatorFactory("a4", "dim3", 1024, true)))
.build() .build()
), ),

View File

@ -323,4 +323,10 @@ public class QueryTestBuilder
return build().resultsOnly(); return build().resultsOnly();
} }
public boolean isDecoupledMode()
{
String mode = (String) queryContext.getOrDefault(PlannerConfig.CTX_NATIVE_QUERY_SQL_PLANNING_MODE, "");
return PlannerConfig.NATIVE_QUERY_SQL_PLANNING_MODE_DECOUPLED.equalsIgnoreCase(mode);
}
} }

View File

@ -2334,3 +2334,4 @@ LAST_VALUE
markUnused markUnused
markUsed markUsed
segmentId segmentId
aggregateMultipleValues