mirror of https://github.com/apache/druid.git
Graceful null handling and correctness in DoubleMean Aggregator (#12320)
* Adding null handling for double mean aggregator * Updating code to handle nulls in DoubleMean aggregator * oops last one should have checkstyle issues. fixed * Updating some code and test cases * Checking on object is null in case of numeric aggregator * Adding one more test to improve coverage * Changing one test as asked in the review * Changing one test as asked in the review for nulls
This commit is contained in:
parent
3de1272926
commit
b5195c5095
|
@ -19,6 +19,7 @@
|
|||
|
||||
package org.apache.druid.query.aggregation.mean;
|
||||
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.java.util.common.Numbers;
|
||||
import org.apache.druid.query.aggregation.Aggregator;
|
||||
import org.apache.druid.segment.ColumnValueSelector;
|
||||
|
@ -43,6 +44,10 @@ public class DoubleMeanAggregator implements Aggregator
|
|||
{
|
||||
Object update = selector.getObject();
|
||||
|
||||
if (update == null && NullHandling.replaceWithDefault() == false) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (update instanceof DoubleMeanHolder) {
|
||||
value.update((DoubleMeanHolder) update);
|
||||
} else if (update instanceof List) {
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
package org.apache.druid.query.aggregation.mean;
|
||||
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.java.util.common.Numbers;
|
||||
import org.apache.druid.query.aggregation.BufferAggregator;
|
||||
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
|
||||
|
@ -51,6 +52,9 @@ public class DoubleMeanBufferAggregator implements BufferAggregator
|
|||
{
|
||||
Object update = selector.getObject();
|
||||
|
||||
if (update == null && NullHandling.replaceWithDefault() == false) {
|
||||
return;
|
||||
}
|
||||
if (update instanceof DoubleMeanHolder) {
|
||||
DoubleMeanHolder.update(buf, position, (DoubleMeanHolder) update);
|
||||
} else if (update instanceof List) {
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
package org.apache.druid.query.aggregation.mean;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.query.aggregation.VectorAggregator;
|
||||
import org.apache.druid.segment.vector.VectorValueSelector;
|
||||
|
||||
|
@ -45,10 +46,27 @@ public class DoubleMeanVectorAggregator implements VectorAggregator
|
|||
public void aggregate(final ByteBuffer buf, final int position, final int startRow, final int endRow)
|
||||
{
|
||||
final double[] vector = selector.getDoubleVector();
|
||||
final boolean[] nulls = selector.getNullVector();
|
||||
|
||||
if (nulls != null) {
|
||||
if (NullHandling.replaceWithDefault()) {
|
||||
for (int i = startRow; i < endRow; i++) {
|
||||
DoubleMeanHolder.update(buf, position, vector[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = startRow; i < endRow; i++) {
|
||||
if (!nulls[i]) {
|
||||
DoubleMeanHolder.update(buf, position, vector[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = startRow; i < endRow; i++) {
|
||||
DoubleMeanHolder.update(buf, position, vector[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void aggregate(
|
||||
|
@ -60,11 +78,28 @@ public class DoubleMeanVectorAggregator implements VectorAggregator
|
|||
)
|
||||
{
|
||||
final double[] vector = selector.getDoubleVector();
|
||||
final boolean[] nulls = selector.getNullVector();
|
||||
|
||||
if (nulls != null) {
|
||||
if (NullHandling.replaceWithDefault()) {
|
||||
for (int i = 0; i < numRows; i++) {
|
||||
final double val = vector[rows != null ? rows[i] : i];
|
||||
DoubleMeanHolder.update(buf, positions[i] + positionOffset, val);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < numRows; j++) {
|
||||
if (!nulls[j]) {
|
||||
final double val = vector[rows != null ? rows[j] : j];
|
||||
DoubleMeanHolder.update(buf, positions[j] + positionOffset, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < numRows; i++) {
|
||||
final double val = vector[rows != null ? rows[i] : i];
|
||||
DoubleMeanHolder.update(buf, positions[i] + positionOffset, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -26,6 +26,7 @@ import com.google.common.collect.Iterables;
|
|||
import com.google.common.collect.Lists;
|
||||
import junitparams.JUnitParamsRunner;
|
||||
import junitparams.Parameters;
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.data.input.Row;
|
||||
import org.apache.druid.java.util.common.granularity.Granularities;
|
||||
import org.apache.druid.java.util.common.guava.Sequence;
|
||||
|
@ -46,6 +47,7 @@ import org.apache.druid.segment.ColumnSelectorFactory;
|
|||
import org.apache.druid.segment.IncrementalIndexSegment;
|
||||
import org.apache.druid.segment.QueryableIndexSegment;
|
||||
import org.apache.druid.segment.Segment;
|
||||
import org.apache.druid.segment.TestIndex;
|
||||
import org.apache.druid.timeline.SegmentId;
|
||||
import org.easymock.EasyMock;
|
||||
import org.junit.Assert;
|
||||
|
@ -72,6 +74,7 @@ public class DoubleMeanAggregationTest
|
|||
private final AggregationTestHelper timeseriesQueryTestHelper;
|
||||
|
||||
private final List<Segment> segments;
|
||||
private final List<Segment> biggerSegments;
|
||||
|
||||
public DoubleMeanAggregationTest()
|
||||
{
|
||||
|
@ -91,6 +94,11 @@ public class DoubleMeanAggregationTest
|
|||
new IncrementalIndexSegment(SimpleTestIndex.getIncrementalTestIndex(), SegmentId.dummy("test1")),
|
||||
new QueryableIndexSegment(SimpleTestIndex.getMMappedTestIndex(), SegmentId.dummy("test2"))
|
||||
);
|
||||
|
||||
biggerSegments = ImmutableList.of(
|
||||
new IncrementalIndexSegment(TestIndex.getIncrementalTestIndex(), SegmentId.dummy("test1")),
|
||||
new QueryableIndexSegment(TestIndex.getMMappedTestIndex(), SegmentId.dummy("test2"))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -145,6 +153,33 @@ public class DoubleMeanAggregationTest
|
|||
Assert.assertEquals(6.2d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Parameters(method = "doVectorize")
|
||||
public void testVectorAggretatorUsingGroupByQueryOnDoubleColumnOnBiggerSegments(boolean doVectorize) throws Exception
|
||||
{
|
||||
GroupByQuery query = new GroupByQuery.Builder()
|
||||
.setDataSource("blah")
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setInterval("1970/2050")
|
||||
.setAggregatorSpecs(
|
||||
new DoubleMeanAggregatorFactory("meanOnDouble", TestIndex.COLUMNS[9])
|
||||
)
|
||||
.setContext(Collections.singletonMap(QueryContexts.VECTORIZE_KEY, doVectorize))
|
||||
.build();
|
||||
|
||||
// do json serialization and deserialization of query to ensure there are no serde issues
|
||||
ObjectMapper jsonMapper = groupByQueryTestHelper.getObjectMapper();
|
||||
query = (GroupByQuery) jsonMapper.readValue(jsonMapper.writeValueAsString(query), Query.class);
|
||||
|
||||
Sequence<ResultRow> seq = groupByQueryTestHelper.runQueryOnSegmentsObjs(biggerSegments, query);
|
||||
Row result = Iterables.getOnlyElement(seq.toList()).toMapBasedRow(query);
|
||||
if (NullHandling.replaceWithDefault()) {
|
||||
Assert.assertEquals(39.2307d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
|
||||
} else {
|
||||
Assert.assertEquals(51.0d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Parameters(method = "doVectorize")
|
||||
public void testAggretatorUsingTimeseriesQuery(boolean doVectorize) throws Exception
|
||||
|
|
|
@ -33,6 +33,7 @@ import com.google.common.collect.Sets;
|
|||
import org.apache.druid.collections.CloseableDefaultBlockingPool;
|
||||
import org.apache.druid.collections.CloseableStupidPool;
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.data.input.Row;
|
||||
import org.apache.druid.data.input.Rows;
|
||||
import org.apache.druid.java.util.common.DateTimes;
|
||||
import org.apache.druid.java.util.common.IAE;
|
||||
|
@ -82,6 +83,7 @@ import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory;
|
|||
import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
|
||||
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.mean.DoubleMeanAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
|
||||
import org.apache.druid.query.aggregation.post.ConstantPostAggregator;
|
||||
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
|
||||
|
@ -5881,6 +5883,33 @@ public class GroupByQueryRunnerTest extends InitializedNullHandlingTest
|
|||
TestHelper.assertExpectedObjects(expectedResults, results, "subquery-different-intervals");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDoubleMeanQuery()
|
||||
{
|
||||
GroupByQuery query = new GroupByQuery.Builder()
|
||||
.setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
|
||||
.setAggregatorSpecs(
|
||||
new DoubleMeanAggregatorFactory("meanOnDouble", "doubleNumericNull")
|
||||
)
|
||||
.build();
|
||||
|
||||
if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) {
|
||||
expectedException.expect(ISE.class);
|
||||
expectedException.expectMessage("Unable to handle complex type");
|
||||
GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
|
||||
} else {
|
||||
Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
|
||||
Row result = Iterables.getOnlyElement(results).toMapBasedRow(query);
|
||||
if (NullHandling.replaceWithDefault()) {
|
||||
Assert.assertEquals(39.2307d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
|
||||
} else {
|
||||
Assert.assertEquals(51.0d, result.getMetric("meanOnDouble").doubleValue(), 0.0001d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGroupByTimeExtractionNamedUnderUnderTime()
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue