diff --git a/docs/querying/sql-aggregations.md b/docs/querying/sql-aggregations.md index d005ce1fd11..b2df640a68f 100644 --- a/docs/querying/sql-aggregations.md +++ b/docs/querying/sql-aggregations.md @@ -90,7 +90,7 @@ In the aggregation functions supported by Druid, only `COUNT`, `ARRAY_AGG`, and |`EARLIEST_BY(expr, timestampExpr, [maxBytesPerValue])`|Returns the earliest value of `expr`.
The earliest value of `expr` is taken from the row with the overall earliest non-null value of `timestampExpr`.
If the earliest non-null value of `timestampExpr` appears in multiple rows, the `expr` may be taken from any of those rows.

If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.
If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)| |`LATEST(expr, [maxBytesPerValue])`|Returns the latest value of `expr`
The `expr` must come from a relation with a timestamp column (like `__time` in a Druid datasource) and the "latest" is taken from the row with the overall latest non-null value of the timestamp column.
If the latest non-null value of the timestamp column appears in multiple rows, the `expr` may be taken from any of those rows.

If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.
If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)| |`LATEST_BY(expr, timestampExpr, [maxBytesPerValue])`|Returns the latest value of `expr`.
The latest value of `expr` is taken from the row with the overall latest non-null value of `timestampExpr`.
If the overall latest non-null value of `timestampExpr` appears in multiple rows, the `expr` may be taken from any of those rows.

If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.
If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)| -|`ANY_VALUE(expr, [maxBytesPerValue])`|Returns any value of `expr` including null. This aggregator can simplify and optimize the performance by returning the first encountered value (including `null`).

If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.
If `maxBytesPerValue`is omitted; it defaults to `1024`. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)| +|`ANY_VALUE(expr, [maxBytesPerValue, [aggregateMultipleValues]])`|Returns any value of `expr` including null. This aggregator can simplify and optimize the performance by returning the first encountered value (including `null`).

