mirror of https://github.com/apache/druid.git
Graceful null handling and correctness in DoubleMean Aggregator (#12320)
* Adding null handling for double mean aggregator * Updating code to handle nulls in DoubleMean aggregator * oops last one should have checkstyle issues. fixed * Updating some code and test cases * Checking on object is null in case of numeric aggregator * Adding one more test to improve coverage * Changing one test as asked in the review * Changing one test as asked in the review for nulls
This commit is contained in:
parent
3de1272926
commit
b5195c5095
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
package org.apache.druid.query.aggregation.mean;
|
package org.apache.druid.query.aggregation.mean;
|
||||||
|
|
||||||
|
import org.apache.druid.common.config.NullHandling;
|
||||||
import org.apache.druid.java.util.common.Numbers;
|
import org.apache.druid.java.util.common.Numbers;
|
||||||
import org.apache.druid.query.aggregation.Aggregator;
|
import org.apache.druid.query.aggregation.Aggregator;
|
||||||
import org.apache.druid.segment.ColumnValueSelector;
|
import org.apache.druid.segment.ColumnValueSelector;
|
||||||
|
@ -43,6 +44,10 @@ public class DoubleMeanAggregator implements Aggregator
|
||||||
{
|
{
|
||||||
Object update = selector.getObject();
|
Object update = selector.getObject();
|
||||||
|
|
||||||
|
if (update == null && NullHandling.replaceWithDefault() == false) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (update instanceof DoubleMeanHolder) {
|
if (update instanceof DoubleMeanHolder) {
|
||||||
value.update((DoubleMeanHolder) update);
|
value.update((DoubleMeanHolder) update);
|
||||||
} else if (update instanceof List) {
|
} else if (update instanceof List) {
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
package org.apache.druid.query.aggregation.mean;
|
package org.apache.druid.query.aggregation.mean;
|
||||||
|
|
||||||
|
import org.apache.druid.common.config.NullHandling;
|
||||||
import org.apache.druid.java.util.common.Numbers;
|
import org.apache.druid.java.util.common.Numbers;
|
||||||
import org.apache.druid.query.aggregation.BufferAggregator;
|
import org.apache.druid.query.aggregation.BufferAggregator;
|
||||||
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
|
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
|
||||||
|
@ -51,6 +52,9 @@ public class DoubleMeanBufferAggregator implements BufferAggregator
|
||||||
{
|
{
|
||||||
Object update = selector.getObject();
|
Object update = selector.getObject();
|
||||||
|
|
||||||
|
if (update == null && NullHandling.replaceWithDefault() == false) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (update instanceof DoubleMeanHolder) {
|
if (update instanceof DoubleMeanHolder) {
|
||||||
DoubleMeanHolder.update(buf, position, (DoubleMeanHolder) update);
|
DoubleMeanHolder.update(buf, position, (DoubleMeanHolder) update);
|
||||||
} else if (update instanceof List) {
|
} else if (update instanceof List) {
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
package org.apache.druid.query.aggregation.mean;
|
package org.apache.druid.query.aggregation.mean;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
|
import org.apache.druid.common.config.NullHandling;
|
||||||
import org.apache.druid.query.aggregation.VectorAggregator;
|
import org.apache.druid.query.aggregation.VectorAggregator;
|
||||||
import org.apache.druid.segment.vector.VectorValueSelector;
|
import org.apache.druid.segment.vector.VectorValueSelector;
|
||||||
|
|
||||||
|
@ -45,10 +46,27 @@ public class DoubleMeanVectorAggregator implements VectorAggregator
|
||||||
public void aggregate(final ByteBuffer buf, final int position, final int startRow, final int endRow)
|
public void aggregate(final ByteBuffer buf, final int position, final int startRow, final int endRow)
|
||||||
{
|
{
|
||||||
final double[] vector = selector.getDoubleVector();
|
final double[] vector = selector.getDoubleVector();
|
||||||
|
final boolean[] nulls = selector.getNullVector();
|
||||||
|
|
||||||
|
if (nulls != null) {
|
||||||
|
if (NullHandling.replaceWithDefault()) {
|
||||||
|
for (int i = startRow; i < endRow; i++) {
|
||||||
|
DoubleMeanHolder.update(buf, position, vector[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = startRow; i < endRow; i++) {
|
||||||
|
if (!nulls[i]) {
|
||||||
|
DoubleMeanHolder.update(buf, position, vector[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
for (int i = startRow; i < endRow; i++) {
|
for (int i = startRow; i < endRow; i++) {
|
||||||
DoubleMeanHolder.update(buf, position, vector[i]);
|
DoubleMeanHolder.update(buf, position, vector[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void aggregate(
|
public void aggregate(
|
||||||
|
@ -60,11 +78,28 @@ public class DoubleMeanVectorAggregator implements VectorAggregator
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
final double[] vector = selector.getDoubleVector();
|
final double[] vector = selector.getDoubleVector();
|
||||||
|
final boolean[] nulls = selector.getNullVector();
|
||||||
|
|
||||||
|
if (nulls != null) {
|
||||||
|
if (NullHandling.replaceWithDefault()) {
|
||||||
for (int i = 0; i < numRows; i++) {
|
for (int i = 0; i < numRows; i++) {
|
||||||
final double val = vector[rows != null ? rows[i] : i];
|
final double val = vector[rows != null ? rows[i] : i];
|
||||||
DoubleMeanHolder.update(buf, positions[i] + positionOffset, val);
|
DoubleMeanHolder.update(buf, positions[i] + positionOffset, val);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
for (int j = 0; j < numRows; j++) {
|
||||||
|
if (!nulls[j]) {
|
||||||
|
final double val = vector[rows != null ? rows[j] : j];
|
||||||
|
DoubleMeanHolder.update(buf, positions[j] + positionOffset, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < numRows; i++) {
|
||||||
|
final double val = vector[rows != null ? rows[i] : i];
|
||||||
|
DoubleMeanHolder.update(buf, positions[i] + positionOffset, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -26,6 +26,7 @@ import com.google.common.collect.Iterables;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import junitparams.JUnitParamsRunner;
|
import junitparams.JUnitParamsRunner;
|
||||||
import junitparams.Parameters;
|
import junitparams.Parameters;
|
||||||
|
import org.apache.druid.common.config.NullHandling;
|
||||||
import org.apache.druid.data.input.Row;
|
import org.apache.druid.data.input.Row;
|
||||||
import org.apache.druid.java.util.common.granularity.Granularities;
|
import org.apache.druid.java.util.common.granularity.Granularities;
|
||||||
import org.apache.druid.java.util.common.guava.Sequence;
|
import org.apache.druid.java.util.common.guava.Sequence;
|
||||||
|
@ -46,6 +47,7 @@ import org.apache.druid.segment.ColumnSelectorFactory;
|
||||||
import org.apache.druid.segment.IncrementalIndexSegment;
|
import org.apache.druid.segment.IncrementalIndexSegment;
|
||||||
import org.apache.druid.segment.QueryableIndexSegment;
|
import org.apache.druid.segment.QueryableIndexSegment;
|
||||||
import org.apache.druid.segment.Segment;
|
import org.apache.druid.segment.Segment;
|
||||||
|
import org.apache.druid.segment.TestIndex;
|
||||||
import org.apache.druid.timeline.SegmentId;
|
import org.apache.druid.timeline.SegmentId;
|
||||||
import org.easymock.EasyMock;
|
import org.easymock.EasyMock;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
|
@ -72,6 +74,7 @@ public class DoubleMeanAggregationTest
|
||||||
private final AggregationTestHelper timeseriesQueryTestHelper;
|
private final AggregationTestHelper timeseriesQueryTestHelper;
|
||||||
|
|
||||||
private final List<Segment> segments;
|
private final List<Segment> segments;
|
||||||
|
private final List<Segment> biggerSegments;
|
||||||
|
|
||||||
public DoubleMeanAggregationTest()
|
public DoubleMeanAggregationTest()
|
||||||
{
|
{
|
||||||
|
@ -91,6 +94,11 @@ public class DoubleMeanAggregationTest
|
||||||
new IncrementalIndexSegment(SimpleTestIndex.getIncrementalTestIndex(), SegmentId.dummy("test1")),
|
new IncrementalIndexSegment(SimpleTestIndex.getIncrementalTestIndex(), SegmentId.dummy("test1")),
|
||||||
new QueryableIndexSegment(SimpleTestIndex.getMMappedTestIndex(), SegmentId.dummy("test2"))
|
new QueryableIndexSegment(SimpleTestIndex.getMMappedTestIndex(), SegmentId.dummy("test2"))
|
||||||
);
|
);
|
||||||
|
|
||||||
|
biggerSegments = ImmutableList.of(
|
||||||
|
new IncrementalIndexSegment(TestIndex.getIncrementalTestIndex(), SegmentId.dummy("test1")),
|
||||||
|
new QueryableIndexSegment(TestIndex.getMMappedTestIndex(), SegmentId.dummy("test2"))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -145,6 +153,33 @@ public class DoubleMeanAggregationTest
|
||||||
Assert.assertEquals(6.2d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
|
Assert.assertEquals(6.2d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Parameters(method = "doVectorize")
|
||||||
|
public void testVectorAggretatorUsingGroupByQueryOnDoubleColumnOnBiggerSegments(boolean doVectorize) throws Exception
|
||||||
|
{
|
||||||
|
GroupByQuery query = new GroupByQuery.Builder()
|
||||||
|
.setDataSource("blah")
|
||||||
|
.setGranularity(Granularities.ALL)
|
||||||
|
.setInterval("1970/2050")
|
||||||
|
.setAggregatorSpecs(
|
||||||
|
new DoubleMeanAggregatorFactory("meanOnDouble", TestIndex.COLUMNS[9])
|
||||||
|
)
|
||||||
|
.setContext(Collections.singletonMap(QueryContexts.VECTORIZE_KEY, doVectorize))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// do json serialization and deserialization of query to ensure there are no serde issues
|
||||||
|
ObjectMapper jsonMapper = groupByQueryTestHelper.getObjectMapper();
|
||||||
|
query = (GroupByQuery) jsonMapper.readValue(jsonMapper.writeValueAsString(query), Query.class);
|
||||||
|
|
||||||
|
Sequence<ResultRow> seq = groupByQueryTestHelper.runQueryOnSegmentsObjs(biggerSegments, query);
|
||||||
|
Row result = Iterables.getOnlyElement(seq.toList()).toMapBasedRow(query);
|
||||||
|
if (NullHandling.replaceWithDefault()) {
|
||||||
|
Assert.assertEquals(39.2307d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
|
||||||
|
} else {
|
||||||
|
Assert.assertEquals(51.0d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Parameters(method = "doVectorize")
|
@Parameters(method = "doVectorize")
|
||||||
public void testAggretatorUsingTimeseriesQuery(boolean doVectorize) throws Exception
|
public void testAggretatorUsingTimeseriesQuery(boolean doVectorize) throws Exception
|
||||||
|
|
|
@ -33,6 +33,7 @@ import com.google.common.collect.Sets;
|
||||||
import org.apache.druid.collections.CloseableDefaultBlockingPool;
|
import org.apache.druid.collections.CloseableDefaultBlockingPool;
|
||||||
import org.apache.druid.collections.CloseableStupidPool;
|
import org.apache.druid.collections.CloseableStupidPool;
|
||||||
import org.apache.druid.common.config.NullHandling;
|
import org.apache.druid.common.config.NullHandling;
|
||||||
|
import org.apache.druid.data.input.Row;
|
||||||
import org.apache.druid.data.input.Rows;
|
import org.apache.druid.data.input.Rows;
|
||||||
import org.apache.druid.java.util.common.DateTimes;
|
import org.apache.druid.java.util.common.DateTimes;
|
||||||
import org.apache.druid.java.util.common.IAE;
|
import org.apache.druid.java.util.common.IAE;
|
||||||
|
@ -82,6 +83,7 @@ import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory;
|
||||||
import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
|
import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
|
||||||
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
|
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
|
||||||
import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory;
|
import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory;
|
||||||
|
import org.apache.druid.query.aggregation.mean.DoubleMeanAggregatorFactory;
|
||||||
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
|
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
|
||||||
import org.apache.druid.query.aggregation.post.ConstantPostAggregator;
|
import org.apache.druid.query.aggregation.post.ConstantPostAggregator;
|
||||||
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
|
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
|
||||||
|
@ -5881,6 +5883,33 @@ public class GroupByQueryRunnerTest extends InitializedNullHandlingTest
|
||||||
TestHelper.assertExpectedObjects(expectedResults, results, "subquery-different-intervals");
|
TestHelper.assertExpectedObjects(expectedResults, results, "subquery-different-intervals");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDoubleMeanQuery()
|
||||||
|
{
|
||||||
|
GroupByQuery query = new GroupByQuery.Builder()
|
||||||
|
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
|
||||||
|
.setGranularity(Granularities.ALL)
|
||||||
|
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
|
||||||
|
.setAggregatorSpecs(
|
||||||
|
new DoubleMeanAggregatorFactory("meanOnDouble", "doubleNumericNull")
|
||||||
|
)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) {
|
||||||
|
expectedException.expect(ISE.class);
|
||||||
|
expectedException.expectMessage("Unable to handle complex type");
|
||||||
|
GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
|
||||||
|
} else {
|
||||||
|
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
|
||||||
|
Row result = Iterables.getOnlyElement(results).toMapBasedRow(query);
|
||||||
|
if (NullHandling.replaceWithDefault()) {
|
||||||
|
Assert.assertEquals(39.2307d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
|
||||||
|
} else {
|
||||||
|
Assert.assertEquals(51.0d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGroupByTimeExtractionNamedUnderUnderTime()
|
public void testGroupByTimeExtractionNamedUnderUnderTime()
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue