diff --git a/extensions-core/stats/pom.xml b/extensions-core/stats/pom.xml
index 13e6214a4d0..f2b605a7f9e 100644
--- a/extensions-core/stats/pom.xml
+++ b/extensions-core/stats/pom.xml
@@ -136,6 +136,16 @@
junit
test
+
+ pl.pragmatists
+ JUnitParams
+ test
+
+
+ org.mockito
+ mockito-core
+ test
+
org.easymock
easymock
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregator.java
index 3c18e2c737d..c39ac8b66b7 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregator.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregator.java
@@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions;
import com.google.common.collect.Sets;
+import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
@@ -31,6 +32,7 @@ import org.apache.druid.query.aggregation.post.PostAggregatorIds;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.column.ValueType;
+import javax.annotation.Nullable;
import java.util.Comparator;
import java.util.Map;
import java.util.Objects;
@@ -73,9 +75,11 @@ public class StandardDeviationPostAggregator implements PostAggregator
}
@Override
- public Object compute(Map combinedAggregators)
+ @Nullable
+ public Double compute(Map combinedAggregators)
{
- return Math.sqrt(((VarianceAggregatorCollector) combinedAggregators.get(fieldName)).getVariance(isVariancePop));
+ Double variance = ((VarianceAggregatorCollector) combinedAggregators.get(fieldName)).getVariance(isVariancePop);
+ return variance == null ? NullHandling.defaultDoubleValue() : (Double) Math.sqrt(variance);
}
@Override
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java
index 5dbcea763b4..ce0edb04f41 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollector.java
@@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Longs;
+import org.apache.druid.common.config.NullHandling;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@@ -60,11 +61,11 @@ public class VarianceAggregatorCollector
}
public static final Comparator COMPARATOR = (o1, o2) -> {
- int compare = Longs.compare(o1.count, o2.count);
+ int compare = Doubles.compare(o1.nvariance, o2.nvariance);
if (compare == 0) {
- compare = Doubles.compare(o1.sum, o2.sum);
+ compare = Longs.compare(o1.count, o2.count);
if (compare == 0) {
- compare = Doubles.compare(o1.nvariance, o2.nvariance);
+ compare = Doubles.compare(o1.sum, o2.sum);
}
}
return compare;
@@ -156,11 +157,11 @@ public class VarianceAggregatorCollector
return this;
}
- public double getVariance(boolean variancePop)
+ @Nullable
+ public Double getVariance(boolean variancePop)
{
if (count == 0) {
- // in SQL standard, we should return null for zero elements. But druid there should not be such a case
- throw new IllegalStateException("should not be empty holder");
+ return NullHandling.defaultDoubleValue();
} else if (count == 1) {
return 0d;
} else {
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
index 4a415061b1d..e9b59b48cd6 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactory.java
@@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions;
+import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregateCombiner;
@@ -49,6 +50,7 @@ import java.util.List;
import java.util.Objects;
/**
+ *
*/
@JsonTypeName("variance")
public class VarianceAggregatorFactory extends AggregatorFactory
@@ -239,7 +241,9 @@ public class VarianceAggregatorFactory extends AggregatorFactory
@Override
public Object finalizeComputation(@Nullable Object object)
{
- return object == null ? null : ((VarianceAggregatorCollector) object).getVariance(isVariancePop);
+ return object == null
+ ? NullHandling.defaultDoubleValue()
+ : ((VarianceAggregatorCollector) object).getVariance(isVariancePop);
}
@Override
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregatorTest.java
index c0768bb88ed..d9a83f4770e 100644
--- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregatorTest.java
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregatorTest.java
@@ -19,14 +19,44 @@
package org.apache.druid.query.aggregation.variance;
+import com.google.common.collect.ImmutableMap;
import nl.jqno.equalsverifier.EqualsVerifier;
+import org.apache.druid.common.config.NullHandling;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.query.aggregation.PostAggregator;
+import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
+import org.junit.Before;
import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnitRunner;
-public class StandardDeviationPostAggregatorTest
+import java.util.Map;
+
+@RunWith(MockitoJUnitRunner.class)
+public class StandardDeviationPostAggregatorTest extends InitializedNullHandlingTest
{
+ private static final String NAME = "NAME";
+ private static final String FIELD_NAME = "FIELD_NAME";
+ private static final String POPULATION = "population";
+ private static final double VARIANCE = 12.56;
+
+ private Map combinedAggregators;
+ @Mock
+ private VarianceAggregatorCollector collector;
+
+ private StandardDeviationPostAggregator target;
+
+ @Before
+ public void setUp()
+ {
+ Mockito.doReturn(VARIANCE).when(collector).getVariance(true);
+ combinedAggregators = ImmutableMap.of(FIELD_NAME, collector);
+ target = new StandardDeviationPostAggregator(NAME, FIELD_NAME, POPULATION);
+ }
+
@Test
public void testSerde() throws Exception
{
@@ -64,4 +94,17 @@ public class StandardDeviationPostAggregatorTest
.usingGetClass()
.verify();
}
+
+ @Test
+ public void testComputeForNullVarianceShouldReturnDefaultDoubleValue()
+ {
+ Mockito.when(collector.getVariance(true)).thenReturn(null);
+ Assert.assertEquals(NullHandling.defaultDoubleValue(), target.compute(combinedAggregators));
+ }
+
+ @Test
+ public void testComputeForVarianceShouldReturnSqrtOfVariance()
+ {
+ Assert.assertEquals(Math.sqrt(VARIANCE), target.compute(combinedAggregators), 1e-15);
+ }
}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollectorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollectorTest.java
index 57ab6ea86d4..d8261dd5b19 100644
--- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollectorTest.java
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorCollectorTest.java
@@ -144,6 +144,35 @@ public class VarianceAggregatorCollectorTest extends InitializedNullHandlingTest
}
}
+ @Test
+ public void testVarianceComparatorShouldCompareVarianceFirst()
+ {
+ VarianceAggregatorCollector v1 = new VarianceAggregatorCollector(4, 2d, 1.0);
+ VarianceAggregatorCollector v2 = new VarianceAggregatorCollector(3, 5d, 2.0);
+ Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v1, v2) < 0);
+ Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v2, v1) > 0);
+ }
+
+ @Test
+ public void testVarianceComparatorShouldCompareCountsIfVarianceIsEqual()
+ {
+ VarianceAggregatorCollector v1 = new VarianceAggregatorCollector(4, 2d, 1.0);
+ VarianceAggregatorCollector v2 = new VarianceAggregatorCollector(3, 5d, 1.0);
+ Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v1, v2) > 0);
+ Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v2, v1) < 0);
+ }
+
+ @Test
+ public void testVarianceComparatorShouldCompareSumIfVarianceAndCountsAreEqual()
+ {
+ VarianceAggregatorCollector v1 = new VarianceAggregatorCollector(4, 2d, 1.0);
+ VarianceAggregatorCollector v2 = new VarianceAggregatorCollector(4, 5d, 1.0);
+ Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v1, v2) < 0);
+ Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v2, v1) > 0);
+ v2.sum = v1.sum;
+ Assert.assertEquals(0, VarianceAggregatorCollector.COMPARATOR.compare(v1, v2));
+ }
+
private static class FloatHandOver extends TestFloatColumnSelector
{
float v;
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryTest.java
index 2e51b294940..7f0bbaa1da5 100644
--- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryTest.java
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorFactoryTest.java
@@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.variance;
+import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.Druids;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
@@ -28,10 +29,11 @@ import org.apache.druid.query.timeseries.TimeseriesQuery;
import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
-public class VarianceAggregatorFactoryTest
+public class VarianceAggregatorFactoryTest extends InitializedNullHandlingTest
{
@Test
public void testResultArraySignature()
@@ -68,4 +70,19 @@ public class VarianceAggregatorFactoryTest
new TimeseriesQueryQueryToolChest().resultArraySignature(query)
);
}
+
+ @Test
+ public void testFinalizeComputationWithZeroCountShouldReturnNull()
+ {
+ VarianceAggregatorFactory target = new VarianceAggregatorFactory("test", "test", null, null);
+ VarianceAggregatorCollector v1 = new VarianceAggregatorCollector();
+ Assert.assertEquals(NullHandling.defaultDoubleValue(), target.finalizeComputation(v1));
+ }
+
+ @Test
+ public void testFinalizeComputationWithNullShouldReturnNull()
+ {
+ VarianceAggregatorFactory target = new VarianceAggregatorFactory("test", "test", null, null);
+ Assert.assertEquals(NullHandling.defaultDoubleValue(), target.finalizeComputation(null));
+ }
}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorTest.java
index 6d7c4314890..f3b9453de70 100644
--- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorTest.java
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/VarianceAggregatorTest.java
@@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.variance;
+import org.apache.druid.common.config.NullHandling;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.query.aggregation.TestFloatColumnSelector;
import org.apache.druid.segment.ColumnSelectorFactory;
@@ -93,16 +94,10 @@ public class VarianceAggregatorTest extends InitializedNullHandlingTest
Assert.assertEquals(sum, holder.sum, 0.0001);
Assert.assertEquals(nvariance, holder.nvariance, 0.0001);
if (count == 0) {
- try {
- holder.getVariance(false);
- Assert.fail("Should throw ISE");
- }
- catch (IllegalStateException e) {
- Assert.assertTrue(e.getMessage().contains("should not be empty holder"));
- }
+ Assert.assertEquals(NullHandling.defaultDoubleValue(), holder.getVariance(false));
} else {
- Assert.assertEquals(holder.getVariance(true), variances_pop[(int) count - 1], 0.0001);
- Assert.assertEquals(holder.getVariance(false), variances_samp[(int) count - 1], 0.0001);
+ Assert.assertEquals(variances_pop[(int) count - 1], holder.getVariance(true), 0.0001);
+ Assert.assertEquals(variances_samp[(int) count - 1], holder.getVariance(false), 0.0001);
}
}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
index 8299034d804..cfb945b99cc 100644
--- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
@@ -22,6 +22,8 @@ package org.apache.druid.query.aggregation.variance.sql;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
+import junitparams.JUnitParamsRunner;
+import junitparams.Parameters;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.InputRow;
@@ -34,18 +36,34 @@ import org.apache.druid.data.input.impl.LongDimensionSchema;
import org.apache.druid.data.input.impl.MapInputRowParser;
import org.apache.druid.data.input.impl.TimeAndDimsParseSpec;
import org.apache.druid.data.input.impl.TimestampSpec;
+import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.query.Druids;
+import org.apache.druid.query.QueryPlus;
+import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
+import org.apache.druid.query.QueryRunnerTestHelper;
+import org.apache.druid.query.Result;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
+import org.apache.druid.query.dimension.DefaultDimensionSpec;
+import org.apache.druid.query.groupby.GroupByQuery;
+import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
+import org.apache.druid.query.groupby.orderby.OrderByColumnSpec;
+import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
+import org.apache.druid.query.timeseries.TimeseriesQuery;
+import org.apache.druid.query.timeseries.TimeseriesQueryEngine;
+import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
+import org.apache.druid.query.timeseries.TimeseriesQueryRunnerFactory;
+import org.apache.druid.query.timeseries.TimeseriesResultValue;
import org.apache.druid.segment.IndexBuilder;
import org.apache.druid.segment.QueryableIndex;
+import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
@@ -74,10 +92,15 @@ import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
+@RunWith(JUnitParamsRunner.class)
public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
{
private static AuthenticationResult authenticationResult = CalciteTests.REGULAR_USER_AUTH_RESULT;
@@ -86,6 +109,8 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
private static QueryRunnerFactoryConglomerate conglomerate;
private static Closer resourceCloser;
+ private SqlLifecycle sqlLifecycle;
+
@BeforeClass
public static void setUpClass()
{
@@ -181,6 +206,8 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
CalciteTests.DRUID_SCHEMA_NAME
)
);
+ queryLogHook.clearRecordedQueries();
+ sqlLifecycle = sqlLifecycleFactory.factorize();
}
@After
@@ -221,7 +248,6 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
@Test
public void testVarPop() throws Exception
{
- SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
final String sql = "SELECT\n"
+ "VAR_POP(d1),\n"
+ "VAR_POP(f1),\n"
@@ -251,14 +277,11 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
final List