If `expr` is a string or complex type `maxBytesPerValue` amount of space is allocated for the aggregation. Strings longer than this limit are truncated. The `maxBytesPerValue` parameter should be set as low as possible, since high values will lead to wasted memory.
If `maxBytesPerValue` is omitted; it defaults to `1024`. `aggregateMultipleValues` is an optional boolean flag controls the behavior of aggregating a [multi-value dimension](./multi-value-dimensions.md). `aggregateMultipleValues` is set as true by default and returns the stringified array in case of a multi-value dimension. By setting it to false, function will return first value instead. |`null` or `0`/`''` if `druid.generic.useDefaultValueForNull=true` (legacy mode)| |`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy dimension is included in a row, when using `GROUPING SETS`. Refer to [additional documentation](aggregations.md#grouping-aggregator) on how to infer this number.|N/A| |`ARRAY_AGG(expr, [size])`|Collects all values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes). If the aggregated array grows larger than the maximum size in bytes, the query will fail. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the output array may vary depending on processing order.|`null`| |`ARRAY_AGG(DISTINCT expr, [size])`|Collects all distinct values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes) per aggregate. If the aggregated array grows larger than the maximum size in bytes, the query will fail. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results will be based on the default for the element type.|`null`| diff --git a/docs/querying/sql-functions.md b/docs/querying/sql-functions.md index 8e43076518d..47b8ca90434 100644 --- a/docs/querying/sql-functions.md +++ b/docs/querying/sql-functions.md @@ -50,7 +50,7 @@ Calculates the arc cosine of a numeric expression. ## ANY_VALUE -`ANY_VALUE(expr, [maxBytesPerValue])` +`ANY_VALUE(expr, [maxBytesPerValue, [aggregateMultipleValues]])` **Function type:** [Aggregation](sql-aggregations.md) diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregator.java index 352b65e5646..aae267364c4 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregator.java @@ -24,19 +24,23 @@ import org.apache.druid.query.aggregation.Aggregator; import org.apache.druid.segment.BaseObjectColumnValueSelector; import org.apache.druid.segment.DimensionHandlerUtils; +import java.util.List; + public class StringAnyAggregator implements Aggregator { private final BaseObjectColumnValueSelector valueSelector; private final int maxStringBytes; private boolean isFound; private String foundValue; + private final boolean aggregateMultipleValues; - public StringAnyAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes) + public StringAnyAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes, boolean aggregateMultipleValues) { this.valueSelector = valueSelector; this.maxStringBytes = maxStringBytes; this.foundValue = null; this.isFound = false; + this.aggregateMultipleValues = aggregateMultipleValues; } @Override @@ -44,18 +48,36 @@ public class StringAnyAggregator implements Aggregator { if (!isFound) { final Object object = valueSelector.getObject(); - foundValue = DimensionHandlerUtils.convertObjectToString(object); - if (foundValue != null && foundValue.length() > maxStringBytes) { - foundValue = foundValue.substring(0, maxStringBytes); - } + foundValue = StringUtils.fastLooseChop(readValue(object), maxStringBytes); isFound = true; } } + private String readValue(final Object object) + { + if (object == null) { + return null; + } + if (object instanceof List) { + List objectList = (List) object; + if (objectList.size() == 0) { + return null; + } + if (objectList.size() == 1) { + return DimensionHandlerUtils.convertObjectToString(objectList.get(0)); + } + if (aggregateMultipleValues) { + return DimensionHandlerUtils.convertObjectToString(objectList); + } + return DimensionHandlerUtils.convertObjectToString(objectList.get(0)); + } + return DimensionHandlerUtils.convertObjectToString(object); + } + @Override public Object get() { - return StringUtils.chop(foundValue, maxStringBytes); + return foundValue; } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactory.java index 307de0650c3..67682bfde9d 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactory.java @@ -48,13 +48,15 @@ public class StringAnyAggregatorFactory extends AggregatorFactory private final String fieldName; private final String name; - protected final int maxStringBytes; + private final int maxStringBytes; + private final boolean aggregateMultipleValues; @JsonCreator public StringAnyAggregatorFactory( @JsonProperty("name") String name, @JsonProperty("fieldName") final String fieldName, - @JsonProperty("maxStringBytes") Integer maxStringBytes + @JsonProperty("maxStringBytes") Integer maxStringBytes, + @JsonProperty("aggregateMultipleValues") @Nullable final Boolean aggregateMultipleValues ) { Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name"); @@ -67,18 +69,19 @@ public class StringAnyAggregatorFactory extends AggregatorFactory this.maxStringBytes = maxStringBytes == null ? StringFirstAggregatorFactory.DEFAULT_MAX_STRING_SIZE : maxStringBytes; + this.aggregateMultipleValues = aggregateMultipleValues == null ? true : aggregateMultipleValues; } @Override public Aggregator factorize(ColumnSelectorFactory metricFactory) { - return new StringAnyAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes); + return new StringAnyAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes, aggregateMultipleValues); } @Override public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) { - return new StringAnyBufferAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes); + return new StringAnyBufferAggregator(metricFactory.makeColumnValueSelector(fieldName), maxStringBytes, aggregateMultipleValues); } @Override @@ -90,13 +93,15 @@ public class StringAnyAggregatorFactory extends AggregatorFactory return new StringAnyVectorAggregator( null, selectorFactory.makeMultiValueDimensionSelector(DefaultDimensionSpec.of(fieldName)), - maxStringBytes + maxStringBytes, + aggregateMultipleValues ); } else { return new StringAnyVectorAggregator( selectorFactory.makeSingleValueDimensionSelector(DefaultDimensionSpec.of(fieldName)), null, - maxStringBytes + maxStringBytes, + aggregateMultipleValues ); } } @@ -122,7 +127,7 @@ public class StringAnyAggregatorFactory extends AggregatorFactory @Override public AggregatorFactory getCombiningFactory() { - return new StringAnyAggregatorFactory(name, name, maxStringBytes); + return new StringAnyAggregatorFactory(name, name, maxStringBytes, aggregateMultipleValues); } @Override @@ -155,6 +160,11 @@ public class StringAnyAggregatorFactory extends AggregatorFactory { return maxStringBytes; } + @JsonProperty + public boolean getAggregateMultipleValues() + { + return aggregateMultipleValues; + } @Override public List requiredFields() @@ -192,7 +202,7 @@ public class StringAnyAggregatorFactory extends AggregatorFactory @Override public AggregatorFactory withName(String newName) { - return new StringAnyAggregatorFactory(newName, getFieldName(), getMaxStringBytes()); + return new StringAnyAggregatorFactory(newName, getFieldName(), getMaxStringBytes(), getAggregateMultipleValues()); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregator.java index 32bb3153fa2..86b8c51a469 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregator.java @@ -25,6 +25,7 @@ import org.apache.druid.segment.BaseObjectColumnValueSelector; import org.apache.druid.segment.DimensionHandlerUtils; import java.nio.ByteBuffer; +import java.util.List; public class StringAnyBufferAggregator implements BufferAggregator { @@ -34,11 +35,13 @@ public class StringAnyBufferAggregator implements BufferAggregator private final BaseObjectColumnValueSelector valueSelector; private final int maxStringBytes; + private final boolean aggregateMultipleValues; - public StringAnyBufferAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes) + public StringAnyBufferAggregator(BaseObjectColumnValueSelector valueSelector, int maxStringBytes, boolean aggregateMultipleValues) { this.valueSelector = valueSelector; this.maxStringBytes = maxStringBytes; + this.aggregateMultipleValues = aggregateMultipleValues; } @Override @@ -51,8 +54,7 @@ public class StringAnyBufferAggregator implements BufferAggregator public void aggregate(ByteBuffer buf, int position) { if (buf.getInt(position) == NOT_FOUND_FLAG_VALUE) { - final Object object = valueSelector.getObject(); - String foundValue = DimensionHandlerUtils.convertObjectToString(object); + String foundValue = readValue(valueSelector.getObject()); if (foundValue != null) { ByteBuffer mutationBuffer = buf.duplicate(); mutationBuffer.position(position + FOUND_VALUE_OFFSET); @@ -65,6 +67,27 @@ public class StringAnyBufferAggregator implements BufferAggregator } } + private String readValue(Object object) + { + if (object == null) { + return null; + } + if (object instanceof List) { + List objectList = (List) object; + if (objectList.size() == 0) { + return null; + } + if (objectList.size() == 1) { + return DimensionHandlerUtils.convertObjectToString(objectList.get(0)); + } + if (aggregateMultipleValues) { + return DimensionHandlerUtils.convertObjectToString(objectList); + } + return DimensionHandlerUtils.convertObjectToString(objectList.get(0)); + } + return DimensionHandlerUtils.convertObjectToString(object); + } + @Override public Object get(ByteBuffer buf, int position) { diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregator.java index 620801bafa3..104ee7a77bf 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregator.java @@ -23,12 +23,15 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.VectorAggregator; +import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.data.IndexedInts; import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector; import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; import javax.annotation.Nullable; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; public class StringAnyVectorAggregator implements VectorAggregator { @@ -43,11 +46,13 @@ public class StringAnyVectorAggregator implements VectorAggregator @Nullable private final MultiValueDimensionVectorSelector multiValueSelector; private final int maxStringBytes; + private final boolean aggregateMultipleValues; public StringAnyVectorAggregator( SingleValueDimensionVectorSelector singleValueSelector, MultiValueDimensionVectorSelector multiValueSelector, - int maxStringBytes + int maxStringBytes, + final boolean aggregateMultipleValues ) { Preconditions.checkState( @@ -61,6 +66,7 @@ public class StringAnyVectorAggregator implements VectorAggregator this.multiValueSelector = multiValueSelector; this.singleValueSelector = singleValueSelector; this.maxStringBytes = maxStringBytes; + this.aggregateMultipleValues = aggregateMultipleValues; } @Override @@ -78,7 +84,7 @@ public class StringAnyVectorAggregator implements VectorAggregator if (startRow < rows.length) { IndexedInts row = rows[startRow]; @Nullable - String foundValue = row.size() == 0 ? null : multiValueSelector.lookupName(row.get(0)); + String foundValue = readValue(row); putValue(buf, position, foundValue); } } else if (singleValueSelector != null) { @@ -93,6 +99,24 @@ public class StringAnyVectorAggregator implements VectorAggregator } } + private String readValue(IndexedInts row) + { + if (row.size() == 0) { + return null; + } + if (aggregateMultipleValues) { + if (row.size() == 1) { + return multiValueSelector.lookupName(row.get(0)); + } + List arrayList = new ArrayList<>(); + row.forEach(rowIndex -> { + arrayList.add(multiValueSelector.lookupName(rowIndex)); + }); + return DimensionHandlerUtils.convertObjectToString(arrayList); + } + return multiValueSelector.lookupName(row.get(0)); + } + @Override public void aggregate( ByteBuffer buf, @@ -142,4 +166,9 @@ public class StringAnyVectorAggregator implements VectorAggregator buf.putInt(position, FOUND_AND_NULL_FLAG_VALUE); } } + + public boolean isAggregateMultipleValues() + { + return aggregateMultipleValues; + } } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/AggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/AggregatorFactoryTest.java index 87d0e3dfdd8..1fea653c6f7 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/AggregatorFactoryTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/AggregatorFactoryTest.java @@ -148,7 +148,7 @@ public class AggregatorFactoryTest extends InitializedNullHandlingTest // string aggregators new StringFirstAggregatorFactory("stringFirst", "col", null, 1024), new StringLastAggregatorFactory("stringLast", "col", null, 1024), - new StringAnyAggregatorFactory("stringAny", "col", 1024), + new StringAnyAggregatorFactory("stringAny", "col", 1024, true), // sketch aggs new CardinalityAggregatorFactory("cardinality", ImmutableList.of(DefaultDimensionSpec.of("some-col")), false), new HyperUniquesAggregatorFactory("hyperUnique", "hyperunique"), @@ -307,7 +307,8 @@ public class AggregatorFactoryTest extends InitializedNullHandlingTest // string aggregators new StringFirstAggregatorFactory("col", "col", null, 1024), new StringLastAggregatorFactory("col", "col", null, 1024), - new StringAnyAggregatorFactory("col", "col", 1024), + new StringAnyAggregatorFactory("col", "col", 1024, true), + new StringAnyAggregatorFactory("col", "col", 1024, false), // sketch aggs new CardinalityAggregatorFactory("col", ImmutableList.of(DefaultDimensionSpec.of("some-col")), false), new HyperUniquesAggregatorFactory("col", "hyperunique"), diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregationTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregationTest.java index 208cfeb052d..b728049d625 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregationTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregationTest.java @@ -49,7 +49,7 @@ public class StringAnyAggregationTest @Before public void setup() { - stringAnyAggFactory = new StringAnyAggregatorFactory("billy", "nilly", MAX_STRING_SIZE); + stringAnyAggFactory = new StringAnyAggregatorFactory("billy", "nilly", MAX_STRING_SIZE, true); combiningAggFactory = stringAnyAggFactory.getCombiningFactory(); valueSelector = new TestObjectColumnSelector<>(strings); objectSelector = new TestObjectColumnSelector<>(strings); diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactoryTest.java index c480b9bb2ef..88351125f55 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactoryTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyAggregatorFactoryTest.java @@ -19,48 +19,44 @@ package org.apache.druid.query.aggregation.any; -import org.apache.druid.segment.ColumnInspector; +import com.google.common.collect.Lists; +import org.apache.druid.query.aggregation.Aggregator; +import org.apache.druid.query.aggregation.TestObjectColumnSelector; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.DimensionSelector; import org.apache.druid.segment.column.ColumnCapabilities; -import org.apache.druid.segment.vector.MultiValueDimensionVectorSelector; -import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector; -import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.segment.column.ColumnCapabilitiesImpl; +import org.apache.druid.segment.vector.TestVectorColumnSelectorFactory; +import org.apache.druid.segment.virtual.FallbackVirtualColumnTest; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.MockitoJUnitRunner; -import static org.mockito.ArgumentMatchers.any; +import java.util.List; -@RunWith(MockitoJUnitRunner.class) public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest { private static final String NAME = "NAME"; private static final String FIELD_NAME = "FIELD_NAME"; private static final int MAX_STRING_BYTES = 10; - @Mock - private ColumnInspector columnInspector; - @Mock + private TestColumnSelectorFactory columnInspector; private ColumnCapabilities capabilities; - @Mock - private VectorColumnSelectorFactory vectorSelectorFactory; - @Mock - private SingleValueDimensionVectorSelector singleValueDimensionVectorSelector; - @Mock - private MultiValueDimensionVectorSelector multiValueDimensionVectorSelector; - + private TestVectorColumnSelectorFactory vectorSelectorFactory; private StringAnyAggregatorFactory target; @Before public void setUp() { - Mockito.doReturn(capabilities).when(vectorSelectorFactory).getColumnCapabilities(FIELD_NAME); - Mockito.doReturn(ColumnCapabilities.Capable.UNKNOWN).when(capabilities).hasMultipleValues(); - target = new StringAnyAggregatorFactory(NAME, FIELD_NAME, MAX_STRING_BYTES); + target = new StringAnyAggregatorFactory(NAME, FIELD_NAME, MAX_STRING_BYTES, true); + columnInspector = new TestColumnSelectorFactory(); + vectorSelectorFactory = new TestVectorColumnSelectorFactory(); + capabilities = ColumnCapabilitiesImpl.createDefault().setHasMultipleValues(true); + vectorSelectorFactory.addCapabilities(FIELD_NAME, capabilities); + vectorSelectorFactory.addMVDVS(FIELD_NAME, new FallbackVirtualColumnTest.SameMultiVectorSelector()); } @Test @@ -72,10 +68,6 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest @Test public void factorizeVectorWithoutCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector() { - Mockito.doReturn(null).when(vectorSelectorFactory).getColumnCapabilities(FIELD_NAME); - Mockito.doReturn(singleValueDimensionVectorSelector) - .when(vectorSelectorFactory) - .makeSingleValueDimensionSelector(any()); StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory); Assert.assertNotNull(aggregator); } @@ -83,9 +75,6 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest @Test public void factorizeVectorWithUnknownCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector() { - Mockito.doReturn(multiValueDimensionVectorSelector) - .when(vectorSelectorFactory) - .makeMultiValueDimensionSelector(any()); StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory); Assert.assertNotNull(aggregator); } @@ -93,10 +82,6 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest @Test public void factorizeVectorWithMultipleValuesCapabilitiesShouldReturnAggregatorWithMultiDimensionSelector() { - Mockito.doReturn(ColumnCapabilities.Capable.TRUE).when(capabilities).hasMultipleValues(); - Mockito.doReturn(multiValueDimensionVectorSelector) - .when(vectorSelectorFactory) - .makeMultiValueDimensionSelector(any()); StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory); Assert.assertNotNull(aggregator); } @@ -104,11 +89,86 @@ public class StringAnyAggregatorFactoryTest extends InitializedNullHandlingTest @Test public void factorizeVectorWithoutMultipleValuesCapabilitiesShouldReturnAggregatorWithSingleDimensionSelector() { - Mockito.doReturn(ColumnCapabilities.Capable.FALSE).when(capabilities).hasMultipleValues(); - Mockito.doReturn(singleValueDimensionVectorSelector) - .when(vectorSelectorFactory) - .makeSingleValueDimensionSelector(any()); StringAnyVectorAggregator aggregator = target.factorizeVector(vectorSelectorFactory); Assert.assertNotNull(aggregator); } + + @Test + public void testFactorize() + { + Aggregator res = target.factorize(new TestColumnSelectorFactory()); + Assert.assertTrue(res instanceof StringAnyAggregator); + res.aggregate(); + Assert.assertEquals(null, res.get()); + StringAnyVectorAggregator vectorAggregator = target.factorizeVector(vectorSelectorFactory); + Assert.assertTrue(vectorAggregator.isAggregateMultipleValues()); + } + + @Test + public void testSvdStringAnyAggregator() + { + TestColumnSelectorFactory columnSelectorFactory = new TestColumnSelectorFactory(); + Aggregator res = target.factorize(columnSelectorFactory); + Assert.assertTrue(res instanceof StringAnyAggregator); + columnSelectorFactory.moveSelectorCursorToNext(); + res.aggregate(); + Assert.assertEquals("CCCC", res.get()); + } + + @Test + public void testMvdStringAnyAggregator() + { + TestColumnSelectorFactory columnSelectorFactory = new TestColumnSelectorFactory(); + Aggregator res = target.factorize(columnSelectorFactory); + Assert.assertTrue(res instanceof StringAnyAggregator); + columnSelectorFactory.moveSelectorCursorToNext(); + columnSelectorFactory.moveSelectorCursorToNext(); + res.aggregate(); + Assert.assertEquals("[AAAA, AAA", res.get()); + } + + @Test + public void testMvdStringAnyAggregatorWithAggregateMultipleToFalse() + { + StringAnyAggregatorFactory target = new StringAnyAggregatorFactory(NAME, FIELD_NAME, MAX_STRING_BYTES, false); + TestColumnSelectorFactory columnSelectorFactory = new TestColumnSelectorFactory(); + Aggregator res = target.factorize(columnSelectorFactory); + Assert.assertTrue(res instanceof StringAnyAggregator); + columnSelectorFactory.moveSelectorCursorToNext(); + columnSelectorFactory.moveSelectorCursorToNext(); + res.aggregate(); + // picks up first value in mvd list + Assert.assertEquals("AAAA", res.get()); + } + + static class TestColumnSelectorFactory implements ColumnSelectorFactory + { + List mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC"); + final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"}; + Integer maxStringBytes = 1024; + TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(mvds); + + @Override + public DimensionSelector makeDimensionSelector(DimensionSpec dimensionSpec) + { + return null; + } + + @Override + public ColumnValueSelector makeColumnValueSelector(String columnName) + { + return objectColumnSelector; + } + + @Override + public ColumnCapabilities getColumnCapabilities(String columnName) + { + return ColumnCapabilitiesImpl.createDefault().setHasMultipleValues(true); + } + + public void moveSelectorCursorToNext() + { + objectColumnSelector.increment(); + } + } } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregatorTest.java index 658db6f7eec..1db6cbe2b1d 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregatorTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyBufferAggregatorTest.java @@ -19,15 +19,22 @@ package org.apache.druid.query.aggregation.any; +import com.google.common.collect.Lists; import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.TestObjectColumnSelector; import org.junit.Assert; import org.junit.Test; import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; public class StringAnyBufferAggregatorTest { + StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory( + "billy", "billy", 1024, true + ); + private void aggregateBuffer( TestObjectColumnSelector valueSelector, BufferAggregator agg, @@ -44,17 +51,14 @@ public class StringAnyBufferAggregatorTest { final String[] strings = {"AAAA", "BBBB", "CCCC", "DDDD", "EEEE"}; - Integer maxStringBytes = 1024; + int maxStringBytes = 1024; TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(strings); - StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory( - "billy", "billy", maxStringBytes - ); - StringAnyBufferAggregator agg = new StringAnyBufferAggregator( objectColumnSelector, - maxStringBytes + maxStringBytes, + true ); ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize()); @@ -75,17 +79,15 @@ public class StringAnyBufferAggregatorTest public void testBufferAggregateWithFoldCheck() { final String[] strings = {"AAAA", "BBBB", "CCCC", "DDDD", "EEEE"}; - Integer maxStringBytes = 1024; + int maxStringBytes = 1024; TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(strings); - StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory( - "billy", "billy", maxStringBytes - ); StringAnyBufferAggregator agg = new StringAnyBufferAggregator( objectColumnSelector, - maxStringBytes + maxStringBytes, + true ); ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize()); @@ -108,17 +110,14 @@ public class StringAnyBufferAggregatorTest { final String[] strings = {"CCCC", "AAAA", "BBBB", null, "EEEE"}; - Integer maxStringBytes = 1024; + int maxStringBytes = 1024; TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(strings); - StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory( - "billy", "billy", maxStringBytes - ); - StringAnyBufferAggregator agg = new StringAnyBufferAggregator( objectColumnSelector, - maxStringBytes + maxStringBytes, + true ); ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize()); @@ -140,17 +139,13 @@ public class StringAnyBufferAggregatorTest { final String[] strings = {null, "CCCC", "AAAA", "BBBB", "EEEE"}; - Integer maxStringBytes = 1024; + int maxStringBytes = 1024; TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(strings); - StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory( - "billy", "billy", maxStringBytes - ); - StringAnyBufferAggregator agg = new StringAnyBufferAggregator( objectColumnSelector, - maxStringBytes + maxStringBytes, true ); ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize()); @@ -170,19 +165,15 @@ public class StringAnyBufferAggregatorTest @Test public void testNonStringValue() { - final Double[] doubles = {1.00, 2.00}; - Integer maxStringBytes = 1024; + int maxStringBytes = 1024; TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(doubles); - StringAnyAggregatorFactory factory = new StringAnyAggregatorFactory( - "billy", "billy", maxStringBytes - ); - StringAnyBufferAggregator agg = new StringAnyBufferAggregator( objectColumnSelector, - maxStringBytes + maxStringBytes, + true ); ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize()); @@ -198,4 +189,77 @@ public class StringAnyBufferAggregatorTest Assert.assertEquals("1.0", result); } + + @Test + public void testMvds() + { + List mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC"); + final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"}; + int maxStringBytes = 1024; + + TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(mvds); + + StringAnyBufferAggregator agg = new StringAnyBufferAggregator( + objectColumnSelector, + maxStringBytes, true + ); + + ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize() * 2); + int position = 0; + + int[] positions = new int[]{0, 1, 43, 100, 189}; + Arrays.stream(positions).forEach(i -> agg.init(buf, i)); + + //noinspection ForLoopReplaceableByForEach + for (int i = 0; i < mvds.length; i++) { + aggregateBuffer(objectColumnSelector, agg, buf, positions[i]); + } + String result = ((String) agg.get(buf, position)); + Assert.assertNull(result); + + for (int i = 0; i < positions.length; i++) { + if (i == 2) { + Assert.assertEquals(mvd.toString(), agg.get(buf, positions[2])); + } else { + Assert.assertEquals(mvds[i], agg.get(buf, positions[i])); + } + } + } + + @Test + public void testMvdsWithCustomAggregate() + { + List mvd = Lists.newArrayList("AAAA", "AAAAB", "AAAC"); + final Object[] mvds = {null, "CCCC", mvd, "BBBB", "EEEE"}; + final int maxStringBytes = 1024; + + TestObjectColumnSelector objectColumnSelector = new TestObjectColumnSelector<>(mvds); + + StringAnyBufferAggregator agg = new StringAnyBufferAggregator( + objectColumnSelector, + maxStringBytes, false + ); + + ByteBuffer buf = ByteBuffer.allocate(factory.getMaxIntermediateSize() * 2); + int position = 0; + + int[] positions = new int[]{0, 1, 43, 100, 189}; + Arrays.stream(positions).forEach(i -> agg.init(buf, i)); + + //noinspection ForLoopReplaceableByForEach + for (int i = 0; i < mvds.length; i++) { + aggregateBuffer(objectColumnSelector, agg, buf, positions[i]); + } + String result = ((String) agg.get(buf, position)); + Assert.assertNull(result); + + for (int i = 0; i < positions.length; i++) { + if (i == 2) { + // takes first in case of mvds + Assert.assertEquals(mvd.get(0), agg.get(buf, positions[2])); + } else { + Assert.assertEquals(mvds[i], agg.get(buf, positions[i])); + } + } + } } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregatorTest.java index bb9ca74dfb4..b6555f6d2af 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregatorTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/any/StringAnyVectorAggregatorTest.java @@ -33,6 +33,8 @@ import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.ThreadLocalRandom; import static org.apache.druid.query.aggregation.any.StringAnyVectorAggregator.NOT_FOUND_FLAG_VALUE; @@ -61,6 +63,7 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest private StringAnyVectorAggregator singleValueTarget; private StringAnyVectorAggregator multiValueTarget; + private StringAnyVectorAggregator customMultiValueTarget; @Before public void setUp() @@ -74,20 +77,22 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest return index >= DICTIONARY.length ? null : DICTIONARY[index]; }).when(singleValueSelector).lookupName(anyInt()); initializeRandomBuffer(); - singleValueTarget = new StringAnyVectorAggregator(singleValueSelector, null, MAX_STRING_BYTES); - multiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector, MAX_STRING_BYTES); + singleValueTarget = new StringAnyVectorAggregator(singleValueSelector, null, MAX_STRING_BYTES, true); + multiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector, MAX_STRING_BYTES, true); + // customMultiValueTarget aggregates to only single value in case of MVDs + customMultiValueTarget = new StringAnyVectorAggregator(null, multiValueSelector, MAX_STRING_BYTES, false); } @Test(expected = IllegalStateException.class) public void initWithBothSingleAndMultiValueSelectorShouldThrowException() { - new StringAnyVectorAggregator(singleValueSelector, multiValueSelector, MAX_STRING_BYTES); + new StringAnyVectorAggregator(singleValueSelector, multiValueSelector, MAX_STRING_BYTES, true); } @Test(expected = IllegalStateException.class) public void initWithNeitherSingleNorMultiValueSelectorShouldThrowException() { - new StringAnyVectorAggregator(null, null, MAX_STRING_BYTES); + new StringAnyVectorAggregator(null, null, MAX_STRING_BYTES, true); } @Test @@ -122,7 +127,7 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest public void aggregateMultiValuePositionNotFoundShouldPutFirstValue() { multiValueTarget.aggregate(buf, POSITION, 0, 2); - Assert.assertEquals(DICTIONARY[1], multiValueTarget.get(buf, POSITION)); + Assert.assertEquals("[One, Zero]", multiValueTarget.get(buf, POSITION)); } @Test @@ -155,9 +160,9 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest @Test public void aggregateBatchWithRowsShouldAggregateAllRows() { - int[] positions = new int[] {0, 43, 100}; + int[] positions = new int[]{0, 43, 100}; int positionOffset = 2; - int[] rows = new int[] {2, 1, 0}; + int[] rows = new int[]{2, 1, 0}; clearBufferForPositions(positionOffset, positions); multiValueTarget.aggregate(buf, 3, positions, rows, positionOffset); for (int i = 0; i < positions.length; i++) { @@ -166,8 +171,32 @@ public class StringAnyVectorAggregatorTest extends InitializedNullHandlingTest IndexedInts rowIndex = MULTI_VALUE_ROWS[row]; if (rowIndex.size() == 0) { Assert.assertNull(multiValueTarget.get(buf, position)); - } else { + } else if (rowIndex.size() == 1) { Assert.assertEquals(multiValueSelector.lookupName(rowIndex.get(0)), multiValueTarget.get(buf, position)); + } else { + List res = new ArrayList<>(); + rowIndex.forEach(index -> res.add(multiValueSelector.lookupName(index))); + Assert.assertEquals(res.toString(), multiValueTarget.get(buf, position)); + } + } + } + + @Test + public void aggregateBatchWithRowsShouldAggregateAllRowsWithAggregateMVDFalse() + { + int[] positions = new int[]{0, 43, 100}; + int positionOffset = 2; + int[] rows = new int[]{2, 1, 0}; + clearBufferForPositions(positionOffset, positions); + customMultiValueTarget.aggregate(buf, 3, positions, rows, positionOffset); + for (int i = 0; i < positions.length; i++) { + int position = positions[i] + positionOffset; + int row = rows[i]; + IndexedInts rowIndex = MULTI_VALUE_ROWS[row]; + if (rowIndex.size() == 0) { + Assert.assertNull(customMultiValueTarget.get(buf, position)); + } else { + Assert.assertEquals(multiValueSelector.lookupName(rowIndex.get(0)), customMultiValueTarget.get(buf, position)); } } } diff --git a/processing/src/test/java/org/apache/druid/segment/virtual/FallbackVirtualColumnTest.java b/processing/src/test/java/org/apache/druid/segment/virtual/FallbackVirtualColumnTest.java index ac6b5446144..72de5466ff6 100644 --- a/processing/src/test/java/org/apache/druid/segment/virtual/FallbackVirtualColumnTest.java +++ b/processing/src/test/java/org/apache/druid/segment/virtual/FallbackVirtualColumnTest.java @@ -489,7 +489,7 @@ public class FallbackVirtualColumnTest } } - private static class SameMultiVectorSelector implements MultiValueDimensionVectorSelector + public static class SameMultiVectorSelector implements MultiValueDimensionVectorSelector { @Override public int getValueCardinality() diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java index efa3a9e7e32..21bcc833e04 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestAnySqlAggregator.java @@ -95,7 +95,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator String fieldName, String timeColumn, ColumnType type, - Integer maxStringBytes + Integer maxStringBytes, + Boolean aggregateMultipleValues ) { switch (type.getType()) { @@ -121,7 +122,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator String fieldName, String timeColumn, ColumnType type, - Integer maxStringBytes + Integer maxStringBytes, + Boolean aggregateMultipleValues ) { switch (type.getType()) { @@ -147,7 +149,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator String fieldName, String timeColumn, ColumnType type, - Integer maxStringBytes + Integer maxStringBytes, + Boolean aggregateMultipleValues ) { switch (type.getType()) { @@ -158,7 +161,7 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator case DOUBLE: return new DoubleAnyAggregatorFactory(name, fieldName); case STRING: - return new StringAnyAggregatorFactory(name, fieldName, maxStringBytes); + return new StringAnyAggregatorFactory(name, fieldName, maxStringBytes, aggregateMultipleValues); default: throw SimpleSqlAggregator.badTypeException(fieldName, "ANY", type); } @@ -170,7 +173,8 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator String fieldName, String timeColumn, ColumnType outputType, - Integer maxStringBytes + Integer maxStringBytes, + Boolean aggregateMultipleValues ); } @@ -244,37 +248,38 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator final AggregatorFactory theAggFactory; switch (args.size()) { case 1: - theAggFactory = aggregatorType.createAggregatorFactory(aggregatorName, fieldName, null, outputType, null); + theAggFactory = aggregatorType.createAggregatorFactory(aggregatorName, fieldName, null, outputType, null, true); break; case 2: - int maxStringBytes; - try { - maxStringBytes = RexLiteral.intValue(rexNodes.get(1)); - } - catch (AssertionError ae) { - plannerContext.setPlanningError( - "The second argument '%s' to function '%s' is not a number", - rexNodes.get(1), - aggregateCall.getName() - ); - return null; - } + Integer maxStringBytes = RexLiteral.intValue(rexNodes.get(1)); // added not null check at the function theAggFactory = aggregatorType.createAggregatorFactory( aggregatorName, fieldName, null, outputType, - maxStringBytes + maxStringBytes.intValue(), + true + ); + break; + case 3: + maxStringBytes = RexLiteral.intValue(rexNodes.get(1)); // added not null check at the function for rexNode 1,2 + boolean aggregateMultipleValues = RexLiteral.booleanValue(rexNodes.get(2)); + theAggFactory = aggregatorType.createAggregatorFactory( + aggregatorName, + fieldName, + null, + outputType, + maxStringBytes, + aggregateMultipleValues ); break; default: throw InvalidSqlInput.exception( - "Function [%s] expects 1 or 2 arguments but found [%s]", + "Function [%s] expects 1 or 2 or 3 arguments but found [%s]", aggregateCall.getName(), args.size() ); } - return Aggregation.create( Collections.singletonList(theAggFactory), finalizeAggregations ? new FinalizingFieldAccessPostAggregator(name, aggregatorName) : null @@ -372,10 +377,11 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator InferTypes.RETURN_TYPE, DefaultOperandTypeChecker .builder() - .operandNames("expr", "maxBytesPerString") - .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) + .operandNames("expr", "maxBytesPerStringInt", "aggregateMultipleValuesBoolean") + .operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.BOOLEAN) .requiredOperandCount(1) - .literalOperands(1) + .literalOperands(1, 2) + .notNullOperands(1, 2) .build(), SqlFunctionCategory.USER_DEFINED_FUNCTION, false, @@ -402,9 +408,9 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator SqlParserPos pos = call.getParserPosition(); - if (operands.isEmpty() || operands.size() > 2) { + if (operands.isEmpty() || operands.size() > 3) { throw InvalidSqlInput.exception( - "Function [%s] expects 1 or 2 arguments but found [%s]", + "Function [%s] expects 1 or 2 or 3 arguments but found [%s]", getName(), operands.size() ); @@ -417,6 +423,9 @@ public class EarliestLatestAnySqlAggregator implements SqlAggregator if (operands.size() == 2) { newOperands.add(operands.get(1)); } + if (operands.size() == 3) { + newOperands.add(operands.get(2)); + } return replacementAggFunc.createCall(pos, newOperands); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java index 03e23503a81..fac88d853e1 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/EarliestLatestBySqlAggregator.java @@ -119,7 +119,8 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator rexNodes.get(1) ), outputType, - null + null, + true ); break; case 3: @@ -145,7 +146,8 @@ public class EarliestLatestBySqlAggregator implements SqlAggregator rexNodes.get(1) ), outputType, - maxStringBytes + maxStringBytes, + true ); break; default: diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java index f43fde3a935..a52ce5707c1 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/DefaultOperandTypeChecker.java @@ -188,6 +188,7 @@ public class DefaultOperandTypeChecker implements SqlOperandTypeChecker @Nullable private Integer requiredOperandCount; private int[] literalOperands; + private IntSet notNullOperands = new IntArraySet(); private Builder() { @@ -229,6 +230,12 @@ public class DefaultOperandTypeChecker implements SqlOperandTypeChecker return this; } + public Builder notNullOperands(final int... notNullOperands) + { + Arrays.stream(notNullOperands).forEach(this.notNullOperands::add); + return this; + } + public DefaultOperandTypeChecker build() { int computedRequiredOperandCount = requiredOperandCount == null ? operandTypes.size() : requiredOperandCount; @@ -236,16 +243,18 @@ public class DefaultOperandTypeChecker implements SqlOperandTypeChecker operandNames, operandTypes, computedRequiredOperandCount, - DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount, operandTypes.size()), + DefaultOperandTypeChecker.buildNullableOperands(computedRequiredOperandCount, operandTypes.size(), notNullOperands), literalOperands ); } } - public static IntSet buildNullableOperands(int requiredOperandCount, int totalOperandCount) + public static IntSet buildNullableOperands(int requiredOperandCount, int totalOperandCount, IntSet notNullOperands) { final IntSet nullableOperands = new IntArraySet(); - IntStream.range(requiredOperandCount, totalOperandCount).forEach(nullableOperands::add); + IntStream.range(requiredOperandCount, totalOperandCount) + .filter(i -> !notNullOperands.contains(i)) + .forEach(nullableOperands::add); return nullableOperands; } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java index d655e12f799..450d6208240 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java @@ -593,7 +593,7 @@ public class OperatorConversions { final IntSet nullableOperands = requiredOperandCount == null ? new IntArraySet() - : DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount, operandTypes.size()); + : DefaultOperandTypeChecker.buildNullableOperands(requiredOperandCount, operandTypes.size(), new IntArraySet()); if (operandTypeInference == null) { SqlOperandTypeInference defaultInference = new DefaultOperandTypeInference(operandTypes, nullableOperands); return (callBinding, returnType, types) -> { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java index a8e16de2675..3ead14a05a3 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java @@ -497,7 +497,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest cannotVectorize(); testQuery( - "SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100) FROM foo WHERE (TIME_FLOOR(__time, 'PT1H'), m1) IN\n" + "SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100, true) FROM foo WHERE (TIME_FLOOR(__time, 'PT1H'), m1) IN\n" + " (\n" + " SELECT TIME_FLOOR(__time, 'PT1H') AS t1, MIN(m1) AS t2 FROM foo WHERE dim3 = 'b'\n" + " AND __time BETWEEN '1994-04-29 00:00:00' AND '2020-01-11 00:00:00' GROUP BY 1\n" @@ -532,7 +532,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest ) .setGranularity(Granularities.ALL) .setAggregatorSpecs(aggregators( - new StringAnyAggregatorFactory("a0", "dim3", 100) + new StringAnyAggregatorFactory("a0", "dim3", 100, true) )) .setContext(QUERY_CONTEXT_DEFAULT) .build() @@ -598,7 +598,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest ) .setGranularity(Granularities.ALL) .setAggregatorSpecs(aggregators( - new StringAnyAggregatorFactory("a0", "dim3", 100) + new StringAnyAggregatorFactory("a0", "dim3", 100, true) )) .setContext(QUERY_CONTEXT_DEFAULT) .build() @@ -609,6 +609,71 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest ) ); } + @Test + public void testJoinOnGroupByInsteadOfTimeseriesWithFloorOnTimeWithNoAggregateMultipleValues() + { + // Cannot vectorize JOIN operator. + cannotVectorize(); + + testQuery( + "SELECT CAST(__time AS BIGINT), m1, ANY_VALUE(dim3, 100, false) FROM foo WHERE (CAST(TIME_FLOOR(__time, 'PT1H') AS BIGINT) + 1, m1) IN\n" + + " (\n" + + " SELECT CAST(TIME_FLOOR(__time, 'PT1H') AS BIGINT) + 1 AS t1, MIN(m1) AS t2 FROM foo WHERE dim3 = 'b'\n" + + " AND __time BETWEEN '1994-04-29 00:00:00' AND '2020-01-11 00:00:00' GROUP BY 1\n" + + " )\n" + + "GROUP BY 1, 2\n", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Intervals.of( + "1994-04-29/2020-01-11T00:00:00.001Z"))) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "(timestamp_floor(\"__time\",'PT1H',null,'UTC') + 1)", + ColumnType.LONG + ) + ) + .setDimFilter(equality("dim3", "b", ColumnType.STRING)) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions(new DefaultDimensionSpec( + "v0", + "d0", + ColumnType.LONG + ))) + .setAggregatorSpecs(aggregators( + new FloatMinAggregatorFactory("a0", "m1") + )) + .setContext(QUERY_CONTEXT_DEFAULT) + .build()), + "j0.", + "(((timestamp_floor(\"__time\",'PT1H',null,'UTC') + 1) == \"j0.d0\") && (\"m1\" == \"j0.a0\"))", + JoinType.INNER + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimensions( + new DefaultDimensionSpec("__time", "d0", ColumnType.LONG), + new DefaultDimensionSpec("m1", "d1", ColumnType.FLOAT) + ) + .setGranularity(Granularities.ALL) + .setAggregatorSpecs(aggregators( + new StringAnyAggregatorFactory("a0", "dim3", 100, false) + )) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{946684800000L, 1.0f, "a"}, // picks up first from [a, b] + new Object[]{946771200000L, 2.0f, "b"} // picks up first from [b, c] + ) + ); + } @Test @Parameters(source = QueryContextForJoinProvider.class) @@ -1480,7 +1545,7 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest new SubstringDimExtractionFn(0, 1) ) ) - .setAggregatorSpecs(new StringAnyAggregatorFactory("a0", "v", 10)) + .setAggregatorSpecs(new StringAnyAggregatorFactory("a0", "v", 10, true)) .build() ), "j0.", diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 00ea933bb1e..1bca129757b 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -844,10 +844,10 @@ public class CalciteQueryTest extends BaseCalciteQueryTest new LongAnyAggregatorFactory("a0", "cnt"), new FloatAnyAggregatorFactory("a1", "m1"), new DoubleAnyAggregatorFactory("a2", "m2"), - new StringAnyAggregatorFactory("a3", "dim1", 10), + new StringAnyAggregatorFactory("a3", "dim1", 10, true), new LongAnyAggregatorFactory("a4", "v0"), new FloatAnyAggregatorFactory("a5", "v1"), - new StringAnyAggregatorFactory("a6", "v2", 10) + new StringAnyAggregatorFactory("a6", "v2", 10, true) ) ) .context(QUERY_CONTEXT_DEFAULT) @@ -1420,7 +1420,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .setAggregatorSpecs(aggregators(new StringAnyAggregatorFactory( "a0:a", "dim1", - 10 + 10, true ))) .setPostAggregatorSpecs( ImmutableList.of( @@ -1565,7 +1565,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .granularity(Granularities.ALL) .aggregators( aggregators( - new StringAnyAggregatorFactory("a0", "dim1", 32), + new StringAnyAggregatorFactory("a0", "dim1", 32, true), new LongAnyAggregatorFactory("a1", "l2"), new DoubleAnyAggregatorFactory("a2", "d2"), new FloatAnyAggregatorFactory("a3", "f2") @@ -1607,7 +1607,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .filters(filter) .aggregators( aggregators( - new StringAnyAggregatorFactory("a0", "dim1", 32), + new StringAnyAggregatorFactory("a0", "dim1", 32, true), new LongAnyAggregatorFactory("a1", "l2"), new DoubleAnyAggregatorFactory("a2", "d2"), new FloatAnyAggregatorFactory("a3", "f2") @@ -9422,7 +9422,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .granularity(Granularities.ALL) .aggregators( aggregators( - new StringAnyAggregatorFactory("a0", "dim1", 1024), + new StringAnyAggregatorFactory("a0", "dim1", 1024, true), new LongAnyAggregatorFactory("a1", "l1"), new StringFirstAggregatorFactory("a2", "dim1", null, 1024), new LongFirstAggregatorFactory("a3", "l1", null), @@ -9741,7 +9741,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest .setAggregatorSpecs( aggregators( new FilteredAggregatorFactory( - new StringAnyAggregatorFactory("a0", "dim1", 1024), + new StringAnyAggregatorFactory("a0", "dim1", 1024, true), equality("dim1", "nonexistent", ColumnType.STRING) ), new FilteredAggregatorFactory( @@ -13533,6 +13533,24 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testStringAnyAggArgValidation() + { + DruidException e = assertThrows(DruidException.class, () -> testBuilder() + .sql("SELECT ANY_VALUE(dim3, 1000, 'true') FROM foo") + .queryContext(ImmutableMap.of()) + .run()); + assertThat(e, invalidSqlIs( + "Cannot apply 'ANY_VALUE' to arguments of type 'ANY_VALUE(, , )'. Supported form(s): 'ANY_VALUE(, [, []])' (line [1], column [8])")); + DruidException e1 = assertThrows(DruidException.class, () -> testBuilder() + .sql("SELECT ANY_VALUE(dim3, 1000, null) FROM foo") + .queryContext(ImmutableMap.of()).run()); + Assert.assertEquals("Illegal use of 'NULL' (line [1], column [30])", e1.getMessage()); + DruidException e2 = assertThrows(DruidException.class, () -> testBuilder() + .sql("SELECT ANY_VALUE(dim3, null, true) FROM foo") + .queryContext(ImmutableMap.of()).run()); + Assert.assertEquals("Illegal use of 'NULL' (line [1], column [24])", e2.getMessage()); + } @Test public void testStringAggMaxBytes() { @@ -14367,7 +14385,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest new StringLastAggregatorFactory("a1", "dim1", "__time", 1024), new StringFirstAggregatorFactory("a2", "dim3", "__time", 1024), new StringFirstAggregatorFactory("a3", "dim1", "__time", 1024), - new StringAnyAggregatorFactory("a4", "dim3", 1024))) + new StringAnyAggregatorFactory("a4", "dim3", 1024, true))) .build() ), diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestBuilder.java b/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestBuilder.java index bb76488ccc4..eb8e7e67da7 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestBuilder.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestBuilder.java @@ -323,4 +323,10 @@ public class QueryTestBuilder return build().resultsOnly(); } + public boolean isDecoupledMode() + { + String mode = (String) queryContext.getOrDefault(PlannerConfig.CTX_NATIVE_QUERY_SQL_PLANNING_MODE, ""); + return PlannerConfig.NATIVE_QUERY_SQL_PLANNING_MODE_DECOUPLED.equalsIgnoreCase(mode); + } + } diff --git a/website/.spelling b/website/.spelling index 002998c442c..14233798fef 100644 --- a/website/.spelling +++ b/website/.spelling @@ -2334,3 +2334,4 @@ LAST_VALUE markUnused markUsed segmentId +aggregateMultipleValues