mirror of https://github.com/apache/druid.git
Enabling aggregateMultipleValues in all StringAnyAggregators (#15434)
* Enabling aggregateMultipleValues in all StringAnyAggregators * Adding more tests * More validation * fix warning * updating asserts in decoupled mode * fix intellij inspection * Addressing comments * Addressing comments * Adding early validations and make aggregate consistent across all * fixing tests * fixing tests * Update docs/querying/sql-aggregations.md Co-authored-by: Clint Wylie <cjwylie@gmail.com> * fixing static check --------- Co-authored-by: Clint Wylie <cjwylie@gmail.com>
This commit is contained in:
parent
64fcb32bcf
commit
93cd638645
|
@ -90,7 +90,7 @@ In the aggregation functions supported by Druid, only `COUNT`, `ARRAY_AGG`, and
|
|||
|`EARLIEST_BY(expr, timestampExpr, [maxBytesPerValue])`|Returns the earliest value of `expr`.<br />The earliest value of `expr` is taken from the row with the overall earliest non-null value of `timestampExpr`. <br />If the earliest non-null value of `timestampExpr` appears in multiple rows, the `expr` may be taken from any of those rows.<br /><br />If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.<br/>If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
|
||||
|`LATEST(expr, [maxBytesPerValue])`|Returns the latest value of `expr`<br />The `expr` must come from a relation with a timestamp column (like `__time` in a Druid datasource) and the "latest" is taken from the row with the overall latest non-null value of the timestamp column.<br />If the latest non-null value of the timestamp column appears in multiple rows, the `expr` may be taken from any of those rows.<br /><br />If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.<br/>If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
|
||||
|`LATEST_BY(expr, timestampExpr, [maxBytesPerValue])`|Returns the latest value of `expr`.<br />The latest value of `expr` is taken from the row with the overall latest non-null value of `timestampExpr`.<br />If the overall latest non-null value of `timestampExpr` appears in multiple rows, the `expr` may be taken from any of those rows.<br /><br />If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.<br/>If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
|
||||
|`ANY_VALUE(expr, [maxBytesPerValue])`|Returns any value of `expr` including null. This aggregator can simplify and optimize the performance by returning the first encountered value (including `null`).<br /><br />If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.<br/>If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
|
||||
|`ANY_VALUE(expr, [maxBytesPerValue, [aggregateMultipleValues]])`|Returns any value of `expr` including null. This aggregator can simplify and optimize the performance by returning the first encountered value (including `null`).<br /><br />If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.<br/>If `maxBytesPerValue` is omitted; it defaults to `1024`. `aggregateMultipleValues` is an optional boolean flag controls the behavior of aggregating a [multi-value dimension](./multi-value-dimensions.md). `aggregateMultipleValues` is set as true by default and returns the stringified array in case of a multi-value dimension. By setting it to false, function will return first value instead. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
|
||||
|`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy dimension is included in a row, when using `GROUPING SETS`. Refer to [additional documentation](aggregations.md#grouping-aggregator) on how to infer this number.|N/A|
|
||||
|`ARRAY_AGG(expr, [size])`|Collects all values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes). If the aggregated array grows larger than the maximum size in bytes, the query will fail. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the output array may vary depending on processing order.|`null`|
|
||||
|`ARRAY_AGG(DISTINCT expr, [size])`|Collects all distinct values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes) per aggregate. If the aggregated array grows larger than the maximum size in bytes, the query will fail. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results will be based on the default for the element type.|`null`|
|
||||
|
|
|
@ -50,7 +50,7 @@ Calculates the arc cosine of a numeric expression.
|
|||
|
||||
## ANY_VALUE
|
||||
|
||||
`ANY_VALUE(expr, [maxBytesPerValue])`
|
||||
`ANY_VALUE(expr, [maxBytesPerValue, [aggregateMultipleValues]])`
|
||||
|
||||
**Function type:** [Aggregation](sql-aggregations.md)
|
||||
|
||||
|
|
|
@ -24,19 +24,23 @@ import org.apache.druid.query.aggregation.Aggregator;
|
|||
import org.apache.druid.segment.BaseObjectColumnValueSelector;
|
||||
import org.apache.druid.segment.DimensionHandlerUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class StringAnyAggregator implements Aggregator
|
||||
{
|
||||
private final BaseObjectColumnValueSelector valueSelector;
|
||||
private final int maxStringBytes;
|
||||
private boolean isFound;
|
||||
private String foundValue;
|
||||
private final boolean aggregateMultipleValues;
|
||||
|
||||
public StringAnyAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes)
|
||||
public StringAnyAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes, boolean aggregateMultipleValues)
|
||||
{
|
||||
this.valueSelector = valueSelector;
|
||||
this.maxStringBytes = maxStringBytes;
|
||||
this.foundValue = null;
|
||||
this.isFound = false;
|
||||
this.aggregateMultipleValues = aggregateMultipleValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -44,18 +48,36 @@ public class StringAnyAggregator implements Aggregator
|
|||
{
|
||||
if (!isFound) {
|
||||
final Object object = valueSelector.getObject();
|
||||
foundValue = DimensionHandlerUtils.convertObjectToString(object);
|
||||
if (foundValue != null && foundValue.length() > maxStringBytes) {
|
||||
foundValue = foundValue.substring(0, maxStringBytes);
|
||||
}
|
||||
foundValue = StringUtils.fastLooseChop(readValue(object), maxStringBytes);
|
||||
isFound = true;
|
||||
}
|
||||
}
|
||||
|
||||
private String readValue(final Object object)
|
||||
{
|
||||
if (object == null) {
|
||||
return null;
|
||||
}
|
||||
if (object instanceof List) {
|
||||
List<Object> objectList = (List) object;
|
||||
if (objectList.size() == 0) {
|
||||
return null;
|
||||
}
|
||||
if (objectList.size() == 1) {
|
||||
return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
|
||||
}
|
||||
if (aggregateMultipleValues) {
|
||||
return DimensionHandlerUtils.convertObjectToString(objectList);
|
||||
}
|
||||
return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
|
||||
}
|
||||
return DimensionHandlerUtils.convertObjectToString(object);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object get()
|
||||
{
|
||||
return StringUtils.chop(foundValue, maxStringBytes);
|
||||
return foundValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -48,13 +48,15 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
|
|||
|
||||
private final String fieldName;
|
||||
private final String name;
|
||||
protected final int maxStringBytes;
|
||||
private final int maxStringBytes;
|
||||
private final boolean aggregateMultipleValues;
|
||||
|
||||
@JsonCreator
|
||||
public StringAnyAggregatorFactory(
|
||||
@JsonProperty("name") String name,
|
||||
@JsonProperty("fieldName") final String fieldName,
|
||||
@JsonProperty("maxStringBytes") Integer maxStringBytes
|
||||
@JsonProperty("maxStringBytes") Integer maxStringBytes,
|
||||
@JsonProperty("aggregateMultipleValues") @Nullable final Boolean aggregateMultipleValues
|
||||
)
|
||||
{
|
||||
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
|
||||
|
@ -67,18 +69,19 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
|
|||
this.maxStringBytes = maxStringBytes == null
|
||||
? StringFirstAggregatorFactory.DEFAULT_MAX_STRING_SIZE
|
||||
: maxStringBytes;
|
||||
this.aggregateMultipleValues = aggregateMultipleValues == null ? true : aggregateMultipleValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Aggregator factorize(ColumnSelectorFactory metricFactory)
|
||||
{
|
||||
return new StringAnyAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes);
|
||||
return new StringAnyAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes, aggregateMultipleValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
|
||||
{
|
||||
return new StringAnyBufferAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes);
|
||||
return new StringAnyBufferAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes, aggregateMultipleValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -90,13 +93,15 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
|
|||
return new StringAnyVectorAggregator(
|
||||
null,
|
||||
selectorFactory.makeMultiValueDimensionSelector(DefaultDimensionSpec.of(fieldName)),
|
||||
maxStringBytes
|
||||
maxStringBytes,
|
||||
aggregateMultipleValues
|
||||
);
|
||||
} else {
|
||||
return new StringAnyVectorAggregator(
|
||||
selectorFactory.makeSingleValueDimensionSelector(DefaultDimensionSpec.of(fieldName)),
|
||||
null,
|
||||
maxStringBytes
|
||||
maxStringBytes,
|
||||
aggregateMultipleValues
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -122,7 +127,7 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
|
|||
@Override
|
||||
public AggregatorFactory getCombiningFactory()
|
||||
{
|
||||
return new StringAnyAggregatorFactory(name, name, maxStringBytes);
|
||||
return new StringAnyAggregatorFactory(name, name, maxStringBytes, aggregateMultipleValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -155,6 +160,11 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
|
|||
{
|
||||
return maxStringBytes;
|
||||
}
|
||||
@JsonProperty
|
||||
public boolean getAggregateMultipleValues()
|
||||
{
|
||||
return aggregateMultipleValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> requiredFields()
|
||||
|
@ -192,7 +202,7 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
|
|||
@Override
|
||||
public AggregatorFactory withName(String newName)
|
||||
{
|
||||
return new StringAnyAggregatorFactory(newName, getFieldName(), getMaxStringBytes());
|
||||
return new StringAnyAggregatorFactory(newName, getFieldName(), getMaxStringBytes(), getAggregateMultipleValues());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.druid.segment.BaseObjectColumnValueSelector;
|
|||
import org.apache.druid.segment.DimensionHandlerUtils;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.List;
|
||||
|
||||
public class StringAnyBufferAggregator implements BufferAggregator
|
||||
{
|
||||
|
@ -34,11 +35,13 @@ public class StringAnyBufferAggregator implements BufferAggregator
|
|||
|
||||
private final BaseObjectColumnValueSelector valueSelector;
|
||||
private final int maxStringBytes;
|
||||
private final boolean aggregateMultipleValues;
|
||||
|
||||
public StringAnyBufferAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes)
|
||||
public StringAnyBufferAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes, boolean aggregateMultipleValues)
|
||||
{
|
||||
this.valueSelector = valueSelector;
|
||||
this.maxStringBytes = maxStringBytes;
|
||||
this.aggregateMultipleValues = aggregateMultipleValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -51,8 +54,7 @@ public class StringAnyBufferAggregator implements BufferAggregator
|
|||
public void aggregate(ByteBuffer buf, int position)
|
||||
{
|
||||
if (buf.getInt(position) == NOT_FOUND_FLAG_VALUE) {
|
||||
final Object object = valueSelector.getObject();
|
||||
String foundValue = DimensionHandlerUtils.convertObjectToString(object);
|
||||
String foundValue = readValue(valueSelector.getObject());
|
||||
if (foundValue != null) {
|
||||
ByteBuffer mutationBuffer = buf.duplicate();
|
||||
mutationBuffer.position(position + FOUND_VALUE_OFFSET);
|
||||
|
@ -65,6 +67,27 @@ public class StringAnyBufferAggregator implements BufferAggregator
|
|||
}
|
||||
}
|
||||
|
||||
private String readValue(Object object)
|
||||
{
|
||||
if (object == null) {
|
||||
return null;
|
||||
}
|
||||
if (object instanceof List) {
|
||||
List<Object> objectList = (List) object;
|
||||
if (objectList.size() == 0) {
|
||||
return null;
|
||||
}
|
||||
if (objectList.size() == 1) {
|
||||
return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
|
||||
}
|
||||
if (aggregateMultipleValues) {
|
||||
return DimensionHandlerUtils.convertObjectToString(objectList);
|
||||
}
|
||||
return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
|
||||
}
|
||||
return DimensionHandlerUtils.convertObjectToString(object);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object get(ByteBuffer buf, int position)
|
||||
{
|
||||
|
|
|
@ -23,12 +23,15 @@ import com.google.common.annotations.VisibleForTesting;
|
|||
import com.google.common.base.Preconditions;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.query.aggregation.VectorAggregator;
|
||||
import org.apache.druid.segment.DimensionHandlerUtils;
|
||||
import org.apache.druid.segment.data.IndexedInts;
|
||||
import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector;
|
||||
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class StringAnyVectorAggregator implements VectorAggregator
|
||||
{
|
||||
|
@ -43,11 +46,13 @@ public class StringAnyVectorAggregator implements VectorAggregator
|
|||
@Nullable
|
||||
private final MultiValueDimensionVectorSelector multiValueSelector;
|
||||
private final int maxStringBytes;
|
||||
private final boolean aggregateMultipleValues;
|
||||
|
||||
public StringAnyVectorAggregator(
|
||||
SingleValueDimensionVectorSelector singleValueSelector,
|
||||
MultiValueDimensionVectorSelector multiValueSelector,
|
||||
int maxStringBytes
|
||||
int maxStringBytes,
|
||||
final boolean aggregateMultipleValues
|
||||
)
|
||||
{
|
||||
Preconditions.checkState(
|
||||
|
@ -61,6 +66,7 @@ public class StringAnyVectorAggregator implements VectorAggregator
|
|||
this.multiValueSelector = multiValueSelector;
|
||||
this.singleValueSelector = singleValueSelector;
|
||||
this.maxStringBytes = maxStringBytes;
|
||||
this.aggregateMultipleValues = aggregateMultipleValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -78,7 +84,7 @@ public class StringAnyVectorAggregator implements VectorAggregator
|
|||
if (startRow < rows.length) {
|
||||
IndexedInts row = rows[startRow];
|
||||
@Nullable
|
||||
String foundValue = row.size() == 0 ? null : multiValueSelector.lookupName(row.get(0));
|
||||
String foundValue = readValue(row);
|
||||
putValue(buf, position, foundValue);
|
||||
}
|
||||
} else if (singleValueSelector != null) {
|
||||
|
@ -93,6 +99,24 @@ public class StringAnyVectorAggregator implements VectorAggregator
|
|||
}
|
||||
}
|
||||
|
||||
private String readValue(IndexedInts row)
|
||||
{
|
||||
if (row.size() == 0) {
|
||||
return null;
|
||||
}
|
||||
if (aggregateMultipleValues) {
|
||||
if (row.size() == 1) {
|
||||
return multiValueSelector.lookupName(row.get(0));
|
||||
}
|
||||
List<String> arrayList = new ArrayList<>();
|
||||
row.forEach(rowIndex -> {
|
||||
arrayList.add(multiValueSelector.lookupName(rowIndex));
|
||||
});
|
||||
return DimensionHandlerUtils.convertObjectToString(arrayList);
|
||||
}
|
||||
return multiValueSelector.lookupName(row.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void aggregate(
|
||||
ByteBuffer buf,
|
||||
|
@ -142,4 +166,9 @@ public class StringAnyVectorAggregator implements VectorAggregator
|
|||
buf.putInt(position, FOUND_AND_NULL_FLAG_VALUE);
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isAggregateMultipleValues()
|
||||
{
|
||||
return aggregateMultipleValues;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -148,7 +148,7 @@ public class AggregatorFactoryTest extends InitializedNullHandlingTest
|
|||
// string aggregators
|
||||
new StringFirstAggregatorFactory("stringFirst", "col", null, 1024),
|
||||
new StringLastAggregatorFactory("stringLast", "col", null, 1024),
|
||||
new StringAnyAggregatorFactory("stringAny", "col", 1024),
|
||||
new StringAnyAggregatorFactory("stringAny", "col", 1024, true),
|
||||
// sketch aggs
|
||||
new CardinalityAggregatorFactory("cardinality", ImmutableList.of(DefaultDimensionSpec.of("some-col")), false),
|
||||
new HyperUniquesAggregatorFactory("hyperUnique", "hyperunique"),
|
||||
|
@ -307,7 +307,8 @@ public class AggregatorFactoryTest extends InitializedNullHandlingTest
|
|||
// string aggregators
|
||||
new StringFirstAggregatorFactory("col", "col", null, 1024),
|
||||
new StringLastAggregatorFactory("col", "col", null, 1024),
|
||||
new StringAnyAggregatorFactory("col", "col", 1024),
|
||||
new StringAnyAggregatorFactory("col", "col", 1024, true),
|
||||
new StringAnyAggregatorFactory("col", "col", 1024, false),
|
||||
// sketch aggs
|
||||
new CardinalityAggregatorFactory("col", ImmutableList.of(DefaultDimensionSpec.of("some-col")), false),
|
||||
new HyperUniquesAggregatorFactory("col", "hyperunique"),
|
||||
|
|
|
@ -49,7 +49,7 @@ public class StringAnyAggregationTest
|
|||
@Before
|
||||
public void setup()
|
||||
{
|
||||
stringAnyAggFactory = new StringAnyAggregatorFactory("billy", "nilly", MAX_STRING_SIZE);
|
||||
stringAnyAggFactory = new StringAnyAggregatorFactory("billy", "nilly", MAX_STRING_SIZE, true);
|
||||
combiningAggFactory = stringAnyAggFactory.getCombiningFactory();
|
||||
valueSelector = new TestObjectColumnSelector<>(strings);
|
||||
objectSelector = new TestObjectColumnSelector<>(strings);
|
||||
|
|
|
@ -19,48 +19,44 @@
|
|||
|
||||
package org.apache.druid.query.aggregation.any;
|
||||
|
||||
import org.apache.druid.segment.ColumnInspector;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.druid.query.aggregation.Aggregator;
|
||||
import org.apache.druid.query.aggregation.TestObjectColumnSelector;
|
||||
import org.apache.druid.query.dimension.DimensionSpec;
|
||||
import org.apache.druid.segment.ColumnSelectorFactory;
|
||||
import org.apache.druid.segment.ColumnValueSelector;
|
||||
import org.apache.druid.segment.DimensionSelector;
|
||||
import org.apache.druid.segment.column.ColumnCapabilities;
|
||||
import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector;
|
||||
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
|
||||
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
|
||||
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
|
||||
import org.apache.druid.segment.vector.TestVectorColumnSelectorFactory;
|
||||
import org.apache.druid.segment.virtual.FallbackVirtualColumnTest;
|
||||
import org.apache.druid.testing.InitializedNullHandlingTest;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.Mockito;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import java.util.List;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
|
||||
{
|
||||
private static final String NAME = "NAME";
|
||||
private static final String FIELD_NAME = "FIELD_NAME";
|
||||
private static final int MAX_STRING_BYTES = 10;
|
||||
|
||||
@Mock
|
||||
private ColumnInspector columnInspector;
|
||||
@Mock
|
||||
private TestColumnSelectorFactory columnInspector;
|
||||
private ColumnCapabilities capabilities;
|
||||
@Mock
|
||||
private VectorColumnSelectorFactory vectorSelectorFactory;
|
||||
@Mock
|
||||
private SingleValueDimensionVectorSelector singleValueDimensionVectorSelector;
|
||||
@Mock
|
||||
private MultiValueDimensionVectorSelector multiValueDimensionVectorSelector;
|
||||
|
||||
private TestVectorColumnSelectorFactory vectorSelectorFactory;
|
||||
private StringAnyAggregatorFactory target;
|
||||
|
||||
@Before
|
||||
public void setUp()
|
||||
{
|
||||
Mockito.doReturn(capabilities).when(vectorSelectorFactory).getColumnCapabilities(FIELD_NAME);
|
||||
Mockito.doReturn(ColumnCapabilities.Capable.UNKNOWN).when(capabilities).hasMultipleValues();
|
||||
target = new StringAnyAggregatorFactory(NAME, FIELD_NAME, MAX_STRING_BYTES);
|
||||
target = new StringAnyAggregatorFactory(NAME, FIELD_NAME, MAX_STRING_BYTES, true);
|
||||
columnInspector = new TestColumnSelectorFactory();
|
||||
vectorSelectorFactory = new TestVectorColumnSelectorFactory();
|
||||
capabilities = ColumnCapabilitiesImpl.createDefault().setHasMultipleValues(true);
|
||||
vectorSelectorFactory.addCapabilities(FIELD_NAME, capabilities);
|
||||
vectorSelectorFactory.addMVDVS(FIELD_NAME, new FallbackVirtualColumnTest.SameMultiVectorSelector());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -72,10 +68,6 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
|
|||
@Test
|
||||
public void factorizeVectorWithoutCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector()
|
||||
{
|
||||
Mockito.doReturn(null).when(vectorSelectorFactory).getColumnCapabilities(FIELD_NAME);
|
||||
Mockito.doReturn(singleValueDimensionVectorSelector)
|
||||
.when(vectorSelectorFactory)
|
||||
.makeSingleValueDimensionSelector(any());
|
||||
StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory);
|
||||
Assert.assertNotNull(aggregator);
|
||||
}
|
||||
|
@ -83,9 +75,6 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
|
|||
@Test
|
||||
public void factorizeVectorWithUnknownCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector()
|
||||
{
|
||||
Mockito.doReturn(multiValueDimensionVectorSelector)
|
||||
.when(vectorSelectorFactory)
|
||||
.makeMultiValueDimensionSelector(any());
|
||||
StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory);
|
||||
Assert.assertNotNull(aggregator);
|
||||
}
|
||||
|
@ -93,10 +82,6 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
|
|||
@Test
|
||||
public void factorizeVectorWithMultipleValuesCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector()
|
||||
{
|
||||
Mockito.doReturn(ColumnCapabilities.Capable.TRUE).when(capabilities).hasMultipleValues();
|
||||
Mockito.doReturn(multiValueDimensionVectorSelector)
|
||||
.when(vectorSelectorFactory)
|
||||
.makeMultiValueDimensionSelector(any());
|
||||
StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory);
|
||||
Assert.assertNotNull(aggregator);
|
||||
}
|
||||
|
@ -104,11 +89,86 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
|
|||
@Test
|
||||
public void factorizeVectorWithoutMultipleValuesCapabilitiesShouldReturnAggregatorWithSingleDimensionSelector()
|
||||
{
|
||||
Mockito.doReturn(ColumnCapabilities.Capable.FALSE).when(capabilities).hasMultipleValues();
|
||||
Mockito.doReturn(singleValueDimensionVectorSelector)
|
||||
.when(vectorSelectorFactory)
|
||||
.makeSingleValueDimensionSelector(any());
|
||||
StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory);
|
||||
Assert.assertNotNull(aggregator);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFactorize()
|
||||
{
|
||||
Aggregator res = target.factorize(new TestColumnSelectorFactory());
|
||||
Assert.assertTrue(res instanceof StringAnyAggregator);
|
||||
res.aggregate();
|
||||
Assert.assertEquals(null, res.get());
|
||||
StringAnyVectorAggregator vectorAggregator = target.factorizeVector(vectorSelectorFactory);
|
||||
Assert.assertTrue(vectorAggregator.isAggregateMultipleValues());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSvdStringAnyAggregator()
|
||||
{
|
||||
TestColumnSelectorFactory columnSelectorFactory = new TestColumnSelectorFactory();
|
||||
Aggregator res = target.factorize(columnSelectorFactory);
|
||||
Assert.assertTrue(res instanceof StringAnyAggregator);
|
||||
columnSelectorFactory.moveSelectorCursorToNext();
|
||||
res.aggregate();
|
||||
Assert.assertEquals("CCCC", res.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMvdStringAnyAggregator()
|
||||
{
|
||||
TestColumnSelectorFactory columnSelectorFactory = new TestColumnSelectorFactory();
|
||||
Aggregator res = target.factorize(columnSelectorFactory);
|
||||
Assert.assertTrue(res instanceof StringAnyAggregator);
|
||||
columnSelectorFactory.moveSelectorCursorToNext();
|
||||
columnSelectorFactory.moveSelectorCursorToNext();
|
||||
res.aggregate();
|
||||
Assert.assertEquals("[AAAA, AAA", res.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMvdStringAnyAggregatorWithAggregateMultipleToFalse()
|
||||
{
|
||||
StringAnyAggregatorFactory target = new StringAnyAggregatorFactory(NAME, FIELD_NAME, MAX_STRING_BYTES, false);
|
||||
TestColumnSelectorFactory columnSelectorFactory = new TestColumnSelectorFactory();
|
||||
Aggregator res = target.factorize(columnSelectorFactory);
|
||||
Assert.assertTrue(res instanceof StringAnyAggregator);
|
||||
columnSelectorFactory.moveSelectorCursorToNext();
|
||||
columnSelectorFactory.moveSelectorCursorToNext();
|
||||
res.aggregate();
|
||||
// picks up first value in mvd list
|
||||
Assert.assertEquals("AAAA", res.get());
|
||||
}
|
||||
|
||||
static class TestColumnSelectorFactory implements ColumnSelectorFactory
|
||||
{
|
||||
List<String> mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC");
|
||||
final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"};
|
||||
Integer maxStringBytes = 1024;
|
||||
TestObjectColumnSelector<Object> objectColumnSelector = new TestObjectColumnSelector<>(mvds);
|
||||
|
||||
@Override
|
||||
public DimensionSelector makeDimensionSelector(DimensionSpec dimensionSpec)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ColumnValueSelector<?> makeColumnValueSelector(String columnName)
|
||||
{
|
||||
return objectColumnSelector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ColumnCapabilities getColumnCapabilities(String columnName)
|
||||
{
|
||||
return ColumnCapabilitiesImpl.createDefault().setHasMultipleValues(true);
|
||||
}
|
||||
|
||||
public void moveSelectorCursorToNext()
|
||||
{
|
||||
objectColumnSelector.increment();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,15 +19,22 @@
|
|||
|
||||
package org.apache.druid.query.aggregation.any;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.druid.query.aggregation.BufferAggregator;
|
||||
import org.apache.druid.query.aggregation.TestObjectColumnSelector;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class StringAnyBufferAggregatorTest
|
||||
{
|
||||
StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
|
||||
"billy", "billy", 1024, true
|
||||
);
|
||||
|
||||
private void aggregateBuffer(
|
||||
TestObjectColumnSelector valueSelector,
|
||||
BufferAggregator agg,
|
||||
|
@ -44,17 +51,14 @@ public class StringAnyBufferAggregatorTest
|
|||
{
|
||||
|
||||
final String[] strings = {"AAAA", "BBBB", "CCCC", "DDDD", "EEEE"};
|
||||
Integer maxStringBytes = 1024;
|
||||
int maxStringBytes = 1024;
|
||||
|
||||
TestObjectColumnSelector<String> objectColumnSelector = new TestObjectColumnSelector<>(strings);
|
||||
|
||||
StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
|
||||
"billy", "billy", maxStringBytes
|
||||
);
|
||||
|
||||
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
|
||||
objectColumnSelector,
|
||||
maxStringBytes
|
||||
maxStringBytes,
|
||||
true
|
||||
);
|
||||
|
||||
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
|
||||
|
@ -75,17 +79,15 @@ public class StringAnyBufferAggregatorTest
|
|||
public void testBufferAggregateWithFoldCheck()
|
||||
{
|
||||
final String[] strings = {"AAAA", "BBBB", "CCCC", "DDDD", "EEEE"};
|
||||
Integer maxStringBytes = 1024;
|
||||
int maxStringBytes = 1024;
|
||||
|
||||
TestObjectColumnSelector<String> objectColumnSelector = new TestObjectColumnSelector<>(strings);
|
||||
|
||||
StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
|
||||
"billy", "billy", maxStringBytes
|
||||
);
|
||||
|
||||
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
|
||||
objectColumnSelector,
|
||||
maxStringBytes
|
||||
maxStringBytes,
|
||||
true
|
||||
);
|
||||
|
||||
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
|
||||
|
@ -108,17 +110,14 @@ public class StringAnyBufferAggregatorTest
|
|||
{
|
||||
|
||||
final String[] strings = {"CCCC", "AAAA", "BBBB", null, "EEEE"};
|
||||
Integer maxStringBytes = 1024;
|
||||
int maxStringBytes = 1024;
|
||||
|
||||
TestObjectColumnSelector<String> objectColumnSelector = new TestObjectColumnSelector<>(strings);
|
||||
|
||||
StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
|
||||
"billy", "billy", maxStringBytes
|
||||
);
|
||||
|
||||
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
|
||||
objectColumnSelector,
|
||||
maxStringBytes
|
||||
maxStringBytes,
|
||||
true
|
||||
);
|
||||
|
||||
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
|
||||
|
@ -140,17 +139,13 @@ public class StringAnyBufferAggregatorTest
|
|||
{
|
||||
|
||||
final String[] strings = {null, "CCCC", "AAAA", "BBBB", "EEEE"};
|
||||
Integer maxStringBytes = 1024;
|
||||
int maxStringBytes = 1024;
|
||||
|
||||
TestObjectColumnSelector<String> objectColumnSelector = new TestObjectColumnSelector<>(strings);
|
||||
|
||||
StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
|
||||
"billy", "billy", maxStringBytes
|
||||
);
|
||||
|
||||
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
|
||||
objectColumnSelector,
|
||||
maxStringBytes
|
||||
maxStringBytes, true
|
||||
);
|
||||
|
||||
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
|
||||
|
@ -170,19 +165,15 @@ public class StringAnyBufferAggregatorTest
|
|||
@Test
|
||||
public void testNonStringValue()
|
||||
{
|
||||
|
||||
final Double[] doubles = {1.00, 2.00};
|
||||
Integer maxStringBytes = 1024;
|
||||
int maxStringBytes = 1024;
|
||||
|
||||
TestObjectColumnSelector<Double> objectColumnSelector = new TestObjectColumnSelector<>(doubles);
|
||||
|
||||
StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
|
||||
"billy", "billy", maxStringBytes
|
||||
);
|
||||
|
||||
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
|
||||
objectColumnSelector,
|
||||
maxStringBytes
|
||||
maxStringBytes,
|
||||
true
|
||||
);
|
||||
|
||||
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
|
||||
|
@ -198,4 +189,77 @@ public class StringAnyBufferAggregatorTest
|
|||
|
||||
Assert.assertEquals("1.0", result);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMvds()
|
||||
{
|
||||
List<String> mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC");
|
||||
final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"};
|
||||
int maxStringBytes = 1024;
|
||||
|
||||
TestObjectColumnSelector<Object> objectColumnSelector = new TestObjectColumnSelector<>(mvds);
|
||||
|
||||
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
|
||||
objectColumnSelector,
|
||||
maxStringBytes, true
|
||||
);
|
||||
|
||||
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize() * 2);
|
||||
int position = 0;
|
||||
|
||||
int[] positions = new int[]{0, 1, 43, 100, 189};
|
||||
Arrays.stream(positions).forEach(i -> agg.init(buf, i));
|
||||
|
||||
//noinspection ForLoopReplaceableByForEach
|
||||
for (int i = 0; i < mvds.length; i++) {
|
||||
aggregateBuffer(objectColumnSelector, agg, buf, positions[i]);
|
||||
}
|
||||
String result = ((String) agg.get(buf, position));
|
||||
Assert.assertNull(result);
|
||||
|
||||
for (int i = 0; i < positions.length; i++) {
|
||||
if (i == 2) {
|
||||
Assert.assertEquals(mvd.toString(), agg.get(buf, positions[2]));
|
||||
} else {
|
||||
Assert.assertEquals(mvds[i], agg.get(buf, positions[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMvdsWithCustomAggregate()
|
||||
{
|
||||
List<String> mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC");
|
||||
final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"};
|
||||
final int maxStringBytes = 1024;
|
||||
|
||||
TestObjectColumnSelector<Object> objectColumnSelector = new TestObjectColumnSelector<>(mvds);
|
||||
|
||||
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
|
||||
objectColumnSelector,
|
||||
maxStringBytes, false
|
||||
);
|
||||
|
||||
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize() * 2);
|
||||
int position = 0;
|
||||
|
||||
int[] positions = new int[]{0, 1, 43, 100, 189};
|
||||
Arrays.stream(positions).forEach(i -> agg.init(buf, i));
|
||||
|
||||
//noinspection ForLoopReplaceableByForEach
|
||||
for (int i = 0; i < mvds.length; i++) {
|
||||
aggregateBuffer(objectColumnSelector, agg, buf, positions[i]);
|
||||
}
|
||||
String result = ((String) agg.get(buf, position));
|
||||
Assert.assertNull(result);
|
||||
|
||||
for (int i = 0; i < positions.length; i++) {
|
||||
if (i == 2) {
|
||||
// takes first in case of mvds
|
||||
Assert.assertEquals(mvd.get(0), agg.get(buf, positions[2]));
|
||||
} else {
|
||||
Assert.assertEquals(mvds[i], agg.get(buf, positions[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,8 @@ import org.mockito.Mockito;
|
|||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
|
||||
import static org.apache.druid.query.aggregation.any.StringAnyVectorAggregator.NOT_FOUND_FLAG_VALUE;
|
||||
|
@ -61,6 +63,7 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
|
|||
|
||||
private StringAnyVectorAggregator singleValueTarget;
|
||||
private StringAnyVectorAggregator multiValueTarget;
|
||||
private StringAnyVectorAggregator customMultiValueTarget;
|
||||
|
||||
@Before
|
||||
public void setUp()
|
||||
|
@ -74,20 +77,22 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
|
|||
return index >= DICTIONARY.length ? null : DICTIONARY[index];
|
||||
}).when(singleValueSelector).lookupName(anyInt());
|
||||
initializeRandomBuffer();
|
||||
singleValueTarget = new StringAnyVectorAggregator(singleValueSelector, null, MAX_STRING_BYTES);
|
||||
multiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector, MAX_STRING_BYTES);
|
||||
singleValueTarget = new StringAnyVectorAggregator(singleValueSelector, null, MAX_STRING_BYTES, true);
|
||||
multiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector, MAX_STRING_BYTES, true);
|
||||
// customMultiValueTarget aggregates to only single value in case of MVDs
|
||||
customMultiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector, MAX_STRING_BYTES, false);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void initWithBothSingleAndMultiValueSelectorShouldThrowException()
|
||||
{
|
||||
new StringAnyVectorAggregator(singleValueSelector, multiValueSelector, MAX_STRING_BYTES);
|
||||
new StringAnyVectorAggregator(singleValueSelector, multiValueSelector, MAX_STRING_BYTES, true);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void initWithNeitherSingleNorMultiValueSelectorShouldThrowException()
|
||||
{
|
||||
new StringAnyVectorAggregator(null, null, MAX_STRING_BYTES);
|
||||
new StringAnyVectorAggregator(null, null, MAX_STRING_BYTES, true);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -122,7 +127,7 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
|
|||
public void aggregateMultiValuePositionNotFoundShouldPutFirstValue()
|
||||
{
|
||||
multiValueTarget.aggregate(buf, POSITION, 0, 2);
|
||||
Assert.assertEquals(DICTIONARY[1], multiValueTarget.get(buf, POSITION));
|
||||
Assert.assertEquals("[One, Zero]", multiValueTarget.get(buf, POSITION));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -155,9 +160,9 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
|
|||
@Test
|
||||
public void aggregateBatchWithRowsShouldAggregateAllRows()
|
||||
{
|
||||
int[] positions = new int[] {0, 43, 100};
|
||||
int[] positions = new int[]{0, 43, 100};
|
||||
int positionOffset = 2;
|
||||
int[] rows = new int[] {2, 1, 0};
|
||||
int[] rows = new int[]{2, 1, 0};
|
||||
clearBufferForPositions(positionOffset, positions);
|
||||
multiValueTarget.aggregate(buf, 3, positions, rows, positionOffset);
|
||||
for (int i = 0; i < positions.length; i++) {
|
||||
|
@ -166,8 +171,32 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
|
|||
IndexedInts rowIndex = MULTI_VALUE_ROWS[row];
|
||||
if (rowIndex.size() == 0) {
|
||||
Assert.assertNull(multiValueTarget.get(buf, position));
|
||||
} else {
|
||||
} else if (rowIndex.size() == 1) {
|
||||
Assert.assertEquals(multiValueSelector.lookupName(rowIndex.get(0)), multiValueTarget.get(buf, position));
|
||||
} else {
|
||||
List<String> res = new ArrayList<>();
|
||||
rowIndex.forEach(index -> res.add(multiValueSelector.lookupName(index)));
|
||||
Assert.assertEquals(res.toString(), multiValueTarget.get(buf, position));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void aggregateBatchWithRowsShouldAggregateAllRowsWithAggregateMVDFalse()
|
||||
{
|
||||
int[] positions = new int[]{0, 43, 100};
|
||||
int positionOffset = 2;
|
||||
int[] rows = new int[]{2, 1, 0};
|
||||
clearBufferForPositions(positionOffset, positions);
|
||||
customMultiValueTarget.aggregate(buf, 3, positions, rows, positionOffset);
|
||||
for (int i = 0; i < positions.length; i++) {
|
||||
int position = positions[i] + positionOffset;
|
||||
int row = rows[i];
|
||||
IndexedInts rowIndex = MULTI_VALUE_ROWS[row];
|
||||
if (rowIndex.size() == 0) {
|
||||
Assert.assertNull(customMultiValueTarget.get(buf, position));
|
||||
} else {
|
||||
Assert.assertEquals(multiValueSelector.lookupName(rowIndex.get(0)), customMultiValueTarget.get(buf, position));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -489,7 +489,7 @@ public class FallbackVirtualColumnTest
|
|||
}
|
||||
}
|
||||
|
||||
private static class SameMultiVectorSelector implements MultiValueDimensionVectorSelector
|
||||
public static class SameMultiVectorSelector implements MultiValueDimensionVectorSelector
|
||||
{
|
||||
@Override
|
||||
public int getValueCardinality()
|
||||
|
|
|
@ -95,7 +95,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
|
|||
String fieldName,
|
||||
String timeColumn,
|
||||
ColumnType type,
|
||||
Integer maxStringBytes
|
||||
Integer maxStringBytes,
|
||||
Boolean aggregateMultipleValues
|
||||
)
|
||||
{
|
||||
switch (type.getType()) {
|
||||
|
@ -121,7 +122,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
|
|||
String fieldName,
|
||||
String timeColumn,
|
||||
ColumnType type,
|
||||
Integer maxStringBytes
|
||||
Integer maxStringBytes,
|
||||
Boolean aggregateMultipleValues
|
||||
)
|
||||
{
|
||||
switch (type.getType()) {
|
||||
|
@ -147,7 +149,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
|
|||
String fieldName,
|
||||
String timeColumn,
|
||||
ColumnType type,
|
||||
Integer maxStringBytes
|
||||
Integer maxStringBytes,
|
||||
Boolean aggregateMultipleValues
|
||||
)
|
||||
{
|
||||
switch (type.getType()) {
|
||||
|
@ -158,7 +161,7 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
|
|||
case DOUBLE:
|
||||
return new DoubleAnyAggregatorFactory(name, fieldName);
|
||||
case STRING:
|
||||
return new StringAnyAggregatorFactory(name, fieldName, maxStringBytes);
|
||||
return new StringAnyAggregatorFactory(name, fieldName, maxStringBytes, aggregateMultipleValues);
|
||||
default:
|
||||
throw SimpleSqlAggregator.badTypeException(fieldName, "ANY", type);
|
||||
}
|
||||
|
@ -170,7 +173,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
|
|||
String fieldName,
|
||||
String timeColumn,
|
||||
ColumnType outputType,
|
||||
Integer maxStringBytes
|
||||
Integer maxStringBytes,
|
||||
Boolean aggregateMultipleValues
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -244,37 +248,38 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
|
|||
final AggregatorFactory theAggFactory;
|
||||
switch (args.size()) {
|
||||
case 1:
|
||||
theAggFactory = aggregatorType.createAggregatorFactory(aggregatorName, fieldName, null, outputType, null);
|
||||
theAggFactory = aggregatorType.createAggregatorFactory(aggregatorName, fieldName, null, outputType, null, true);
|
||||
break;
|
||||
case 2:
|
||||
int maxStringBytes;
|
||||
try {
|
||||
maxStringBytes = RexLiteral.intValue(rexNodes.get(1));
|
||||
}
|
||||
catch (AssertionError ae) {
|
||||
plannerContext.setPlanningError(
|
||||
"The second argument '%s' to function '%s' is not a number",
|
||||
rexNodes.get(1),
|
||||
aggregateCall.getName()
|
||||
);
|
||||
return null;
|
||||
}
|
||||
Integer maxStringBytes = RexLiteral.intValue(rexNodes.get(1)); // added not null check at the function
|
||||
theAggFactory = aggregatorType.createAggregatorFactory(
|
||||
aggregatorName,
|
||||
fieldName,
|
||||
null,
|
||||
outputType,
|
||||
maxStringBytes
|
||||
maxStringBytes.intValue(),
|
||||
true
|
||||
);
|
||||
break;
|
||||
case 3:
|
||||
maxStringBytes = RexLiteral.intValue(rexNodes.get(1)); // added not null check at the function for rexNode 1,2
|
||||
boolean aggregateMultipleValues = RexLiteral.booleanValue(rexNodes.get(2));
|
||||
theAggFactory = aggregatorType.createAggregatorFactory(
|
||||
aggregatorName,
|
||||
fieldName,
|
||||
null,
|
||||
outputType,
|
||||
maxStringBytes,
|
||||
aggregateMultipleValues
|
||||
);
|
||||
break;
|
||||
default:
|
||||
throw InvalidSqlInput.exception(
|
||||
"Function [%s] expects 1 or 2 arguments but found [%s]",
|
||||
"Function [%s] expects 1 or 2 or 3 arguments but found [%s]",
|
||||
aggregateCall.getName(),
|
||||
args.size()
|
||||
);
|
||||
}
|
||||
|
||||
return Aggregation.create(
|
||||
Collections.singletonList(theAggFactory),
|
||||
finalizeAggregations ? new FinalizingFieldAccessPostAggregator(name, aggregatorName) : null
|
||||
|
@ -372,10 +377,11 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
|
|||
InferTypes.RETURN_TYPE,
|
||||
DefaultOperandTypeChecker
|
||||
.builder()
|
||||
.operandNames("expr", "maxBytesPerString")
|
||||
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
|
||||
.operandNames("expr", "maxBytesPerStringInt", "aggregateMultipleValuesBoolean")
|
||||
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.BOOLEAN)
|
||||
.requiredOperandCount(1)
|
||||
.literalOperands(1)
|
||||
.literalOperands(1, 2)
|
||||
.notNullOperands(1, 2)
|
||||
.build(),
|
||||
SqlFunctionCategory.USER_DEFINED_FUNCTION,
|
||||
false,
|
||||
|
@ -402,9 +408,9 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
|
|||
|
||||
SqlParserPos pos = call.getParserPosition();
|
||||
|
||||
if (operands.isEmpty() || operands.size() > 2) {
|
||||
if (operands.isEmpty() || operands.size() > 3) {
|
||||
throw InvalidSqlInput.exception(
|
||||
"Function [%s] expects 1 or 2 arguments but found [%s]",
|
||||
"Function [%s] expects 1 or 2 or 3 arguments but found [%s]",
|
||||
getName(),
|
||||
operands.size()
|
||||
);
|
||||
|
@ -417,6 +423,9 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
|
|||
if (operands.size() == 2) {
|
||||
newOperands.add(operands.get(1));
|
||||
}
|
||||
if (operands.size() == 3) {
|
||||
newOperands.add(operands.get(2));
|
||||
}
|
||||
|
||||
return replacementAggFunc.createCall(pos, newOperands);
|
||||
}
|
||||
|
|
|
@ -119,7 +119,8 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator
|
|||
rexNodes.get(1)
|
||||
),
|
||||
outputType,
|
||||
null
|
||||
null,
|
||||
true
|
||||
);
|
||||
break;
|
||||
case 3:
|
||||
|
@ -145,7 +146,8 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator
|
|||
rexNodes.get(1)
|
||||
),
|
||||
outputType,
|
||||
maxStringBytes
|
||||
maxStringBytes,
|
||||
true
|
||||
);
|
||||
break;
|
||||
default:
|
||||
|
|
|
@ -188,6 +188,7 @@ public class DefaultOperandTypeChecker implements SqlOperandTypeChecker
|
|||
@Nullable
|
||||
private Integer requiredOperandCount;
|
||||
private int[] literalOperands;
|
||||
private IntSet notNullOperands = new IntArraySet();
|
||||
|
||||
private Builder()
|
||||
{
|
||||
|
@ -229,6 +230,12 @@ public class DefaultOperandTypeChecker implements SqlOperandTypeChecker
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder notNullOperands(final int... notNullOperands)
|
||||
{
|
||||
Arrays.stream(notNullOperands).forEach(this.notNullOperands::add);
|
||||
return this;
|
||||
}
|
||||
|
||||
public DefaultOperandTypeChecker build()
|
||||
{
|
||||
int computedRequiredOperandCount = requiredOperandCount == null ? operandTypes.size() : requiredOperandCount;
|
||||
|
@ -236,16 +243,18 @@ public class DefaultOperandTypeChecker implements SqlOperandTypeChecker
|
|||
operandNames,
|
||||
operandTypes,
|
||||
computedRequiredOperandCount,
|
||||
DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount, operandTypes.size()),
|
||||
DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount, operandTypes.size(), notNullOperands),
|
||||
literalOperands
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public static IntSet buildNullableOperands(int requiredOperandCount, int totalOperandCount)
|
||||
public static IntSet buildNullableOperands(int requiredOperandCount, int totalOperandCount, IntSet notNullOperands)
|
||||
{
|
||||
final IntSet nullableOperands = new IntArraySet();
|
||||
IntStream.range(requiredOperandCount, totalOperandCount).forEach(nullableOperands::add);
|
||||
IntStream.range(requiredOperandCount, totalOperandCount)
|
||||
.filter(i -> !notNullOperands.contains(i))
|
||||
.forEach(nullableOperands::add);
|
||||
return nullableOperands;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -593,7 +593,7 @@ public class OperatorConversions
|
|||
{
|
||||
final IntSet nullableOperands = requiredOperandCount == null
|
||||
? new IntArraySet()
|
||||
: DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount, operandTypes.size());
|
||||
: DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount, operandTypes.size(), new IntArraySet());
|
||||
if (operandTypeInference == null) {
|
||||
SqlOperandTypeInference defaultInference = new DefaultOperandTypeInference(operandTypes, nullableOperands);
|
||||
return (callBinding, returnType, types) -> {
|
||||
|
|
|
@ -497,7 +497,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
|
|||
cannotVectorize();
|
||||
|
||||
testQuery(
|
||||
"SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100) FROM foo WHERE (TIME_FLOOR(__time, 'PT1H'), m1) IN\n"
|
||||
"SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100, true) FROM foo WHERE (TIME_FLOOR(__time, 'PT1H'), m1) IN\n"
|
||||
+ " (\n"
|
||||
+ " SELECT TIME_FLOOR(__time, 'PT1H') AS t1, MIN(m1) AS t2 FROM foo WHERE dim3 = 'b'\n"
|
||||
+ " AND __time BETWEEN '1994-04-29 00:00:00' AND '2020-01-11 00:00:00' GROUP BY 1\n"
|
||||
|
@ -532,7 +532,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
|
|||
)
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setAggregatorSpecs(aggregators(
|
||||
new StringAnyAggregatorFactory("a0", "dim3", 100)
|
||||
new StringAnyAggregatorFactory("a0", "dim3", 100, true)
|
||||
))
|
||||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
|
@ -598,7 +598,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
|
|||
)
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setAggregatorSpecs(aggregators(
|
||||
new StringAnyAggregatorFactory("a0", "dim3", 100)
|
||||
new StringAnyAggregatorFactory("a0", "dim3", 100, true)
|
||||
))
|
||||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
|
@ -609,6 +609,71 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
|
|||
)
|
||||
);
|
||||
}
|
||||
@Test
|
||||
public void testJoinOnGroupByInsteadOfTimeseriesWithFloorOnTimeWithNoAggregateMultipleValues()
|
||||
{
|
||||
// Cannot vectorize JOIN operator.
|
||||
cannotVectorize();
|
||||
|
||||
testQuery(
|
||||
"SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100, false) FROM foo WHERE (CAST(TIME_FLOOR(__time, 'PT1H') AS BIGINT) + 1, m1) IN\n"
|
||||
+ " (\n"
|
||||
+ " SELECT CAST(TIME_FLOOR(__time, 'PT1H') AS BIGINT) + 1 AS t1, MIN(m1) AS t2 FROM foo WHERE dim3 = 'b'\n"
|
||||
+ " AND __time BETWEEN '1994-04-29 00:00:00' AND '2020-01-11 00:00:00' GROUP BY 1\n"
|
||||
+ " )\n"
|
||||
+ "GROUP BY 1, 2\n",
|
||||
ImmutableList.of(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(
|
||||
join(
|
||||
new TableDataSource(CalciteTests.DATASOURCE1),
|
||||
new QueryDataSource(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE1)
|
||||
.setInterval(querySegmentSpec(Intervals.of(
|
||||
"1994-04-29/2020-01-11T00:00:00.001Z")))
|
||||
.setVirtualColumns(
|
||||
expressionVirtualColumn(
|
||||
"v0",
|
||||
"(timestamp_floor(\"__time\",'PT1H',null,'UTC') + 1)",
|
||||
ColumnType.LONG
|
||||
)
|
||||
)
|
||||
.setDimFilter(equality("dim3", "b", ColumnType.STRING))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setDimensions(dimensions(new DefaultDimensionSpec(
|
||||
"v0",
|
||||
"d0",
|
||||
ColumnType.LONG
|
||||
)))
|
||||
.setAggregatorSpecs(aggregators(
|
||||
new FloatMinAggregatorFactory("a0", "m1")
|
||||
))
|
||||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
.build()),
|
||||
"j0.",
|
||||
"(((timestamp_floor(\"__time\",'PT1H',null,'UTC') + 1) == \"j0.d0\") && (\"m1\" == \"j0.a0\"))",
|
||||
JoinType.INNER
|
||||
)
|
||||
)
|
||||
.setInterval(querySegmentSpec(Filtration.eternity()))
|
||||
.setDimensions(
|
||||
new DefaultDimensionSpec("__time", "d0", ColumnType.LONG),
|
||||
new DefaultDimensionSpec("m1", "d1", ColumnType.FLOAT)
|
||||
)
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setAggregatorSpecs(aggregators(
|
||||
new StringAnyAggregatorFactory("a0", "dim3", 100, false)
|
||||
))
|
||||
.setContext(QUERY_CONTEXT_DEFAULT)
|
||||
.build()
|
||||
),
|
||||
ImmutableList.of(
|
||||
new Object[]{946684800000L, 1.0f, "a"}, // picks up first from [a, b]
|
||||
new Object[]{946771200000L, 2.0f, "b"} // picks up first from [b, c]
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Parameters(source = QueryContextForJoinProvider.class)
|
||||
|
@ -1480,7 +1545,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
|
|||
new SubstringDimExtractionFn(0, 1)
|
||||
)
|
||||
)
|
||||
.setAggregatorSpecs(new StringAnyAggregatorFactory("a0", "v", 10))
|
||||
.setAggregatorSpecs(new StringAnyAggregatorFactory("a0", "v", 10, true))
|
||||
.build()
|
||||
),
|
||||
"j0.",
|
||||
|
|
|
@ -844,10 +844,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
new LongAnyAggregatorFactory("a0", "cnt"),
|
||||
new FloatAnyAggregatorFactory("a1", "m1"),
|
||||
new DoubleAnyAggregatorFactory("a2", "m2"),
|
||||
new StringAnyAggregatorFactory("a3", "dim1", 10),
|
||||
new StringAnyAggregatorFactory("a3", "dim1", 10, true),
|
||||
new LongAnyAggregatorFactory("a4", "v0"),
|
||||
new FloatAnyAggregatorFactory("a5", "v1"),
|
||||
new StringAnyAggregatorFactory("a6", "v2", 10)
|
||||
new StringAnyAggregatorFactory("a6", "v2", 10, true)
|
||||
)
|
||||
)
|
||||
.context(QUERY_CONTEXT_DEFAULT)
|
||||
|
@ -1420,7 +1420,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.setAggregatorSpecs(aggregators(new StringAnyAggregatorFactory(
|
||||
"a0:a",
|
||||
"dim1",
|
||||
10
|
||||
10, true
|
||||
)))
|
||||
.setPostAggregatorSpecs(
|
||||
ImmutableList.of(
|
||||
|
@ -1565,7 +1565,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.granularity(Granularities.ALL)
|
||||
.aggregators(
|
||||
aggregators(
|
||||
new StringAnyAggregatorFactory("a0", "dim1", 32),
|
||||
new StringAnyAggregatorFactory("a0", "dim1", 32, true),
|
||||
new LongAnyAggregatorFactory("a1", "l2"),
|
||||
new DoubleAnyAggregatorFactory("a2", "d2"),
|
||||
new FloatAnyAggregatorFactory("a3", "f2")
|
||||
|
@ -1607,7 +1607,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.filters(filter)
|
||||
.aggregators(
|
||||
aggregators(
|
||||
new StringAnyAggregatorFactory("a0", "dim1", 32),
|
||||
new StringAnyAggregatorFactory("a0", "dim1", 32, true),
|
||||
new LongAnyAggregatorFactory("a1", "l2"),
|
||||
new DoubleAnyAggregatorFactory("a2", "d2"),
|
||||
new FloatAnyAggregatorFactory("a3", "f2")
|
||||
|
@ -9422,7 +9422,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.granularity(Granularities.ALL)
|
||||
.aggregators(
|
||||
aggregators(
|
||||
new StringAnyAggregatorFactory("a0", "dim1", 1024),
|
||||
new StringAnyAggregatorFactory("a0", "dim1", 1024, true),
|
||||
new LongAnyAggregatorFactory("a1", "l1"),
|
||||
new StringFirstAggregatorFactory("a2", "dim1", null, 1024),
|
||||
new LongFirstAggregatorFactory("a3", "l1", null),
|
||||
|
@ -9741,7 +9741,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
.setAggregatorSpecs(
|
||||
aggregators(
|
||||
new FilteredAggregatorFactory(
|
||||
new StringAnyAggregatorFactory("a0", "dim1", 1024),
|
||||
new StringAnyAggregatorFactory("a0", "dim1", 1024, true),
|
||||
equality("dim1", "nonexistent", ColumnType.STRING)
|
||||
),
|
||||
new FilteredAggregatorFactory(
|
||||
|
@ -13533,6 +13533,24 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStringAnyAggArgValidation()
|
||||
{
|
||||
DruidException e = assertThrows(DruidException.class, () -> testBuilder()
|
||||
.sql("SELECT ANY_VALUE(dim3, 1000, 'true') FROM foo")
|
||||
.queryContext(ImmutableMap.of())
|
||||
.run());
|
||||
assertThat(e, invalidSqlIs(
|
||||
"Cannot apply 'ANY_VALUE' to arguments of type 'ANY_VALUE(<VARCHAR>, <INTEGER>, <CHAR(4)>)'. Supported form(s): 'ANY_VALUE(<expr>, [<maxBytesPerStringInt>, [<aggregateMultipleValuesBoolean>]])' (line [1], column [8])"));
|
||||
DruidException e1 = assertThrows(DruidException.class, () -> testBuilder()
|
||||
.sql("SELECT ANY_VALUE(dim3, 1000, null) FROM foo")
|
||||
.queryContext(ImmutableMap.of()).run());
|
||||
Assert.assertEquals("Illegal use of 'NULL' (line [1], column [30])", e1.getMessage());
|
||||
DruidException e2 = assertThrows(DruidException.class, () -> testBuilder()
|
||||
.sql("SELECT ANY_VALUE(dim3, null, true) FROM foo")
|
||||
.queryContext(ImmutableMap.of()).run());
|
||||
Assert.assertEquals("Illegal use of 'NULL' (line [1], column [24])", e2.getMessage());
|
||||
}
|
||||
@Test
|
||||
public void testStringAggMaxBytes()
|
||||
{
|
||||
|
@ -14367,7 +14385,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
|
|||
new StringLastAggregatorFactory("a1", "dim1", "__time", 1024),
|
||||
new StringFirstAggregatorFactory("a2", "dim3", "__time", 1024),
|
||||
new StringFirstAggregatorFactory("a3", "dim1", "__time", 1024),
|
||||
new StringAnyAggregatorFactory("a4", "dim3", 1024)))
|
||||
new StringAnyAggregatorFactory("a4", "dim3", 1024, true)))
|
||||
.build()
|
||||
|
||||
),
|
||||
|
|
|
@ -323,4 +323,10 @@ public class QueryTestBuilder
|
|||
return build().resultsOnly();
|
||||
}
|
||||
|
||||
public boolean isDecoupledMode()
|
||||
{
|
||||
String mode = (String) queryContext.getOrDefault(PlannerConfig.CTX_NATIVE_QUERY_SQL_PLANNING_MODE, "");
|
||||
return PlannerConfig.NATIVE_QUERY_SQL_PLANNING_MODE_DECOUPLED.equalsIgnoreCase(mode);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -2334,3 +2334,4 @@ LAST_VALUE
|
|||
markUnused
|
||||
markUsed
|
||||
segmentId
|
||||
aggregateMultipleValues
|
||||
|
|
Loading…
Reference in New Issue