Graceful null handling and correctness in DoubleMean Aggregator (#12320)

* Adding null handling for double mean aggregator

* Updating code to handle nulls in DoubleMean aggregator

* oops last one should have checkstyle issues. fixed

* Updating some code and test cases

* Checking on object is null in case of numeric aggregator

* Adding one more test to improve coverage

* Changing one test as asked in the review

* Changing one test as asked in the review for nulls
This commit is contained in:
somu-imply 2022-03-14 16:52:47 -07:00 committed by GitHub
parent 3de1272926
commit b5195c5095
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 5 deletions

View File

@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.mean;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.segment.ColumnValueSelector;
@ -43,6 +44,10 @@ public class DoubleMeanAggregator implements Aggregator
{
Object update = selector.getObject();
if (update == null && NullHandling.replaceWithDefault() == false) {
return;
}
if (update instanceof DoubleMeanHolder) {
value.update((DoubleMeanHolder) update);
} else if (update instanceof List) {

View File

@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.mean;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
@ -51,6 +52,9 @@ public class DoubleMeanBufferAggregator implements BufferAggregator
{
Object update = selector.getObject();
if (update == null && NullHandling.replaceWithDefault() == false) {
return;
}
if (update instanceof DoubleMeanHolder) {
DoubleMeanHolder.update(buf, position, (DoubleMeanHolder) update);
} else if (update instanceof List) {

View File

@ -20,6 +20,7 @@
package org.apache.druid.query.aggregation.mean;
import com.google.common.base.Preconditions;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.VectorValueSelector;
@ -45,11 +46,28 @@ public class DoubleMeanVectorAggregator implements VectorAggregator
public void aggregate(final ByteBuffer buf, final int position, final int startRow, final int endRow)
{
final double[] vector = selector.getDoubleVector();
for (int i = startRow; i < endRow; i++) {
DoubleMeanHolder.update(buf, position, vector[i]);
final boolean[] nulls = selector.getNullVector();
if (nulls != null) {
if (NullHandling.replaceWithDefault()) {
for (int i = startRow; i < endRow; i++) {
DoubleMeanHolder.update(buf, position, vector[i]);
}
} else {
for (int i = startRow; i < endRow; i++) {
if (!nulls[i]) {
DoubleMeanHolder.update(buf, position, vector[i]);
}
}
}
} else {
for (int i = startRow; i < endRow; i++) {
DoubleMeanHolder.update(buf, position, vector[i]);
}
}
}
@Override
public void aggregate(
final ByteBuffer buf,
@ -60,10 +78,27 @@ public class DoubleMeanVectorAggregator implements VectorAggregator
)
{
final double[] vector = selector.getDoubleVector();
final boolean[] nulls = selector.getNullVector();
for (int i = 0; i < numRows; i++) {
final double val = vector[rows != null ? rows[i] : i];
DoubleMeanHolder.update(buf, positions[i] + positionOffset, val);
if (nulls != null) {
if (NullHandling.replaceWithDefault()) {
for (int i = 0; i < numRows; i++) {
final double val = vector[rows != null ? rows[i] : i];
DoubleMeanHolder.update(buf, positions[i] + positionOffset, val);
}
} else {
for (int j = 0; j < numRows; j++) {
if (!nulls[j]) {
final double val = vector[rows != null ? rows[j] : j];
DoubleMeanHolder.update(buf, positions[j] + positionOffset, val);
}
}
}
} else {
for (int i = 0; i < numRows; i++) {
final double val = vector[rows != null ? rows[i] : i];
DoubleMeanHolder.update(buf, positions[i] + positionOffset, val);
}
}
}

View File

@ -26,6 +26,7 @@ import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.Row;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence;
@ -46,6 +47,7 @@ import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.IncrementalIndexSegment;
import org.apache.druid.segment.QueryableIndexSegment;
import org.apache.druid.segment.Segment;
import org.apache.druid.segment.TestIndex;
import org.apache.druid.timeline.SegmentId;
import org.easymock.EasyMock;
import org.junit.Assert;
@ -72,6 +74,7 @@ public class DoubleMeanAggregationTest
private final AggregationTestHelper timeseriesQueryTestHelper;
private final List<Segment> segments;
private final List<Segment> biggerSegments;
public DoubleMeanAggregationTest()
{
@ -91,6 +94,11 @@ public class DoubleMeanAggregationTest
new IncrementalIndexSegment(SimpleTestIndex.getIncrementalTestIndex(), SegmentId.dummy("test1")),
new QueryableIndexSegment(SimpleTestIndex.getMMappedTestIndex(), SegmentId.dummy("test2"))
);
biggerSegments = ImmutableList.of(
new IncrementalIndexSegment(TestIndex.getIncrementalTestIndex(), SegmentId.dummy("test1")),
new QueryableIndexSegment(TestIndex.getMMappedTestIndex(), SegmentId.dummy("test2"))
);
}
@Test
@ -145,6 +153,33 @@ public class DoubleMeanAggregationTest
Assert.assertEquals(6.2d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
}
@Test
@Parameters(method = "doVectorize")
public void testVectorAggretatorUsingGroupByQueryOnDoubleColumnOnBiggerSegments(boolean doVectorize) throws Exception
{
GroupByQuery query = new GroupByQuery.Builder()
.setDataSource("blah")
.setGranularity(Granularities.ALL)
.setInterval("1970/2050")
.setAggregatorSpecs(
new DoubleMeanAggregatorFactory("meanOnDouble", TestIndex.COLUMNS[9])
)
.setContext(Collections.singletonMap(QueryContexts.VECTORIZE_KEY, doVectorize))
.build();
// do json serialization and deserialization of query to ensure there are no serde issues
ObjectMapper jsonMapper = groupByQueryTestHelper.getObjectMapper();
query = (GroupByQuery) jsonMapper.readValue(jsonMapper.writeValueAsString(query), Query.class);
Sequence<ResultRow> seq = groupByQueryTestHelper.runQueryOnSegmentsObjs(biggerSegments, query);
Row result = Iterables.getOnlyElement(seq.toList()).toMapBasedRow(query);
if (NullHandling.replaceWithDefault()) {
Assert.assertEquals(39.2307d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
} else {
Assert.assertEquals(51.0d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
}
}
@Test
@Parameters(method = "doVectorize")
public void testAggretatorUsingTimeseriesQuery(boolean doVectorize) throws Exception

View File

@ -33,6 +33,7 @@ import com.google.common.collect.Sets;
import org.apache.druid.collections.CloseableDefaultBlockingPool;
import org.apache.druid.collections.CloseableStupidPool;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.Row;
import org.apache.druid.data.input.Rows;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.IAE;
@ -82,6 +83,7 @@ import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory;
import org.apache.druid.query.aggregation.mean.DoubleMeanAggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.ConstantPostAggregator;
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
@ -5881,6 +5883,33 @@ public class GroupByQueryRunnerTest extends InitializedNullHandlingTest
TestHelper.assertExpectedObjects(expectedResults, results, "subquery-different-intervals");
}
@Test
public void testDoubleMeanQuery()
{
GroupByQuery query = new GroupByQuery.Builder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setGranularity(Granularities.ALL)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setAggregatorSpecs(
new DoubleMeanAggregatorFactory("meanOnDouble", "doubleNumericNull")
)
.build();
if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) {
expectedException.expect(ISE.class);
expectedException.expectMessage("Unable to handle complex type");
GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
} else {
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
Row result = Iterables.getOnlyElement(results).toMapBasedRow(query);
if (NullHandling.replaceWithDefault()) {
Assert.assertEquals(39.2307d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
} else {
Assert.assertEquals(51.0d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
}
}
}
@Test
public void testGroupByTimeExtractionNamedUnderUnderTime()
{