Graceful null handling and correctness in DoubleMean Aggregator (#12320)

* Adding null handling for double mean aggregator

* Updating code to handle nulls in DoubleMean aggregator

* oops last one should have checkstyle issues. fixed

* Updating some code and test cases

* Checking on object is null in case of numeric aggregator

* Adding one more test to improve coverage

* Changing one test as asked in the review

* Changing one test as asked in the review for nulls
This commit is contained in:
somu-imply 2022-03-14 16:52:47 -07:00 committed by GitHub
parent 3de1272926
commit b5195c5095
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 5 deletions

View File

@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.mean; package org.apache.druid.query.aggregation.mean;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Numbers; import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.query.aggregation.Aggregator; import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.ColumnValueSelector;
@ -43,6 +44,10 @@ public class DoubleMeanAggregator implements Aggregator
{ {
Object update = selector.getObject(); Object update = selector.getObject();
if (update == null && NullHandling.replaceWithDefault() == false) {
return;
}
if (update instanceof DoubleMeanHolder) { if (update instanceof DoubleMeanHolder) {
value.update((DoubleMeanHolder) update); value.update((DoubleMeanHolder) update);
} else if (update instanceof List) { } else if (update instanceof List) {

View File

@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.mean; package org.apache.druid.query.aggregation.mean;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Numbers; import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
@ -51,6 +52,9 @@ public class DoubleMeanBufferAggregator implements BufferAggregator
{ {
Object update = selector.getObject(); Object update = selector.getObject();
if (update == null && NullHandling.replaceWithDefault() == false) {
return;
}
if (update instanceof DoubleMeanHolder) { if (update instanceof DoubleMeanHolder) {
DoubleMeanHolder.update(buf, position, (DoubleMeanHolder) update); DoubleMeanHolder.update(buf, position, (DoubleMeanHolder) update);
} else if (update instanceof List) { } else if (update instanceof List) {

View File

@ -20,6 +20,7 @@
package org.apache.druid.query.aggregation.mean; package org.apache.druid.query.aggregation.mean;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.VectorAggregator; import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.segment.vector.VectorValueSelector; import org.apache.druid.segment.vector.VectorValueSelector;
@ -45,10 +46,27 @@ public class DoubleMeanVectorAggregator implements VectorAggregator
public void aggregate(final ByteBuffer buf, final int position, final int startRow, final int endRow) public void aggregate(final ByteBuffer buf, final int position, final int startRow, final int endRow)
{ {
final double[] vector = selector.getDoubleVector(); final double[] vector = selector.getDoubleVector();
final boolean[] nulls = selector.getNullVector();
if (nulls != null) {
if (NullHandling.replaceWithDefault()) {
for (int i = startRow; i < endRow; i++) {
DoubleMeanHolder.update(buf, position, vector[i]);
}
} else {
for (int i = startRow; i < endRow; i++) {
if (!nulls[i]) {
DoubleMeanHolder.update(buf, position, vector[i]);
}
}
}
} else {
for (int i = startRow; i < endRow; i++) { for (int i = startRow; i < endRow; i++) {
DoubleMeanHolder.update(buf, position, vector[i]); DoubleMeanHolder.update(buf, position, vector[i]);
} }
} }
}
@Override @Override
public void aggregate( public void aggregate(
@ -60,11 +78,28 @@ public class DoubleMeanVectorAggregator implements VectorAggregator
) )
{ {
final double[] vector = selector.getDoubleVector(); final double[] vector = selector.getDoubleVector();
final boolean[] nulls = selector.getNullVector();
if (nulls != null) {
if (NullHandling.replaceWithDefault()) {
for (int i = 0; i < numRows; i++) { for (int i = 0; i < numRows; i++) {
final double val = vector[rows != null ? rows[i] : i]; final double val = vector[rows != null ? rows[i] : i];
DoubleMeanHolder.update(buf, positions[i] + positionOffset, val); DoubleMeanHolder.update(buf, positions[i] + positionOffset, val);
} }
} else {
for (int j = 0; j < numRows; j++) {
if (!nulls[j]) {
final double val = vector[rows != null ? rows[j] : j];
DoubleMeanHolder.update(buf, positions[j] + positionOffset, val);
}
}
}
} else {
for (int i = 0; i < numRows; i++) {
final double val = vector[rows != null ? rows[i] : i];
DoubleMeanHolder.update(buf, positions[i] + positionOffset, val);
}
}
} }
@Override @Override

View File

@ -26,6 +26,7 @@ import com.google.common.collect.Iterables;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import junitparams.JUnitParamsRunner; import junitparams.JUnitParamsRunner;
import junitparams.Parameters; import junitparams.Parameters;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.Row; import org.apache.druid.data.input.Row;
import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequence;
@ -46,6 +47,7 @@ import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.IncrementalIndexSegment; import org.apache.druid.segment.IncrementalIndexSegment;
import org.apache.druid.segment.QueryableIndexSegment; import org.apache.druid.segment.QueryableIndexSegment;
import org.apache.druid.segment.Segment; import org.apache.druid.segment.Segment;
import org.apache.druid.segment.TestIndex;
import org.apache.druid.timeline.SegmentId; import org.apache.druid.timeline.SegmentId;
import org.easymock.EasyMock; import org.easymock.EasyMock;
import org.junit.Assert; import org.junit.Assert;
@ -72,6 +74,7 @@ public class DoubleMeanAggregationTest
private final AggregationTestHelper timeseriesQueryTestHelper; private final AggregationTestHelper timeseriesQueryTestHelper;
private final List<Segment> segments; private final List<Segment> segments;
private final List<Segment> biggerSegments;
public DoubleMeanAggregationTest() public DoubleMeanAggregationTest()
{ {
@ -91,6 +94,11 @@ public class DoubleMeanAggregationTest
new IncrementalIndexSegment(SimpleTestIndex.getIncrementalTestIndex(), SegmentId.dummy("test1")), new IncrementalIndexSegment(SimpleTestIndex.getIncrementalTestIndex(), SegmentId.dummy("test1")),
new QueryableIndexSegment(SimpleTestIndex.getMMappedTestIndex(), SegmentId.dummy("test2")) new QueryableIndexSegment(SimpleTestIndex.getMMappedTestIndex(), SegmentId.dummy("test2"))
); );
biggerSegments = ImmutableList.of(
new IncrementalIndexSegment(TestIndex.getIncrementalTestIndex(), SegmentId.dummy("test1")),
new QueryableIndexSegment(TestIndex.getMMappedTestIndex(), SegmentId.dummy("test2"))
);
} }
@Test @Test
@ -145,6 +153,33 @@ public class DoubleMeanAggregationTest
Assert.assertEquals(6.2d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d); Assert.assertEquals(6.2d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
} }
@Test
@Parameters(method = "doVectorize")
public void testVectorAggretatorUsingGroupByQueryOnDoubleColumnOnBiggerSegments(boolean doVectorize) throws Exception
{
GroupByQuery query = new GroupByQuery.Builder()
.setDataSource("blah")
.setGranularity(Granularities.ALL)
.setInterval("1970/2050")
.setAggregatorSpecs(
new DoubleMeanAggregatorFactory("meanOnDouble", TestIndex.COLUMNS[9])
)
.setContext(Collections.singletonMap(QueryContexts.VECTORIZE_KEY, doVectorize))
.build();
// do json serialization and deserialization of query to ensure there are no serde issues
ObjectMapper jsonMapper = groupByQueryTestHelper.getObjectMapper();
query = (GroupByQuery) jsonMapper.readValue(jsonMapper.writeValueAsString(query), Query.class);
Sequence<ResultRow> seq = groupByQueryTestHelper.runQueryOnSegmentsObjs(biggerSegments, query);
Row result = Iterables.getOnlyElement(seq.toList()).toMapBasedRow(query);
if (NullHandling.replaceWithDefault()) {
Assert.assertEquals(39.2307d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
} else {
Assert.assertEquals(51.0d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
}
}
@Test @Test
@Parameters(method = "doVectorize") @Parameters(method = "doVectorize")
public void testAggretatorUsingTimeseriesQuery(boolean doVectorize) throws Exception public void testAggretatorUsingTimeseriesQuery(boolean doVectorize) throws Exception

View File

@ -33,6 +33,7 @@ import com.google.common.collect.Sets;
import org.apache.druid.collections.CloseableDefaultBlockingPool; import org.apache.druid.collections.CloseableDefaultBlockingPool;
import org.apache.druid.collections.CloseableStupidPool; import org.apache.druid.collections.CloseableStupidPool;
import org.apache.druid.common.config.NullHandling; import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.Row;
import org.apache.druid.data.input.Rows; import org.apache.druid.data.input.Rows;
import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
@ -82,6 +83,7 @@ import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator; import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory; import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory; import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory;
import org.apache.druid.query.aggregation.mean.DoubleMeanAggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.ConstantPostAggregator; import org.apache.druid.query.aggregation.post.ConstantPostAggregator;
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator; import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
@ -5881,6 +5883,33 @@ public class GroupByQueryRunnerTest extends InitializedNullHandlingTest
TestHelper.assertExpectedObjects(expectedResults, results, "subquery-different-intervals"); TestHelper.assertExpectedObjects(expectedResults, results, "subquery-different-intervals");
} }
@Test
public void testDoubleMeanQuery()
{
GroupByQuery query = new GroupByQuery.Builder()
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
.setGranularity(Granularities.ALL)
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
.setAggregatorSpecs(
new DoubleMeanAggregatorFactory("meanOnDouble", "doubleNumericNull")
)
.build();
if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) {
expectedException.expect(ISE.class);
expectedException.expectMessage("Unable to handle complex type");
GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
} else {
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
Row result = Iterables.getOnlyElement(results).toMapBasedRow(query);
if (NullHandling.replaceWithDefault()) {
Assert.assertEquals(39.2307d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
} else {
Assert.assertEquals(51.0d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
}
}
}
@Test @Test
public void testGroupByTimeExtractionNamedUnderUnderTime() public void testGroupByTimeExtractionNamedUnderUnderTime()
{ {