Sql Single Value Aggregator for scalar queries (#15700)

Executing single value correlated queries will throw an exception today since single_value function is not available in druid.
With these added classes, this provides druid, the capability to plan and run such queries.
This commit is contained in:
Sree Charan Manamala 2024-02-08 19:20:30 +05:30 committed by GitHub
parent f3996b96ff
commit 57e12df352
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1152 additions and 10 deletions

View File

@ -43,6 +43,7 @@ import org.apache.druid.query.aggregation.SerializablePairLongDoubleComplexMetri
import org.apache.druid.query.aggregation.SerializablePairLongFloatComplexMetricSerde;
import org.apache.druid.query.aggregation.SerializablePairLongLongComplexMetricSerde;
import org.apache.druid.query.aggregation.SerializablePairLongStringComplexMetricSerde;
import org.apache.druid.query.aggregation.SingleValueAggregatorFactory;
import org.apache.druid.query.aggregation.any.DoubleAnyAggregatorFactory;
import org.apache.druid.query.aggregation.any.FloatAnyAggregatorFactory;
import org.apache.druid.query.aggregation.any.LongAnyAggregatorFactory;
@ -84,11 +85,23 @@ public class AggregatorsModule extends SimpleModule
ComplexMetrics.registerSerde(HyperUniquesSerde.TYPE_NAME, new HyperUniquesSerde());
ComplexMetrics.registerSerde(PreComputedHyperUniquesSerde.TYPE_NAME, new PreComputedHyperUniquesSerde());
ComplexMetrics.registerSerde(SerializablePairLongStringComplexMetricSerde.TYPE_NAME, new SerializablePairLongStringComplexMetricSerde());
ComplexMetrics.registerSerde(
SerializablePairLongStringComplexMetricSerde.TYPE_NAME,
new SerializablePairLongStringComplexMetricSerde()
);
ComplexMetrics.registerSerde(SerializablePairLongFloatComplexMetricSerde.TYPE_NAME, new SerializablePairLongFloatComplexMetricSerde());
ComplexMetrics.registerSerde(SerializablePairLongDoubleComplexMetricSerde.TYPE_NAME, new SerializablePairLongDoubleComplexMetricSerde());
ComplexMetrics.registerSerde(SerializablePairLongLongComplexMetricSerde.TYPE_NAME, new SerializablePairLongLongComplexMetricSerde());
ComplexMetrics.registerSerde(
SerializablePairLongFloatComplexMetricSerde.TYPE_NAME,
new SerializablePairLongFloatComplexMetricSerde()
);
ComplexMetrics.registerSerde(
SerializablePairLongDoubleComplexMetricSerde.TYPE_NAME,
new SerializablePairLongDoubleComplexMetricSerde()
);
ComplexMetrics.registerSerde(
SerializablePairLongLongComplexMetricSerde.TYPE_NAME,
new SerializablePairLongLongComplexMetricSerde()
);
setMixInAnnotation(AggregatorFactory.class, AggregatorFactoryMixin.class);
setMixInAnnotation(PostAggregator.class, PostAggregatorMixin.class);
@ -129,7 +142,8 @@ public class AggregatorsModule extends SimpleModule
@JsonSubTypes.Type(name = "doubleAny", value = DoubleAnyAggregatorFactory.class),
@JsonSubTypes.Type(name = "stringAny", value = StringAnyAggregatorFactory.class),
@JsonSubTypes.Type(name = "grouping", value = GroupingAggregatorFactory.class),
@JsonSubTypes.Type(name = "expression", value = ExpressionLambdaAggregatorFactory.class)
@JsonSubTypes.Type(name = "expression", value = ExpressionLambdaAggregatorFactory.class),
@JsonSubTypes.Type(name = "singleValue", value = SingleValueAggregatorFactory.class)
})
public interface AggregatorFactoryMixin
{

View File

@ -158,6 +158,7 @@ public class AggregatorUtil
public static final byte ARRAY_OF_DOUBLES_SKETCH_TO_BASE64_STRING_CACHE_TYPE_ID = 0x4C;
public static final byte ARRAY_OF_DOUBLES_SKETCH_CONSTANT_SKETCH_CACHE_TYPE_ID = 0x4D;
public static final byte ARRAY_OF_DOUBLES_SKETCH_TO_METRICS_SUM_ESTIMATE_CACHE_TYPE_ID = 0x4E;
public static final byte SINGLE_VALUE_CACHE_TYPE_ID = 0x4F;
// DDSketch aggregator
public static final byte DDSKETCH_CACHE_TYPE_ID = 0x50;
@ -165,15 +166,15 @@ public class AggregatorUtil
/**
* Given a list of PostAggregators and the name of an output column, returns the minimal list of PostAggregators
* required to compute the output column.
*
* <p>
* If the outputColumn does not exist in the list of PostAggregators, the return list will be empty (under the
* assumption that the outputColumn comes from a project, aggregation or really anything other than a
* PostAggregator).
*
* <p>
* If the outputColumn <strong>does</strong> exist in the list of PostAggregators, then the return list will have at
* least one element. If the PostAggregator with outputName depends on any other PostAggregators, then the returned
* list will contain all PostAggregators required to compute the outputColumn.
*
* <p>
* Note that PostAggregators are processed in list-order, meaning that for a PostAggregator to depend on another
* PostAggregator, the "depender" must exist *after* the "dependee" in the list. That is, if PostAggregator A
* depends on PostAggregator B, then the list should be [B, A], such that A is computed after B.
@ -181,8 +182,7 @@ public class AggregatorUtil
* @param postAggregatorList List of postAggregator, there is a restriction that the list should be in an order such
* that all the dependencies of any given aggregator should occur before that aggregator.
* See AggregatorUtilTest.testOutOfOrderPruneDependentPostAgg for example.
* @param outputName name of the postAgg on which dependency is to be calculated
*
* @param outputName name of the postAgg on which dependency is to be calculated
* @return the list of dependent postAggregators
*/
public static List<PostAggregator> pruneDependentPostAgg(List<PostAggregator> postAggregatorList, String outputName)

View File

@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.error.InvalidInput;
import org.apache.druid.segment.ColumnValueSelector;
import javax.annotation.Nullable;
public class SingleValueAggregator implements Aggregator
{
final ColumnValueSelector selector;
@Nullable
private Object value;
private boolean isAggregateInvoked = false;
public SingleValueAggregator(ColumnValueSelector selector)
{
this.selector = selector;
}
@Override
public void aggregate()
{
if (isAggregateInvoked) {
throw InvalidInput.exception("Subquery expression returned more than one row");
}
value = selector.getObject();
isAggregateInvoked = true;
}
@Override
public Object get()
{
return value;
}
@Override
public float getFloat()
{
assert validObjectValue();
return (value == null) ? NullHandling.ZERO_FLOAT : ((Number) value).floatValue();
}
@Override
public long getLong()
{
assert validObjectValue();
return (value == null) ? NullHandling.ZERO_LONG : ((Number) value).longValue();
}
@Override
public double getDouble()
{
assert validObjectValue();
return (value == null) ? NullHandling.ZERO_DOUBLE : ((Number) value).doubleValue();
}
@Override
public boolean isNull()
{
return NullHandling.sqlCompatible() && value == null;
}
private boolean validObjectValue()
{
return NullHandling.replaceWithDefault() || !isNull();
}
@Override
public void close()
{
// no resources to cleanup
}
@Override
public String toString()
{
return "SingleValueAggregator{" +
"selector=" + selector +
", value=" + value +
", isAggregateInvoked=" + isAggregateInvoked +
'}';
}
}

View File

@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions;
import org.apache.druid.error.DruidException;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnType;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
/**
* This AggregatorFactory is meant to wrap the subquery used as an expression into a single value
* and is expected to throw an exception when the subquery results in more than one row
*
* <p>
* This consumes columnType as well along with name and fieldName to pass it on to underlying
* {@link SingleValueBufferAggregator} to work with different ColumnTypes
*/
@JsonTypeName("singleValue")
public class SingleValueAggregatorFactory extends AggregatorFactory
{
@JsonProperty
private final String name;
@JsonProperty
private final String fieldName;
@JsonProperty
private final ColumnType columnType;
public static final int DEFAULT_MAX_VALUE_SIZE = 1024;
@JsonCreator
public SingleValueAggregatorFactory(
@JsonProperty("name") String name,
@JsonProperty("fieldName") final String fieldName,
@JsonProperty("columnType") final ColumnType columnType
)
{
this.name = Preconditions.checkNotNull(name, "name");
this.fieldName = Preconditions.checkNotNull(fieldName, "fieldName");
this.columnType = Preconditions.checkNotNull(columnType, "columnType");
}
@Override
public Aggregator factorize(ColumnSelectorFactory metricFactory)
{
ColumnValueSelector selector = metricFactory.makeColumnValueSelector(fieldName);
return new SingleValueAggregator(selector);
}
@Override
public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
{
ColumnValueSelector selector = metricFactory.makeColumnValueSelector(fieldName);
ColumnCapabilities columnCapabilities = metricFactory.getColumnCapabilities(fieldName);
if (columnCapabilities == null) {
throw DruidException.forPersona(DruidException.Persona.DEVELOPER)
.ofCategory(DruidException.Category.DEFENSIVE)
.build("Unable to get the capabilities of field [%s]", fieldName);
}
ColumnType columnType = new ColumnType(columnCapabilities.getType(), null, null);
return new SingleValueBufferAggregator(selector, columnType);
}
@Override
public Comparator getComparator()
{
throw DruidException.defensive("Single Value Aggregator would not have more than one row to compare");
}
/**
* Combine method would never be invoked as the broker sends the subquery to multiple segments
* and gather the results to a single value on which the single value aggregator is applied.
* Though getCombiningFactory would be invoked for understanding the fieldname.
*/
@Override
@Nullable
public Object combine(@Nullable Object lhs, @Nullable Object rhs)
{
throw DruidException.defensive("Single Value Aggregator would not have more than one row to combine");
}
@Override
public AggregatorFactory getCombiningFactory()
{
return new SingleValueAggregatorFactory(name, name, columnType);
}
@Override
public Object deserialize(Object object)
{
return object;
}
@Nullable
@Override
public Object finalizeComputation(@Nullable Object object)
{
return object;
}
@Override
public ColumnType getIntermediateType()
{
return columnType;
}
@Override
public ColumnType getResultType()
{
return columnType;
}
@Override
@JsonProperty
public String getName()
{
return name;
}
@JsonProperty
public String getFieldName()
{
return fieldName;
}
@Override
public List<String> requiredFields()
{
return Collections.singletonList(fieldName);
}
@Override
public int getMaxIntermediateSize()
{
// keeping 8 bytes for all numerics to make code look simple. This would store only a single value.
return Byte.BYTES + (columnType.isNumeric() ? Double.BYTES : DEFAULT_MAX_VALUE_SIZE);
}
@Override
public byte[] getCacheKey()
{
return new byte[]{AggregatorUtil.SINGLE_VALUE_CACHE_TYPE_ID};
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SingleValueAggregatorFactory that = (SingleValueAggregatorFactory) o;
return Objects.equals(name, that.name)
&& Objects.equals(fieldName, that.fieldName)
&& Objects.equals(columnType, that.columnType);
}
@Override
public int hashCode()
{
return Objects.hash(name, fieldName, columnType);
}
@Override
public String toString()
{
return "SingleValueAggregatorFactory{" +
"name='" + name + '\'' +
", fieldName='" + fieldName + '\'' +
", columnType=" + columnType +
'}';
}
}

View File

@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.error.DruidException;
import org.apache.druid.error.InvalidInput;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.NullableTypeStrategy;
import org.apache.druid.segment.column.TypeStrategies;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
public class SingleValueBufferAggregator implements BufferAggregator
{
private final ColumnValueSelector selector;
private final ColumnType columnType;
private final NullableTypeStrategy typeStrategy;
private boolean isAggregateInvoked = false;
public SingleValueBufferAggregator(ColumnValueSelector selector, ColumnType columnType)
{
this.selector = selector;
this.columnType = columnType;
this.typeStrategy = columnType.getNullableStrategy();
}
@Override
public void init(ByteBuffer buf, int position)
{
buf.put(position, NullHandling.IS_NULL_BYTE);
}
@Override
public void aggregate(ByteBuffer buf, int position)
{
if (isAggregateInvoked) {
throw InvalidInput.exception("Subquery expression returned more than one row");
}
int maxbufferSixe = Byte.BYTES + (columnType.isNumeric()
? Double.BYTES
: SingleValueAggregatorFactory.DEFAULT_MAX_VALUE_SIZE);
int written = typeStrategy.write(
buf,
position,
getSelectorObject(),
maxbufferSixe
);
if (written < 0) {
throw DruidException.forPersona(DruidException.Persona.ADMIN)
.ofCategory(DruidException.Category.RUNTIME_FAILURE)
.build("Subquery result exceeds the buffer limit [%s]", maxbufferSixe);
}
isAggregateInvoked = true;
}
@Nullable
private Object getSelectorObject()
{
if (columnType.isNumeric() && selector.isNull()) {
return null;
}
switch (columnType.getType()) {
case LONG:
return selector.getLong();
case FLOAT:
return selector.getFloat();
case DOUBLE:
return selector.getDouble();
default:
return selector.getObject();
}
}
@Nullable
@Override
public Object get(ByteBuffer buf, int position)
{
return typeStrategy.read(buf, position);
}
@Override
public float getFloat(ByteBuffer buf, int position)
{
return TypeStrategies.isNullableNull(buf, position)
? NullHandling.ZERO_FLOAT
: TypeStrategies.readNotNullNullableFloat(buf, position);
}
@Override
public double getDouble(ByteBuffer buf, int position)
{
return TypeStrategies.isNullableNull(buf, position)
? NullHandling.ZERO_DOUBLE
: TypeStrategies.readNotNullNullableDouble(buf, position);
}
@Override
public long getLong(ByteBuffer buf, int position)
{
return TypeStrategies.isNullableNull(buf, position)
? NullHandling.ZERO_LONG
: TypeStrategies.readNotNullNullableLong(buf, position);
}
@Override
public void close()
{
// no resources to cleanup
}
}

View File

@ -0,0 +1,307 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.error.DruidException;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.TestColumnSelectorFactory;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.nio.ByteBuffer;
public class SingleValueAggregationTest extends InitializedNullHandlingTest
{
private SingleValueAggregatorFactory longAggFactory;
private ColumnSelectorFactory colSelectorFactoryLong;
private ColumnCapabilities columnCapabilitiesLong;
private TestLongColumnSelector selectorLong;
private SingleValueAggregatorFactory doubleAggFactory;
private ColumnSelectorFactory colSelectorFactoryDouble;
private ColumnCapabilities columnCapabilitiesDouble;
private TestDoubleColumnSelectorImpl selectorDouble;
private SingleValueAggregatorFactory floatAggFactory;
private ColumnSelectorFactory colSelectorFactoryFloat;
private ColumnCapabilities columnCapabilitiesFloat;
private TestFloatColumnSelector selectorFloat;
private SingleValueAggregatorFactory stringAggFactory;
private ColumnSelectorFactory colSelectorFactoryString;
private ColumnCapabilities columnCapabilitiesString;
private TestObjectColumnSelector selectorString;
private final long[] longValues = {9223372036854775802L, 9223372036854775803L};
private final double[] doubleValues = {5.2d, 2.8976552d};
private final float[] floatValues = {5.2f, 2.89f};
private final String[] strValues = {"str1", "str2"};
public SingleValueAggregationTest() throws Exception
{
String longAggSpecJson = "{\"type\": \"singleValue\", \"name\": \"lng\", \"fieldName\": \"lngFld\", \"columnType\": \"LONG\"}";
longAggFactory = TestHelper.makeJsonMapper().readValue(longAggSpecJson, SingleValueAggregatorFactory.class);
String doubleAggSpecJson = "{\"type\": \"singleValue\", \"name\": \"dbl\", \"fieldName\": \"dblFld\", \"columnType\": \"DOUBLE\"}";
doubleAggFactory = TestHelper.makeJsonMapper().readValue(doubleAggSpecJson, SingleValueAggregatorFactory.class);
String floatAggSpecJson = "{\"type\": \"singleValue\", \"name\": \"dbl\", \"fieldName\": \"fltFld\", \"columnType\": \"FLOAT\"}";
floatAggFactory = TestHelper.makeJsonMapper().readValue(floatAggSpecJson, SingleValueAggregatorFactory.class);
String strAggSpecJson = "{\"type\": \"singleValue\", \"name\": \"str\", \"fieldName\": \"strFld\", \"columnType\": \"STRING\"}";
stringAggFactory = TestHelper.makeJsonMapper().readValue(strAggSpecJson, SingleValueAggregatorFactory.class);
}
@Before
public void setup()
{
selectorLong = new TestLongColumnSelector(longValues);
columnCapabilitiesLong = ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG);
colSelectorFactoryLong = new TestColumnSelectorFactory()
.addCapabilities("lngFld", columnCapabilitiesLong)
.addColumnSelector("lngFld", selectorLong);
selectorDouble = new TestDoubleColumnSelectorImpl(doubleValues);
columnCapabilitiesDouble = ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.DOUBLE);
colSelectorFactoryDouble = new TestColumnSelectorFactory()
.addCapabilities("dblFld", columnCapabilitiesDouble)
.addColumnSelector("dblFld", selectorDouble);
selectorFloat = new TestFloatColumnSelector(floatValues);
columnCapabilitiesFloat = ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.FLOAT);
colSelectorFactoryFloat = new TestColumnSelectorFactory()
.addCapabilities("fltFld", columnCapabilitiesFloat)
.addColumnSelector("fltFld", selectorFloat);
selectorString = new TestObjectColumnSelector(strValues);
columnCapabilitiesString = ColumnCapabilitiesImpl.createSimpleSingleValueStringColumnCapabilities();
colSelectorFactoryString = new TestColumnSelectorFactory()
.addCapabilities("strFld", columnCapabilitiesString)
.addColumnSelector("strFld", selectorString);
}
@Test
public void testLongAggregator()
{
Assert.assertEquals(ColumnType.LONG, longAggFactory.getIntermediateType());
Assert.assertEquals(ColumnType.LONG, longAggFactory.getResultType());
Assert.assertEquals("lng", longAggFactory.getName());
Assert.assertEquals("lngFld", longAggFactory.getFieldName());
Assert.assertThrows(DruidException.class, () -> longAggFactory.getComparator());
Aggregator agg = longAggFactory.factorize(colSelectorFactoryLong);
if (NullHandling.replaceWithDefault()) {
Assert.assertFalse(agg.isNull());
Assert.assertEquals(0L, agg.getLong());
} else {
Assert.assertTrue(agg.isNull());
Assert.assertThrows(AssertionError.class, () -> agg.getLong());
}
aggregate(selectorLong, agg);
Assert.assertEquals(longValues[0], ((Long) agg.get()).longValue());
Assert.assertEquals(longValues[0], agg.getLong());
Assert.assertThrows(DruidException.class, () -> aggregate(selectorLong, agg));
}
@Test
public void testLongBufferAggregator()
{
BufferAggregator agg = longAggFactory.factorizeBuffered(colSelectorFactoryLong);
ByteBuffer buffer = ByteBuffer.wrap(new byte[Double.BYTES + Byte.BYTES]);
agg.init(buffer, 0);
Assert.assertEquals(0L, agg.getLong(buffer, 0));
aggregate(selectorLong, agg, buffer, 0);
Assert.assertEquals(longValues[0], ((Long) agg.get(buffer, 0)).longValue());
Assert.assertEquals(longValues[0], agg.getLong(buffer, 0));
Assert.assertThrows(DruidException.class, () -> aggregate(selectorLong, agg, buffer, 0));
}
@Test
public void testCombine()
{
Assert.assertThrows(DruidException.class, () -> longAggFactory.combine(9223372036854775800L, 9223372036854775803L));
}
@Test
public void testDoubleAggregator()
{
Aggregator agg = doubleAggFactory.factorize(colSelectorFactoryDouble);
if (NullHandling.replaceWithDefault()) {
Assert.assertEquals(0.0d, agg.getDouble(), 0.000001);
} else {
Assert.assertThrows(AssertionError.class, () -> agg.getDouble());
}
aggregate(selectorDouble, agg);
Assert.assertEquals(doubleValues[0], ((Double) agg.get()).doubleValue(), 0.000001);
Assert.assertEquals(doubleValues[0], agg.getDouble(), 0.000001);
Assert.assertThrows(DruidException.class, () -> aggregate(selectorDouble, agg));
}
@Test
public void testDoubleBufferAggregator()
{
BufferAggregator agg = doubleAggFactory.factorizeBuffered(colSelectorFactoryDouble);
ByteBuffer buffer = ByteBuffer.wrap(new byte[SingleValueAggregatorFactory.DEFAULT_MAX_VALUE_SIZE + Byte.BYTES]);
agg.init(buffer, 0);
Assert.assertEquals(0.0d, agg.getDouble(buffer, 0), 0.000001);
aggregate(selectorDouble, agg, buffer, 0);
Assert.assertEquals(doubleValues[0], ((Double) agg.get(buffer, 0)).doubleValue(), 0.000001);
Assert.assertEquals(doubleValues[0], agg.getDouble(buffer, 0), 0.000001);
Assert.assertThrows(DruidException.class, () -> aggregate(selectorDouble, agg, buffer, 0));
}
@Test
public void testFloatAggregator()
{
Aggregator agg = floatAggFactory.factorize(colSelectorFactoryFloat);
if (NullHandling.replaceWithDefault()) {
Assert.assertEquals(0.0f, agg.getFloat(), 0.000001);
} else {
Assert.assertThrows(AssertionError.class, () -> agg.getFloat());
}
aggregate(selectorFloat, agg);
Assert.assertEquals(floatValues[0], ((Float) agg.get()).floatValue(), 0.000001);
Assert.assertEquals(floatValues[0], agg.getFloat(), 0.000001);
Assert.assertThrows(DruidException.class, () -> aggregate(selectorFloat, agg));
}
@Test
public void testFloatBufferAggregator()
{
BufferAggregator agg = floatAggFactory.factorizeBuffered(colSelectorFactoryFloat);
ByteBuffer buffer = ByteBuffer.wrap(new byte[Double.BYTES + Byte.BYTES]);
agg.init(buffer, 0);
Assert.assertEquals(0.0f, agg.getFloat(buffer, 0), 0.000001);
aggregate(selectorFloat, agg, buffer, 0);
Assert.assertEquals(floatValues[0], ((Float) agg.get(buffer, 0)).floatValue(), 0.000001);
Assert.assertEquals(floatValues[0], agg.getFloat(buffer, 0), 0.000001);
Assert.assertThrows(DruidException.class, () -> aggregate(selectorFloat, agg, buffer, 0));
}
@Test
public void testStringAggregator()
{
Aggregator agg = stringAggFactory.factorize(colSelectorFactoryString);
Assert.assertEquals(null, agg.get());
aggregate(selectorString, agg);
Assert.assertEquals(strValues[0], agg.get());
Assert.assertThrows(DruidException.class, () -> aggregate(selectorString, agg));
}
@Test
public void testStringBufferAggregator()
{
BufferAggregator agg = stringAggFactory.factorizeBuffered(colSelectorFactoryString);
ByteBuffer buffer = ByteBuffer.wrap(new byte[SingleValueAggregatorFactory.DEFAULT_MAX_VALUE_SIZE + Byte.BYTES]);
agg.init(buffer, 0);
aggregate(selectorString, agg, buffer, 0);
Assert.assertEquals(strValues[0], agg.get(buffer, 0));
Assert.assertThrows(DruidException.class, () -> aggregate(selectorString, agg, buffer, 0));
}
@Test
public void testEqualsAndHashCode()
{
SingleValueAggregatorFactory one = new SingleValueAggregatorFactory("name1", "fieldName1", ColumnType.LONG);
SingleValueAggregatorFactory oneMore = new SingleValueAggregatorFactory("name1", "fieldName1", ColumnType.LONG);
SingleValueAggregatorFactory two = new SingleValueAggregatorFactory("name2", "fieldName2", ColumnType.LONG);
Assert.assertEquals(one.hashCode(), oneMore.hashCode());
Assert.assertTrue(one.equals(oneMore));
Assert.assertFalse(one.equals(two));
}
private void aggregate(TestLongColumnSelector selector, Aggregator agg)
{
agg.aggregate();
selector.increment();
}
private void aggregate(TestLongColumnSelector selector, BufferAggregator agg, ByteBuffer buff, int position)
{
agg.aggregate(buff, position);
selector.increment();
}
private void aggregate(TestFloatColumnSelector selector, Aggregator agg)
{
agg.aggregate();
selector.increment();
}
private void aggregate(TestFloatColumnSelector selector, BufferAggregator agg, ByteBuffer buff, int position)
{
agg.aggregate(buff, position);
selector.increment();
}
private void aggregate(TestDoubleColumnSelectorImpl selector, Aggregator agg)
{
agg.aggregate();
selector.increment();
}
private void aggregate(TestDoubleColumnSelectorImpl selector, BufferAggregator agg, ByteBuffer buff, int position)
{
agg.aggregate(buff, position);
selector.increment();
}
private void aggregate(TestObjectColumnSelector selector, Aggregator agg)
{
agg.aggregate();
selector.increment();
}
private void aggregate(TestObjectColumnSelector selector, BufferAggregator agg, ByteBuffer buff, int position)
{
agg.aggregate(buff, position);
selector.increment();
}
}

View File

@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.aggregation.builtin;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.SingleValueAggregatorFactory;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.planner.Calcites;
import javax.annotation.Nullable;
/**
* This class serves as binding for Single Value Aggregator.
* Returns a single value in cases of subqueries used in expressions
*/
public class SingleValueSqlAggregator extends SimpleSqlAggregator
{
@Override
public SqlAggFunction calciteFunction()
{
return SqlStdOperatorTable.SINGLE_VALUE;
}
@Override
@Nullable
Aggregation getAggregation(
final String name,
final AggregateCall aggregateCall,
final ExprMacroTable macroTable,
final String fieldName
)
{
final ColumnType valueType = Calcites.getColumnTypeForRelDataType(aggregateCall.getType());
if (valueType == null) {
return null;
}
return Aggregation.create(createSingleValueAggregatorFactory(
valueType,
name,
fieldName
));
}
static AggregatorFactory createSingleValueAggregatorFactory(
final ColumnType aggregationType,
final String name,
final String fieldName
)
{
return new SingleValueAggregatorFactory(name, fieldName, aggregationType);
}
}

View File

@ -46,6 +46,7 @@ import org.apache.druid.sql.calcite.aggregation.builtin.GroupingSqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.LiteralSqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.MaxSqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.MinSqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.SingleValueSqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.StringSqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.SumSqlAggregator;
import org.apache.druid.sql.calcite.aggregation.builtin.SumZeroSqlAggregator;
@ -172,6 +173,7 @@ public class DruidOperatorTable implements SqlOperatorTable
.add(new BitwiseSqlAggregator(BitwiseSqlAggregator.Op.AND))
.add(new BitwiseSqlAggregator(BitwiseSqlAggregator.Op.OR))
.add(new BitwiseSqlAggregator(BitwiseSqlAggregator.Op.XOR))
.add(new SingleValueSqlAggregator())
.build();
// STRLEN has so many aliases.

