diff --git a/docs/querying/sql-aggregations.md b/docs/querying/sql-aggregations.md
index d005ce1fd11..b2df640a68f 100644
--- a/docs/querying/sql-aggregations.md
+++ b/docs/querying/sql-aggregations.md
@@ -90,7 +90,7 @@ In the aggregation functions supported by Druid, only `COUNT`, `ARRAY_AGG`, and
|`EARLIEST_BY(expr, timestampExpr, [maxBytesPerValue])`|Returns the earliest value of `expr`. The earliest value of `expr` is taken from the row with the overall earliest non-null value of `timestampExpr`. If the earliest non-null value of `timestampExpr` appears in multiple rows, the `expr` may be taken from any of those rows. If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory. If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
|`LATEST(expr, [maxBytesPerValue])`|Returns the latest value of `expr` The `expr` must come from a relation with a timestamp column (like `__time` in a Druid datasource) and the "latest" is taken from the row with the overall latest non-null value of the timestamp column. If the latest non-null value of the timestamp column appears in multiple rows, the `expr` may be taken from any of those rows. If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory. If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
|`LATEST_BY(expr, timestampExpr, [maxBytesPerValue])`|Returns the latest value of `expr`. The latest value of `expr` is taken from the row with the overall latest non-null value of `timestampExpr`. If the overall latest non-null value of `timestampExpr` appears in multiple rows, the `expr` may be taken from any of those rows. If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory. If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
-|`ANY_VALUE(expr, [maxBytesPerValue])`|Returns any value of `expr` including null. This aggregator can simplify and optimize the performance by returning the first encountered value (including `null`). If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory. If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
+|`ANY_VALUE(expr, [maxBytesPerValue, [aggregateMultipleValues]])`|Returns any value of `expr` including null. This aggregator can simplify and optimize the performance by returning the first encountered value (including `null`). If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory. If `maxBytesPerValue` is omitted; it defaults to `1024`. `aggregateMultipleValues` is an optional boolean flag controls the behavior of aggregating a [multi-value dimension](./multi-value-dimensions.md). `aggregateMultipleValues` is set as true by default and returns the stringified array in case of a multi-value dimension. By setting it to false, function will return first value instead. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)|
|`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy dimension is included in a row, when using `GROUPING SETS`. Refer to [additional documentation](aggregations.md#grouping-aggregator) on how to infer this number.|N/A|
|`ARRAY_AGG(expr, [size])`|Collects all values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes). If the aggregated array grows larger than the maximum size in bytes, the query will fail. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the output array may vary depending on processing order.|`null`|
|`ARRAY_AGG(DISTINCT expr, [size])`|Collects all distinct values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes) per aggregate. If the aggregated array grows larger than the maximum size in bytes, the query will fail. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results will be based on the default for the element type.|`null`|
diff --git a/docs/querying/sql-functions.md b/docs/querying/sql-functions.md
index 8e43076518d..47b8ca90434 100644
--- a/docs/querying/sql-functions.md
+++ b/docs/querying/sql-functions.md
@@ -50,7 +50,7 @@ Calculates the arc cosine of a numeric expression.
## ANY_VALUE
-`ANY_VALUE(expr, [maxBytesPerValue])`
+`ANY_VALUE(expr, [maxBytesPerValue, [aggregateMultipleValues]])`
**Function type:** [Aggregation](sql-aggregations.md)
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregator.java
index 352b65e5646..aae267364c4 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregator.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregator.java
@@ -24,19 +24,23 @@ import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.segment.BaseObjectColumnValueSelector;
import org.apache.druid.segment.DimensionHandlerUtils;
+import java.util.List;
+
public class StringAnyAggregator implements Aggregator
{
private final BaseObjectColumnValueSelector valueSelector;
private final int maxStringBytes;
private boolean isFound;
private String foundValue;
+ private final boolean aggregateMultipleValues;
- public StringAnyAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes)
+ public StringAnyAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes, boolean aggregateMultipleValues)
{
this.valueSelector = valueSelector;
this.maxStringBytes = maxStringBytes;
this.foundValue = null;
this.isFound = false;
+ this.aggregateMultipleValues = aggregateMultipleValues;
}
@Override
@@ -44,18 +48,36 @@ public class StringAnyAggregator implements Aggregator
{
if (!isFound) {
final Object object = valueSelector.getObject();
- foundValue = DimensionHandlerUtils.convertObjectToString(object);
- if (foundValue != null && foundValue.length() > maxStringBytes) {
- foundValue = foundValue.substring(0, maxStringBytes);
- }
+ foundValue = StringUtils.fastLooseChop(readValue(object), maxStringBytes);
isFound = true;
}
}
+ private String readValue(final Object object)
+ {
+ if (object == null) {
+ return null;
+ }
+ if (object instanceof List) {
+ List objectList = (List) object;
+ if (objectList.size() == 0) {
+ return null;
+ }
+ if (objectList.size() == 1) {
+ return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
+ }
+ if (aggregateMultipleValues) {
+ return DimensionHandlerUtils.convertObjectToString(objectList);
+ }
+ return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
+ }
+ return DimensionHandlerUtils.convertObjectToString(object);
+ }
+
@Override
public Object get()
{
- return StringUtils.chop(foundValue, maxStringBytes);
+ return foundValue;
}
@Override
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactory.java
index 307de0650c3..67682bfde9d 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactory.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactory.java
@@ -48,13 +48,15 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
private final String fieldName;
private final String name;
- protected final int maxStringBytes;
+ private final int maxStringBytes;
+ private final boolean aggregateMultipleValues;
@JsonCreator
public StringAnyAggregatorFactory(
@JsonProperty("name") String name,
@JsonProperty("fieldName") final String fieldName,
- @JsonProperty("maxStringBytes") Integer maxStringBytes
+ @JsonProperty("maxStringBytes") Integer maxStringBytes,
+ @JsonProperty("aggregateMultipleValues") @Nullable final Boolean aggregateMultipleValues
)
{
Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name");
@@ -67,18 +69,19 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
this.maxStringBytes = maxStringBytes == null
? StringFirstAggregatorFactory.DEFAULT_MAX_STRING_SIZE
: maxStringBytes;
+ this.aggregateMultipleValues = aggregateMultipleValues == null ? true : aggregateMultipleValues;
}
@Override
public Aggregator factorize(ColumnSelectorFactory metricFactory)
{
- return new StringAnyAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes);
+ return new StringAnyAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes, aggregateMultipleValues);
}
@Override
public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
{
- return new StringAnyBufferAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes);
+ return new StringAnyBufferAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes, aggregateMultipleValues);
}
@Override
@@ -90,13 +93,15 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
return new StringAnyVectorAggregator(
null,
selectorFactory.makeMultiValueDimensionSelector(DefaultDimensionSpec.of(fieldName)),
- maxStringBytes
+ maxStringBytes,
+ aggregateMultipleValues
);
} else {
return new StringAnyVectorAggregator(
selectorFactory.makeSingleValueDimensionSelector(DefaultDimensionSpec.of(fieldName)),
null,
- maxStringBytes
+ maxStringBytes,
+ aggregateMultipleValues
);
}
}
@@ -122,7 +127,7 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
@Override
public AggregatorFactory getCombiningFactory()
{
- return new StringAnyAggregatorFactory(name, name, maxStringBytes);
+ return new StringAnyAggregatorFactory(name, name, maxStringBytes, aggregateMultipleValues);
}
@Override
@@ -155,6 +160,11 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
{
return maxStringBytes;
}
+ @JsonProperty
+ public boolean getAggregateMultipleValues()
+ {
+ return aggregateMultipleValues;
+ }
@Override
public List requiredFields()
@@ -192,7 +202,7 @@ public class StringAnyAggregatorFactory extends AggregatorFactory
@Override
public AggregatorFactory withName(String newName)
{
- return new StringAnyAggregatorFactory(newName, getFieldName(), getMaxStringBytes());
+ return new StringAnyAggregatorFactory(newName, getFieldName(), getMaxStringBytes(), getAggregateMultipleValues());
}
@Override
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregator.java
index 32bb3153fa2..86b8c51a469 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregator.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregator.java
@@ -25,6 +25,7 @@ import org.apache.druid.segment.BaseObjectColumnValueSelector;
import org.apache.druid.segment.DimensionHandlerUtils;
import java.nio.ByteBuffer;
+import java.util.List;
public class StringAnyBufferAggregator implements BufferAggregator
{
@@ -34,11 +35,13 @@ public class StringAnyBufferAggregator implements BufferAggregator
private final BaseObjectColumnValueSelector valueSelector;
private final int maxStringBytes;
+ private final boolean aggregateMultipleValues;
- public StringAnyBufferAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes)
+ public StringAnyBufferAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes, boolean aggregateMultipleValues)
{
this.valueSelector = valueSelector;
this.maxStringBytes = maxStringBytes;
+ this.aggregateMultipleValues = aggregateMultipleValues;
}
@Override
@@ -51,8 +54,7 @@ public class StringAnyBufferAggregator implements BufferAggregator
public void aggregate(ByteBuffer buf, int position)
{
if (buf.getInt(position) == NOT_FOUND_FLAG_VALUE) {
- final Object object = valueSelector.getObject();
- String foundValue = DimensionHandlerUtils.convertObjectToString(object);
+ String foundValue = readValue(valueSelector.getObject());
if (foundValue != null) {
ByteBuffer mutationBuffer = buf.duplicate();
mutationBuffer.position(position + FOUND_VALUE_OFFSET);
@@ -65,6 +67,27 @@ public class StringAnyBufferAggregator implements BufferAggregator
}
}
+ private String readValue(Object object)
+ {
+ if (object == null) {
+ return null;
+ }
+ if (object instanceof List) {
+ List objectList = (List) object;
+ if (objectList.size() == 0) {
+ return null;
+ }
+ if (objectList.size() == 1) {
+ return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
+ }
+ if (aggregateMultipleValues) {
+ return DimensionHandlerUtils.convertObjectToString(objectList);
+ }
+ return DimensionHandlerUtils.convertObjectToString(objectList.get(0));
+ }
+ return DimensionHandlerUtils.convertObjectToString(object);
+ }
+
@Override
public Object get(ByteBuffer buf, int position)
{
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregator.java
index 620801bafa3..104ee7a77bf 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregator.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregator.java
@@ -23,12 +23,15 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.VectorAggregator;
+import org.apache.druid.segment.DimensionHandlerUtils;
import org.apache.druid.segment.data.IndexedInts;
import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
public class StringAnyVectorAggregator implements VectorAggregator
{
@@ -43,11 +46,13 @@ public class StringAnyVectorAggregator implements VectorAggregator
@Nullable
private final MultiValueDimensionVectorSelector multiValueSelector;
private final int maxStringBytes;
+ private final boolean aggregateMultipleValues;
public StringAnyVectorAggregator(
SingleValueDimensionVectorSelector singleValueSelector,
MultiValueDimensionVectorSelector multiValueSelector,
- int maxStringBytes
+ int maxStringBytes,
+ final boolean aggregateMultipleValues
)
{
Preconditions.checkState(
@@ -61,6 +66,7 @@ public class StringAnyVectorAggregator implements VectorAggregator
this.multiValueSelector = multiValueSelector;
this.singleValueSelector = singleValueSelector;
this.maxStringBytes = maxStringBytes;
+ this.aggregateMultipleValues = aggregateMultipleValues;
}
@Override
@@ -78,7 +84,7 @@ public class StringAnyVectorAggregator implements VectorAggregator
if (startRow < rows.length) {
IndexedInts row = rows[startRow];
@Nullable
- String foundValue = row.size() == 0 ? null : multiValueSelector.lookupName(row.get(0));
+ String foundValue = readValue(row);
putValue(buf, position, foundValue);
}
} else if (singleValueSelector != null) {
@@ -93,6 +99,24 @@ public class StringAnyVectorAggregator implements VectorAggregator
}
}
+ private String readValue(IndexedInts row)
+ {
+ if (row.size() == 0) {
+ return null;
+ }
+ if (aggregateMultipleValues) {
+ if (row.size() == 1) {
+ return multiValueSelector.lookupName(row.get(0));
+ }
+ List arrayList = new ArrayList<>();
+ row.forEach(rowIndex -> {
+ arrayList.add(multiValueSelector.lookupName(rowIndex));
+ });
+ return DimensionHandlerUtils.convertObjectToString(arrayList);
+ }
+ return multiValueSelector.lookupName(row.get(0));
+ }
+
@Override
public void aggregate(
ByteBuffer buf,
@@ -142,4 +166,9 @@ public class StringAnyVectorAggregator implements VectorAggregator
buf.putInt(position, FOUND_AND_NULL_FLAG_VALUE);
}
}
+
+ public boolean isAggregateMultipleValues()
+ {
+ return aggregateMultipleValues;
+ }
}
diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/AggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/AggregatorFactoryTest.java
index 87d0e3dfdd8..1fea653c6f7 100644
--- a/processing/src/test/java/org/apache/druid/query/aggregation/AggregatorFactoryTest.java
+++ b/processing/src/test/java/org/apache/druid/query/aggregation/AggregatorFactoryTest.java
@@ -148,7 +148,7 @@ public class AggregatorFactoryTest extends InitializedNullHandlingTest
// string aggregators
new StringFirstAggregatorFactory("stringFirst", "col", null, 1024),
new StringLastAggregatorFactory("stringLast", "col", null, 1024),
- new StringAnyAggregatorFactory("stringAny", "col", 1024),
+ new StringAnyAggregatorFactory("stringAny", "col", 1024, true),
// sketch aggs
new CardinalityAggregatorFactory("cardinality", ImmutableList.of(DefaultDimensionSpec.of("some-col")), false),
new HyperUniquesAggregatorFactory("hyperUnique", "hyperunique"),
@@ -307,7 +307,8 @@ public class AggregatorFactoryTest extends InitializedNullHandlingTest
// string aggregators
new StringFirstAggregatorFactory("col", "col", null, 1024),
new StringLastAggregatorFactory("col", "col", null, 1024),
- new StringAnyAggregatorFactory("col", "col", 1024),
+ new StringAnyAggregatorFactory("col", "col", 1024, true),
+ new StringAnyAggregatorFactory("col", "col", 1024, false),
// sketch aggs
new CardinalityAggregatorFactory("col", ImmutableList.of(DefaultDimensionSpec.of("some-col")), false),
new HyperUniquesAggregatorFactory("col", "hyperunique"),
diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregationTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregationTest.java
index 208cfeb052d..b728049d625 100644
--- a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregationTest.java
+++ b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregationTest.java
@@ -49,7 +49,7 @@ public class StringAnyAggregationTest
@Before
public void setup()
{
- stringAnyAggFactory = new StringAnyAggregatorFactory("billy", "nilly", MAX_STRING_SIZE);
+ stringAnyAggFactory = new StringAnyAggregatorFactory("billy", "nilly", MAX_STRING_SIZE, true);
combiningAggFactory = stringAnyAggFactory.getCombiningFactory();
valueSelector = new TestObjectColumnSelector<>(strings);
objectSelector = new TestObjectColumnSelector<>(strings);
diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactoryTest.java
index c480b9bb2ef..88351125f55 100644
--- a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactoryTest.java
+++ b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactoryTest.java
@@ -19,48 +19,44 @@
package org.apache.druid.query.aggregation.any;
-import org.apache.druid.segment.ColumnInspector;
+import com.google.common.collect.Lists;
+import org.apache.druid.query.aggregation.Aggregator;
+import org.apache.druid.query.aggregation.TestObjectColumnSelector;
+import org.apache.druid.query.dimension.DimensionSpec;
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.ColumnValueSelector;
+import org.apache.druid.segment.DimensionSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
-import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector;
-import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
-import org.apache.druid.segment.vector.VectorColumnSelectorFactory;
+import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
+import org.apache.druid.segment.vector.TestVectorColumnSelectorFactory;
+import org.apache.druid.segment.virtual.FallbackVirtualColumnTest;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.Mockito;
-import org.mockito.junit.MockitoJUnitRunner;
-import static org.mockito.ArgumentMatchers.any;
+import java.util.List;
-@RunWith(MockitoJUnitRunner.class)
public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
{
private static final String NAME = "NAME";
private static final String FIELD_NAME = "FIELD_NAME";
private static final int MAX_STRING_BYTES = 10;
- @Mock
- private ColumnInspector columnInspector;
- @Mock
+ private TestColumnSelectorFactory columnInspector;
private ColumnCapabilities capabilities;
- @Mock
- private VectorColumnSelectorFactory vectorSelectorFactory;
- @Mock
- private SingleValueDimensionVectorSelector singleValueDimensionVectorSelector;
- @Mock
- private MultiValueDimensionVectorSelector multiValueDimensionVectorSelector;
-
+ private TestVectorColumnSelectorFactory vectorSelectorFactory;
private StringAnyAggregatorFactory target;
@Before
public void setUp()
{
- Mockito.doReturn(capabilities).when(vectorSelectorFactory).getColumnCapabilities(FIELD_NAME);
- Mockito.doReturn(ColumnCapabilities.Capable.UNKNOWN).when(capabilities).hasMultipleValues();
- target = new StringAnyAggregatorFactory(NAME, FIELD_NAME, MAX_STRING_BYTES);
+ target = new StringAnyAggregatorFactory(NAME, FIELD_NAME, MAX_STRING_BYTES, true);
+ columnInspector = new TestColumnSelectorFactory();
+ vectorSelectorFactory = new TestVectorColumnSelectorFactory();
+ capabilities = ColumnCapabilitiesImpl.createDefault().setHasMultipleValues(true);
+ vectorSelectorFactory.addCapabilities(FIELD_NAME, capabilities);
+ vectorSelectorFactory.addMVDVS(FIELD_NAME, new FallbackVirtualColumnTest.SameMultiVectorSelector());
}
@Test
@@ -72,10 +68,6 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
@Test
public void factorizeVectorWithoutCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector()
{
- Mockito.doReturn(null).when(vectorSelectorFactory).getColumnCapabilities(FIELD_NAME);
- Mockito.doReturn(singleValueDimensionVectorSelector)
- .when(vectorSelectorFactory)
- .makeSingleValueDimensionSelector(any());
StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory);
Assert.assertNotNull(aggregator);
}
@@ -83,9 +75,6 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
@Test
public void factorizeVectorWithUnknownCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector()
{
- Mockito.doReturn(multiValueDimensionVectorSelector)
- .when(vectorSelectorFactory)
- .makeMultiValueDimensionSelector(any());
StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory);
Assert.assertNotNull(aggregator);
}
@@ -93,10 +82,6 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
@Test
public void factorizeVectorWithMultipleValuesCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector()
{
- Mockito.doReturn(ColumnCapabilities.Capable.TRUE).when(capabilities).hasMultipleValues();
- Mockito.doReturn(multiValueDimensionVectorSelector)
- .when(vectorSelectorFactory)
- .makeMultiValueDimensionSelector(any());
StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory);
Assert.assertNotNull(aggregator);
}
@@ -104,11 +89,86 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest
@Test
public void factorizeVectorWithoutMultipleValuesCapabilitiesShouldReturnAggregatorWithSingleDimensionSelector()
{
- Mockito.doReturn(ColumnCapabilities.Capable.FALSE).when(capabilities).hasMultipleValues();
- Mockito.doReturn(singleValueDimensionVectorSelector)
- .when(vectorSelectorFactory)
- .makeSingleValueDimensionSelector(any());
StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory);
Assert.assertNotNull(aggregator);
}
+
+ @Test
+ public void testFactorize()
+ {
+ Aggregator res = target.factorize(new TestColumnSelectorFactory());
+ Assert.assertTrue(res instanceof StringAnyAggregator);
+ res.aggregate();
+ Assert.assertEquals(null, res.get());
+ StringAnyVectorAggregator vectorAggregator = target.factorizeVector(vectorSelectorFactory);
+ Assert.assertTrue(vectorAggregator.isAggregateMultipleValues());
+ }
+
+ @Test
+ public void testSvdStringAnyAggregator()
+ {
+ TestColumnSelectorFactory columnSelectorFactory = new TestColumnSelectorFactory();
+ Aggregator res = target.factorize(columnSelectorFactory);
+ Assert.assertTrue(res instanceof StringAnyAggregator);
+ columnSelectorFactory.moveSelectorCursorToNext();
+ res.aggregate();
+ Assert.assertEquals("CCCC", res.get());
+ }
+
+ @Test
+ public void testMvdStringAnyAggregator()
+ {
+ TestColumnSelectorFactory columnSelectorFactory = new TestColumnSelectorFactory();
+ Aggregator res = target.factorize(columnSelectorFactory);
+ Assert.assertTrue(res instanceof StringAnyAggregator);
+ columnSelectorFactory.moveSelectorCursorToNext();
+ columnSelectorFactory.moveSelectorCursorToNext();
+ res.aggregate();
+ Assert.assertEquals("[AAAA, AAA", res.get());
+ }
+
+ @Test
+ public void testMvdStringAnyAggregatorWithAggregateMultipleToFalse()
+ {
+ StringAnyAggregatorFactory target = new StringAnyAggregatorFactory(NAME, FIELD_NAME, MAX_STRING_BYTES, false);
+ TestColumnSelectorFactory columnSelectorFactory = new TestColumnSelectorFactory();
+ Aggregator res = target.factorize(columnSelectorFactory);
+ Assert.assertTrue(res instanceof StringAnyAggregator);
+ columnSelectorFactory.moveSelectorCursorToNext();
+ columnSelectorFactory.moveSelectorCursorToNext();
+ res.aggregate();
+ // picks up first value in mvd list
+ Assert.assertEquals("AAAA", res.get());
+ }
+
+ static class TestColumnSelectorFactory implements ColumnSelectorFactory
+ {
+ List mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC");
+ final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"};
+ Integer maxStringBytes = 1024;
+ TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(mvds);
+
+ @Override
+ public DimensionSelector makeDimensionSelector(DimensionSpec dimensionSpec)
+ {
+ return null;
+ }
+
+ @Override
+ public ColumnValueSelector> makeColumnValueSelector(String columnName)
+ {
+ return objectColumnSelector;
+ }
+
+ @Override
+ public ColumnCapabilities getColumnCapabilities(String columnName)
+ {
+ return ColumnCapabilitiesImpl.createDefault().setHasMultipleValues(true);
+ }
+
+ public void moveSelectorCursorToNext()
+ {
+ objectColumnSelector.increment();
+ }
+ }
}
diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregatorTest.java
index 658db6f7eec..1db6cbe2b1d 100644
--- a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregatorTest.java
+++ b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregatorTest.java
@@ -19,15 +19,22 @@
package org.apache.druid.query.aggregation.any;
+import com.google.common.collect.Lists;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.TestObjectColumnSelector;
import org.junit.Assert;
import org.junit.Test;
import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.List;
public class StringAnyBufferAggregatorTest
{
+ StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
+ "billy", "billy", 1024, true
+ );
+
private void aggregateBuffer(
TestObjectColumnSelector valueSelector,
BufferAggregator agg,
@@ -44,17 +51,14 @@ public class StringAnyBufferAggregatorTest
{
final String[] strings = {"AAAA", "BBBB", "CCCC", "DDDD", "EEEE"};
- Integer maxStringBytes = 1024;
+ int maxStringBytes = 1024;
TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(strings);
- StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
- "billy", "billy", maxStringBytes
- );
-
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector,
- maxStringBytes
+ maxStringBytes,
+ true
);
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@@ -75,17 +79,15 @@ public class StringAnyBufferAggregatorTest
public void testBufferAggregateWithFoldCheck()
{
final String[] strings = {"AAAA", "BBBB", "CCCC", "DDDD", "EEEE"};
- Integer maxStringBytes = 1024;
+ int maxStringBytes = 1024;
TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(strings);
- StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
- "billy", "billy", maxStringBytes
- );
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector,
- maxStringBytes
+ maxStringBytes,
+ true
);
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@@ -108,17 +110,14 @@ public class StringAnyBufferAggregatorTest
{
final String[] strings = {"CCCC", "AAAA", "BBBB", null, "EEEE"};
- Integer maxStringBytes = 1024;
+ int maxStringBytes = 1024;
TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(strings);
- StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
- "billy", "billy", maxStringBytes
- );
-
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector,
- maxStringBytes
+ maxStringBytes,
+ true
);
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@@ -140,17 +139,13 @@ public class StringAnyBufferAggregatorTest
{
final String[] strings = {null, "CCCC", "AAAA", "BBBB", "EEEE"};
- Integer maxStringBytes = 1024;
+ int maxStringBytes = 1024;
TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(strings);
- StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
- "billy", "billy", maxStringBytes
- );
-
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector,
- maxStringBytes
+ maxStringBytes, true
);
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@@ -170,19 +165,15 @@ public class StringAnyBufferAggregatorTest
@Test
public void testNonStringValue()
{
-
final Double[] doubles = {1.00, 2.00};
- Integer maxStringBytes = 1024;
+ int maxStringBytes = 1024;
TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(doubles);
- StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory(
- "billy", "billy", maxStringBytes
- );
-
StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
objectColumnSelector,
- maxStringBytes
+ maxStringBytes,
+ true
);
ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize());
@@ -198,4 +189,77 @@ public class StringAnyBufferAggregatorTest
Assert.assertEquals("1.0", result);
}
+
+ @Test
+ public void testMvds()
+ {
+ List mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC");
+ final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"};
+ int maxStringBytes = 1024;
+
+ TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(mvds);
+
+ StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
+ objectColumnSelector,
+ maxStringBytes, true
+ );
+
+ ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize() * 2);
+ int position = 0;
+
+ int[] positions = new int[]{0, 1, 43, 100, 189};
+ Arrays.stream(positions).forEach(i -> agg.init(buf, i));
+
+ //noinspection ForLoopReplaceableByForEach
+ for (int i = 0; i < mvds.length; i++) {
+ aggregateBuffer(objectColumnSelector, agg, buf, positions[i]);
+ }
+ String result = ((String) agg.get(buf, position));
+ Assert.assertNull(result);
+
+ for (int i = 0; i < positions.length; i++) {
+ if (i == 2) {
+ Assert.assertEquals(mvd.toString(), agg.get(buf, positions[2]));
+ } else {
+ Assert.assertEquals(mvds[i], agg.get(buf, positions[i]));
+ }
+ }
+ }
+
+ @Test
+ public void testMvdsWithCustomAggregate()
+ {
+ List mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC");
+ final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"};
+ final int maxStringBytes = 1024;
+
+ TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(mvds);
+
+ StringAnyBufferAggregator agg = new StringAnyBufferAggregator(
+ objectColumnSelector,
+ maxStringBytes, false
+ );
+
+ ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize() * 2);
+ int position = 0;
+
+ int[] positions = new int[]{0, 1, 43, 100, 189};
+ Arrays.stream(positions).forEach(i -> agg.init(buf, i));
+
+ //noinspection ForLoopReplaceableByForEach
+ for (int i = 0; i < mvds.length; i++) {
+ aggregateBuffer(objectColumnSelector, agg, buf, positions[i]);
+ }
+ String result = ((String) agg.get(buf, position));
+ Assert.assertNull(result);
+
+ for (int i = 0; i < positions.length; i++) {
+ if (i == 2) {
+ // takes first in case of mvds
+ Assert.assertEquals(mvd.get(0), agg.get(buf, positions[2]));
+ } else {
+ Assert.assertEquals(mvds[i], agg.get(buf, positions[i]));
+ }
+ }
+ }
}
diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregatorTest.java
index bb9ca74dfb4..b6555f6d2af 100644
--- a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregatorTest.java
+++ b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregatorTest.java
@@ -33,6 +33,8 @@ import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import static org.apache.druid.query.aggregation.any.StringAnyVectorAggregator.NOT_FOUND_FLAG_VALUE;
@@ -61,6 +63,7 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
private StringAnyVectorAggregator singleValueTarget;
private StringAnyVectorAggregator multiValueTarget;
+ private StringAnyVectorAggregator customMultiValueTarget;
@Before
public void setUp()
@@ -74,20 +77,22 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
return index >= DICTIONARY.length ? null : DICTIONARY[index];
}).when(singleValueSelector).lookupName(anyInt());
initializeRandomBuffer();
- singleValueTarget = new StringAnyVectorAggregator(singleValueSelector, null, MAX_STRING_BYTES);
- multiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector, MAX_STRING_BYTES);
+ singleValueTarget = new StringAnyVectorAggregator(singleValueSelector, null, MAX_STRING_BYTES, true);
+ multiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector, MAX_STRING_BYTES, true);
+ // customMultiValueTarget aggregates to only single value in case of MVDs
+ customMultiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector, MAX_STRING_BYTES, false);
}
@Test(expected = IllegalStateException.class)
public void initWithBothSingleAndMultiValueSelectorShouldThrowException()
{
- new StringAnyVectorAggregator(singleValueSelector, multiValueSelector, MAX_STRING_BYTES);
+ new StringAnyVectorAggregator(singleValueSelector, multiValueSelector, MAX_STRING_BYTES, true);
}
@Test(expected = IllegalStateException.class)
public void initWithNeitherSingleNorMultiValueSelectorShouldThrowException()
{
- new StringAnyVectorAggregator(null, null, MAX_STRING_BYTES);
+ new StringAnyVectorAggregator(null, null, MAX_STRING_BYTES, true);
}
@Test
@@ -122,7 +127,7 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
public void aggregateMultiValuePositionNotFoundShouldPutFirstValue()
{
multiValueTarget.aggregate(buf, POSITION, 0, 2);
- Assert.assertEquals(DICTIONARY[1], multiValueTarget.get(buf, POSITION));
+ Assert.assertEquals("[One, Zero]", multiValueTarget.get(buf, POSITION));
}
@Test
@@ -155,9 +160,9 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
@Test
public void aggregateBatchWithRowsShouldAggregateAllRows()
{
- int[] positions = new int[] {0, 43, 100};
+ int[] positions = new int[]{0, 43, 100};
int positionOffset = 2;
- int[] rows = new int[] {2, 1, 0};
+ int[] rows = new int[]{2, 1, 0};
clearBufferForPositions(positionOffset, positions);
multiValueTarget.aggregate(buf, 3, positions, rows, positionOffset);
for (int i = 0; i < positions.length; i++) {
@@ -166,8 +171,32 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest
IndexedInts rowIndex = MULTI_VALUE_ROWS[row];
if (rowIndex.size() == 0) {
Assert.assertNull(multiValueTarget.get(buf, position));
- } else {
+ } else if (rowIndex.size() == 1) {
Assert.assertEquals(multiValueSelector.lookupName(rowIndex.get(0)), multiValueTarget.get(buf, position));
+ } else {
+ List res = new ArrayList<>();
+ rowIndex.forEach(index -> res.add(multiValueSelector.lookupName(index)));
+ Assert.assertEquals(res.toString(), multiValueTarget.get(buf, position));
+ }
+ }
+ }
+
+ @Test
+ public void aggregateBatchWithRowsShouldAggregateAllRowsWithAggregateMVDFalse()
+ {
+ int[] positions = new int[]{0, 43, 100};
+ int positionOffset = 2;
+ int[] rows = new int[]{2, 1, 0};
+ clearBufferForPositions(positionOffset, positions);
+ customMultiValueTarget.aggregate(buf, 3, positions, rows, positionOffset);
+ for (int i = 0; i < positions.length; i++) {
+ int position = positions[i] + positionOffset;
+ int row = rows[i];
+ IndexedInts rowIndex = MULTI_VALUE_ROWS[row];
+ if (rowIndex.size() == 0) {
+ Assert.assertNull(customMultiValueTarget.get(buf, position));
+ } else {
+ Assert.assertEquals(multiValueSelector.lookupName(rowIndex.get(0)), customMultiValueTarget.get(buf, position));
}
}
}
diff --git a/processing/src/test/java/org/apache/druid/segment/virtual/FallbackVirtualColumnTest.java b/processing/src/test/java/org/apache/druid/segment/virtual/FallbackVirtualColumnTest.java
index ac6b5446144..72de5466ff6 100644
--- a/processing/src/test/java/org/apache/druid/segment/virtual/FallbackVirtualColumnTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/virtual/FallbackVirtualColumnTest.java
@@ -489,7 +489,7 @@ public class FallbackVirtualColumnTest
}
}
- private static class SameMultiVectorSelector implements MultiValueDimensionVectorSelector
+ public static class SameMultiVectorSelector implements MultiValueDimensionVectorSelector
{
@Override
public int getValueCardinality()
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java
index efa3a9e7e32..21bcc833e04 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java
@@ -95,7 +95,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
String fieldName,
String timeColumn,
ColumnType type,
- Integer maxStringBytes
+ Integer maxStringBytes,
+ Boolean aggregateMultipleValues
)
{
switch (type.getType()) {
@@ -121,7 +122,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
String fieldName,
String timeColumn,
ColumnType type,
- Integer maxStringBytes
+ Integer maxStringBytes,
+ Boolean aggregateMultipleValues
)
{
switch (type.getType()) {
@@ -147,7 +149,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
String fieldName,
String timeColumn,
ColumnType type,
- Integer maxStringBytes
+ Integer maxStringBytes,
+ Boolean aggregateMultipleValues
)
{
switch (type.getType()) {
@@ -158,7 +161,7 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
case DOUBLE:
return new DoubleAnyAggregatorFactory(name, fieldName);
case STRING:
- return new StringAnyAggregatorFactory(name, fieldName, maxStringBytes);
+ return new StringAnyAggregatorFactory(name, fieldName, maxStringBytes, aggregateMultipleValues);
default:
throw SimpleSqlAggregator.badTypeException(fieldName, "ANY", type);
}
@@ -170,7 +173,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
String fieldName,
String timeColumn,
ColumnType outputType,
- Integer maxStringBytes
+ Integer maxStringBytes,
+ Boolean aggregateMultipleValues
);
}
@@ -244,37 +248,38 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
final AggregatorFactory theAggFactory;
switch (args.size()) {
case 1:
- theAggFactory = aggregatorType.createAggregatorFactory(aggregatorName, fieldName, null, outputType, null);
+ theAggFactory = aggregatorType.createAggregatorFactory(aggregatorName, fieldName, null, outputType, null, true);
break;
case 2:
- int maxStringBytes;
- try {
- maxStringBytes = RexLiteral.intValue(rexNodes.get(1));
- }
- catch (AssertionError ae) {
- plannerContext.setPlanningError(
- "The second argument '%s' to function '%s' is not a number",
- rexNodes.get(1),
- aggregateCall.getName()
- );
- return null;
- }
+ Integer maxStringBytes = RexLiteral.intValue(rexNodes.get(1)); // added not null check at the function
theAggFactory = aggregatorType.createAggregatorFactory(
aggregatorName,
fieldName,
null,
outputType,
- maxStringBytes
+ maxStringBytes.intValue(),
+ true
+ );
+ break;
+ case 3:
+ maxStringBytes = RexLiteral.intValue(rexNodes.get(1)); // added not null check at the function for rexNode 1,2
+ boolean aggregateMultipleValues = RexLiteral.booleanValue(rexNodes.get(2));
+ theAggFactory = aggregatorType.createAggregatorFactory(
+ aggregatorName,
+ fieldName,
+ null,
+ outputType,
+ maxStringBytes,
+ aggregateMultipleValues
);
break;
default:
throw InvalidSqlInput.exception(
- "Function [%s] expects 1 or 2 arguments but found [%s]",
+ "Function [%s] expects 1 or 2 or 3 arguments but found [%s]",
aggregateCall.getName(),
args.size()
);
}
-
return Aggregation.create(
Collections.singletonList(theAggFactory),
finalizeAggregations ? new FinalizingFieldAccessPostAggregator(name, aggregatorName) : null
@@ -372,10 +377,11 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
InferTypes.RETURN_TYPE,
DefaultOperandTypeChecker
.builder()
- .operandNames("expr", "maxBytesPerString")
- .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
+ .operandNames("expr", "maxBytesPerStringInt", "aggregateMultipleValuesBoolean")
+ .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.BOOLEAN)
.requiredOperandCount(1)
- .literalOperands(1)
+ .literalOperands(1, 2)
+ .notNullOperands(1, 2)
.build(),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
@@ -402,9 +408,9 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
SqlParserPos pos = call.getParserPosition();
- if (operands.isEmpty() || operands.size() > 2) {
+ if (operands.isEmpty() || operands.size() > 3) {
throw InvalidSqlInput.exception(
- "Function [%s] expects 1 or 2 arguments but found [%s]",
+ "Function [%s] expects 1 or 2 or 3 arguments but found [%s]",
getName(),
operands.size()
);
@@ -417,6 +423,9 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator
if (operands.size() == 2) {
newOperands.add(operands.get(1));
}
+ if (operands.size() == 3) {
+ newOperands.add(operands.get(2));
+ }
return replacementAggFunc.createCall(pos, newOperands);
}
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java
index 03e23503a81..fac88d853e1 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java
@@ -119,7 +119,8 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator
rexNodes.get(1)
),
outputType,
- null
+ null,
+ true
);
break;
case 3:
@@ -145,7 +146,8 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator
rexNodes.get(1)
),
outputType,
- maxStringBytes
+ maxStringBytes,
+ true
);
break;
default:
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java
index f43fde3a935..a52ce5707c1 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java
@@ -188,6 +188,7 @@ public class DefaultOperandTypeChecker implements SqlOperandTypeChecker
@Nullable
private Integer requiredOperandCount;
private int[] literalOperands;
+ private IntSet notNullOperands = new IntArraySet();
private Builder()
{
@@ -229,6 +230,12 @@ public class DefaultOperandTypeChecker implements SqlOperandTypeChecker
return this;
}
+ public Builder notNullOperands(final int... notNullOperands)
+ {
+ Arrays.stream(notNullOperands).forEach(this.notNullOperands::add);
+ return this;
+ }
+
public DefaultOperandTypeChecker build()
{
int computedRequiredOperandCount = requiredOperandCount == null ? operandTypes.size() : requiredOperandCount;
@@ -236,16 +243,18 @@ public class DefaultOperandTypeChecker implements SqlOperandTypeChecker
operandNames,
operandTypes,
computedRequiredOperandCount,
- DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount, operandTypes.size()),
+ DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount, operandTypes.size(), notNullOperands),
literalOperands
);
}
}
- public static IntSet buildNullableOperands(int requiredOperandCount, int totalOperandCount)
+ public static IntSet buildNullableOperands(int requiredOperandCount, int totalOperandCount, IntSet notNullOperands)
{
final IntSet nullableOperands = new IntArraySet();
- IntStream.range(requiredOperandCount, totalOperandCount).forEach(nullableOperands::add);
+ IntStream.range(requiredOperandCount, totalOperandCount)
+ .filter(i -> !notNullOperands.contains(i))
+ .forEach(nullableOperands::add);
return nullableOperands;
}
}
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
index d655e12f799..450d6208240 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
@@ -593,7 +593,7 @@ public class OperatorConversions
{
final IntSet nullableOperands = requiredOperandCount == null
? new IntArraySet()
- : DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount, operandTypes.size());
+ : DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount, operandTypes.size(), new IntArraySet());
if (operandTypeInference == null) {
SqlOperandTypeInference defaultInference = new DefaultOperandTypeInference(operandTypes, nullableOperands);
return (callBinding, returnType, types) -> {
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
index a8e16de2675..3ead14a05a3 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
@@ -497,7 +497,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
cannotVectorize();
testQuery(
- "SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100) FROM foo WHERE (TIME_FLOOR(__time, 'PT1H'), m1) IN\n"
+ "SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100, true) FROM foo WHERE (TIME_FLOOR(__time, 'PT1H'), m1) IN\n"
+ " (\n"
+ " SELECT TIME_FLOOR(__time, 'PT1H') AS t1, MIN(m1) AS t2 FROM foo WHERE dim3 = 'b'\n"
+ " AND __time BETWEEN '1994-04-29 00:00:00' AND '2020-01-11 00:00:00' GROUP BY 1\n"
@@ -532,7 +532,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
)
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(
- new StringAnyAggregatorFactory("a0", "dim3", 100)
+ new StringAnyAggregatorFactory("a0", "dim3", 100, true)
))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
@@ -598,7 +598,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
)
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(aggregators(
- new StringAnyAggregatorFactory("a0", "dim3", 100)
+ new StringAnyAggregatorFactory("a0", "dim3", 100, true)
))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
@@ -609,6 +609,71 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
)
);
}
+ @Test
+ public void testJoinOnGroupByInsteadOfTimeseriesWithFloorOnTimeWithNoAggregateMultipleValues()
+ {
+ // Cannot vectorize JOIN operator.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100, false) FROM foo WHERE (CAST(TIME_FLOOR(__time, 'PT1H') AS BIGINT) + 1, m1) IN\n"
+ + " (\n"
+ + " SELECT CAST(TIME_FLOOR(__time, 'PT1H') AS BIGINT) + 1 AS t1, MIN(m1) AS t2 FROM foo WHERE dim3 = 'b'\n"
+ + " AND __time BETWEEN '1994-04-29 00:00:00' AND '2020-01-11 00:00:00' GROUP BY 1\n"
+ + " )\n"
+ + "GROUP BY 1, 2\n",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(
+ join(
+ new TableDataSource(CalciteTests.DATASOURCE1),
+ new QueryDataSource(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Intervals.of(
+ "1994-04-29/2020-01-11T00:00:00.001Z")))
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+ "(timestamp_floor(\"__time\",'PT1H',null,'UTC') + 1)",
+ ColumnType.LONG
+ )
+ )
+ .setDimFilter(equality("dim3", "b", ColumnType.STRING))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(dimensions(new DefaultDimensionSpec(
+ "v0",
+ "d0",
+ ColumnType.LONG
+ )))
+ .setAggregatorSpecs(aggregators(
+ new FloatMinAggregatorFactory("a0", "m1")
+ ))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()),
+ "j0.",
+ "(((timestamp_floor(\"__time\",'PT1H',null,'UTC') + 1) == \"j0.d0\") && (\"m1\" == \"j0.a0\"))",
+ JoinType.INNER
+ )
+ )
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setDimensions(
+ new DefaultDimensionSpec("__time", "d0", ColumnType.LONG),
+ new DefaultDimensionSpec("m1", "d1", ColumnType.FLOAT)
+ )
+ .setGranularity(Granularities.ALL)
+ .setAggregatorSpecs(aggregators(
+ new StringAnyAggregatorFactory("a0", "dim3", 100, false)
+ ))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{946684800000L, 1.0f, "a"}, // picks up first from [a, b]
+ new Object[]{946771200000L, 2.0f, "b"} // picks up first from [b, c]
+ )
+ );
+ }
@Test
@Parameters(source = QueryContextForJoinProvider.class)
@@ -1480,7 +1545,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
new SubstringDimExtractionFn(0, 1)
)
)
- .setAggregatorSpecs(new StringAnyAggregatorFactory("a0", "v", 10))
+ .setAggregatorSpecs(new StringAnyAggregatorFactory("a0", "v", 10, true))
.build()
),
"j0.",
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index 00ea933bb1e..1bca129757b 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -844,10 +844,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new LongAnyAggregatorFactory("a0", "cnt"),
new FloatAnyAggregatorFactory("a1", "m1"),
new DoubleAnyAggregatorFactory("a2", "m2"),
- new StringAnyAggregatorFactory("a3", "dim1", 10),
+ new StringAnyAggregatorFactory("a3", "dim1", 10, true),
new LongAnyAggregatorFactory("a4", "v0"),
new FloatAnyAggregatorFactory("a5", "v1"),
- new StringAnyAggregatorFactory("a6", "v2", 10)
+ new StringAnyAggregatorFactory("a6", "v2", 10, true)
)
)
.context(QUERY_CONTEXT_DEFAULT)
@@ -1420,7 +1420,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setAggregatorSpecs(aggregators(new StringAnyAggregatorFactory(
"a0:a",
"dim1",
- 10
+ 10, true
)))
.setPostAggregatorSpecs(
ImmutableList.of(
@@ -1565,7 +1565,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.granularity(Granularities.ALL)
.aggregators(
aggregators(
- new StringAnyAggregatorFactory("a0", "dim1", 32),
+ new StringAnyAggregatorFactory("a0", "dim1", 32, true),
new LongAnyAggregatorFactory("a1", "l2"),
new DoubleAnyAggregatorFactory("a2", "d2"),
new FloatAnyAggregatorFactory("a3", "f2")
@@ -1607,7 +1607,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.filters(filter)
.aggregators(
aggregators(
- new StringAnyAggregatorFactory("a0", "dim1", 32),
+ new StringAnyAggregatorFactory("a0", "dim1", 32, true),
new LongAnyAggregatorFactory("a1", "l2"),
new DoubleAnyAggregatorFactory("a2", "d2"),
new FloatAnyAggregatorFactory("a3", "f2")
@@ -9422,7 +9422,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.granularity(Granularities.ALL)
.aggregators(
aggregators(
- new StringAnyAggregatorFactory("a0", "dim1", 1024),
+ new StringAnyAggregatorFactory("a0", "dim1", 1024, true),
new LongAnyAggregatorFactory("a1", "l1"),
new StringFirstAggregatorFactory("a2", "dim1", null, 1024),
new LongFirstAggregatorFactory("a3", "l1", null),
@@ -9741,7 +9741,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
.setAggregatorSpecs(
aggregators(
new FilteredAggregatorFactory(
- new StringAnyAggregatorFactory("a0", "dim1", 1024),
+ new StringAnyAggregatorFactory("a0", "dim1", 1024, true),
equality("dim1", "nonexistent", ColumnType.STRING)
),
new FilteredAggregatorFactory(
@@ -13533,6 +13533,24 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
);
}
+ @Test
+ public void testStringAnyAggArgValidation()
+ {
+ DruidException e = assertThrows(DruidException.class, () -> testBuilder()
+ .sql("SELECT ANY_VALUE(dim3, 1000, 'true') FROM foo")
+ .queryContext(ImmutableMap.of())
+ .run());
+ assertThat(e, invalidSqlIs(
+ "Cannot apply 'ANY_VALUE' to arguments of type 'ANY_VALUE(, , )'. Supported form(s): 'ANY_VALUE(, [, []])' (line [1], column [8])"));
+ DruidException e1 = assertThrows(DruidException.class, () -> testBuilder()
+ .sql("SELECT ANY_VALUE(dim3, 1000, null) FROM foo")
+ .queryContext(ImmutableMap.of()).run());
+ Assert.assertEquals("Illegal use of 'NULL' (line [1], column [30])", e1.getMessage());
+ DruidException e2 = assertThrows(DruidException.class, () -> testBuilder()
+ .sql("SELECT ANY_VALUE(dim3, null, true) FROM foo")
+ .queryContext(ImmutableMap.of()).run());
+ Assert.assertEquals("Illegal use of 'NULL' (line [1], column [24])", e2.getMessage());
+ }
@Test
public void testStringAggMaxBytes()
{
@@ -14367,7 +14385,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
new StringLastAggregatorFactory("a1", "dim1", "__time", 1024),
new StringFirstAggregatorFactory("a2", "dim3", "__time", 1024),
new StringFirstAggregatorFactory("a3", "dim1", "__time", 1024),
- new StringAnyAggregatorFactory("a4", "dim3", 1024)))
+ new StringAnyAggregatorFactory("a4", "dim3", 1024, true)))
.build()
),
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestBuilder.java b/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestBuilder.java
index bb76488ccc4..eb8e7e67da7 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestBuilder.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestBuilder.java
@@ -323,4 +323,10 @@ public class QueryTestBuilder
return build().resultsOnly();
}
+ public boolean isDecoupledMode()
+ {
+ String mode = (String) queryContext.getOrDefault(PlannerConfig.CTX_NATIVE_QUERY_SQL_PLANNING_MODE, "");
+ return PlannerConfig.NATIVE_QUERY_SQL_PLANNING_MODE_DECOUPLED.equalsIgnoreCase(mode);
+ }
+
}
diff --git a/website/.spelling b/website/.spelling
index 002998c442c..14233798fef 100644
--- a/website/.spelling
+++ b/website/.spelling
@@ -2334,3 +2334,4 @@ LAST_VALUE
markUnused
markUsed
segmentId
+aggregateMultipleValues