View File

@ -36,16 +36,21 @@ import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory;
import org.apache.druid.query.aggregation.FloatMinAggregatorFactory;
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
import org.apache.druid.query.aggregation.LongMinAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.SingleValueAggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.ExtractionDimensionSpec;
import org.apache.druid.query.extraction.SubstringDimExtractionFn;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
import org.apache.druid.query.groupby.orderby.NoopLimitSpec;
import org.apache.druid.query.groupby.orderby.OrderByColumnSpec;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.query.scan.ScanQuery;
@ -64,8 +69,10 @@ import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@ -1042,4 +1049,306 @@ public class CalciteSubqueryTest extends BaseCalciteQueryTest
results
);
}
@Test
public void testSingleValueFloatAgg()
{
skipVectorize();
cannotVectorize();
testQuery(
"SELECT count(*) FROM foo where m1 <= (select min(m1) + 4 from foo)",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(join(
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(GroupByQuery.builder()
.setDataSource(new QueryDataSource(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(new FloatMinAggregatorFactory("a0", "m1"))
.build()
))
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setVirtualColumns(expressionVirtualColumn(
"v0",
"(\"a0\" + 4)",
ColumnType.FLOAT
)
)
.setAggregatorSpecs(
aggregators(
new SingleValueAggregatorFactory(
"_a0",
"v0",
ColumnType.FLOAT
)
)
)
.setLimitSpec(NoopLimitSpec.instance())
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
"j0.",
"1",
NullHandling.replaceWithDefault() ? JoinType.LEFT : JoinType.INNER
))
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(new CountAggregatorFactory("a0")))
.filters(expressionFilter("(\"m1\" <= \"j0._a0\")"))
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{5L}
)
);
}
@Test
public void testSingleValueDoubleAgg()
{
skipVectorize();
cannotVectorize();
testQuery(
"SELECT count(*) FROM foo where m1 >= (select max(m1) - 3.5 from foo)",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(join(
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(GroupByQuery.builder()
.setDataSource(new QueryDataSource(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(new FloatMaxAggregatorFactory("a0", "m1"))
.build()
))
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setVirtualColumns(expressionVirtualColumn(
"v0",
"(\"a0\" - 3.5)",
ColumnType.DOUBLE
)
)
.setAggregatorSpecs(
aggregators(
new SingleValueAggregatorFactory(
"_a0",
"v0",
ColumnType.DOUBLE
)
)
)
.setLimitSpec(NoopLimitSpec.instance())
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
"j0.",
"1",
NullHandling.replaceWithDefault() ? JoinType.LEFT : JoinType.INNER
))
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(new CountAggregatorFactory("a0")))
.filters(expressionFilter("(\"m1\" >= \"j0._a0\")"))
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{4L}
)
);
}
@Test
public void testSingleValueLongAgg()
{
skipVectorize();
cannotVectorize();
testQuery(
"SELECT count(*) FROM wikipedia where __time >= (select max(__time) - INTERVAL '10' MINUTE from wikipedia)",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(join(
new TableDataSource(CalciteTests.WIKIPEDIA),
new QueryDataSource(GroupByQuery.builder()
.setDataSource(new QueryDataSource(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.WIKIPEDIA)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(new LongMaxAggregatorFactory(
"a0",
"__time"
))
.build()
))
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setVirtualColumns(expressionVirtualColumn(
"v0",
"(\"a0\" - 600000)",
ColumnType.LONG
)
)
.setAggregatorSpecs(
aggregators(
new SingleValueAggregatorFactory(
"_a0",
"v0",
ColumnType.LONG
)
)
)
.setLimitSpec(NoopLimitSpec.instance())
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
"j0.",
"1",
NullHandling.replaceWithDefault() ? JoinType.LEFT : JoinType.INNER
))
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(new CountAggregatorFactory("a0")))
.filters(expressionFilter("(\"__time\" >= \"j0._a0\")"))
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{220L}
)
);
}
@Test
public void testSingleValueStringAgg()
{
skipVectorize();
cannotVectorize();
testQuery(
"SELECT count(*) FROM wikipedia where channel = (select channel from wikipedia order by __time desc LIMIT 1 OFFSET 6)",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(join(
new TableDataSource(CalciteTests.WIKIPEDIA),
new QueryDataSource(GroupByQuery.builder()
.setDataSource(new QueryDataSource(
Druids.newScanQueryBuilder()
.dataSource(CalciteTests.WIKIPEDIA)
.intervals(querySegmentSpec(Filtration.eternity()))
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.offset(6L)
.limit(1L)
.order(ScanQuery.Order.DESCENDING)
.columns("__time", "channel")
.legacy(false)
.context(QUERY_CONTEXT_DEFAULT)
.build()
))
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setVirtualColumns(expressionVirtualColumn(
"v0",
"\"channel\"",
ColumnType.STRING
)
)
.setAggregatorSpecs(
aggregators(
new SingleValueAggregatorFactory(
"a0",
"v0",
ColumnType.STRING
)
)
)
.setLimitSpec(NoopLimitSpec.instance())
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
"j0.",
"(\"channel\" == \"j0.a0\")",
JoinType.INNER
))
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.aggregators(aggregators(new CountAggregatorFactory("a0")))
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{1256L}
)
);
}
@Test
public void testSingleValueStringMultipleRowsAgg()
{
skipVectorize();
cannotVectorize();
testQueryThrows(
"SELECT count(*) FROM wikipedia where channel = (select channel from wikipedia order by __time desc LIMIT 2 OFFSET 6)",
exception -> exception.expectMessage("Subquery expression returned more than one row")
);
}
@Test
public void testSingleValueEmptyInnerAgg()
{
skipVectorize();
cannotVectorize();
testQuery(
"SELECT distinct countryName FROM wikipedia where countryName = ( select countryName from wikipedia where channel in ('abc', 'xyz'))",
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(join(
new TableDataSource(CalciteTests.WIKIPEDIA),
new QueryDataSource(Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.WIKIPEDIA)
.intervals(querySegmentSpec(Filtration.eternity()))
.granularity(Granularities.ALL)
.virtualColumns(expressionVirtualColumn(
"v0",
"\"countryName\"",
ColumnType.STRING
)
)
.aggregators(
new SingleValueAggregatorFactory(
"a0",
"v0",
ColumnType.STRING
)
)
.filters(new InDimFilter(
"channel",
new HashSet<>(Arrays.asList(
"abc",
"xyz"
))
))
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
"j0.",
"(\"countryName\" == \"j0.a0\")",
JoinType.INNER
))
.addDimension(new DefaultDimensionSpec("countryName", "d0", ColumnType.STRING))
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of()
);
}
}