Fix VARIANCE aggregator comparator (#10340)

* Fix VARIANCE aggregator comparator

The comparator for the variance aggregator used to compare values using the
count. This is now fixed to compare values using the variance. If the variance
is equal, the count and sum are used as tie breakers.

* fix tests + sql compatible mode

* code review

* more tests

* fix last test
This commit is contained in:
Suneet Saldanha 2020-09-03 17:38:37 -07:00 committed by GitHub
parent 3fc8bc0701
commit a5cd5f1e84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 268 additions and 52 deletions

View File

@ -136,6 +136,16 @@
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>pl.pragmatists</groupId>
<artifactId>JUnitParams</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.easymock</groupId>
<artifactId>easymock</artifactId>

View File

@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions;
import com.google.common.collect.Sets;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
@ -31,6 +32,7 @@ import org.apache.druid.query.aggregation.post.PostAggregatorIds;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.column.ValueType;
import javax.annotation.Nullable;
import java.util.Comparator;
import java.util.Map;
import java.util.Objects;
@ -73,9 +75,11 @@ public class StandardDeviationPostAggregator implements PostAggregator
}
@Override
public Object compute(Map<String, Object> combinedAggregators)
@Nullable
public Double compute(Map<String, Object> combinedAggregators)
{
return Math.sqrt(((VarianceAggregatorCollector) combinedAggregators.get(fieldName)).getVariance(isVariancePop));
Double variance = ((VarianceAggregatorCollector) combinedAggregators.get(fieldName)).getVariance(isVariancePop);
return variance == null ? NullHandling.defaultDoubleValue() : (Double) Math.sqrt(variance);
}
@Override

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Longs;
import org.apache.druid.common.config.NullHandling;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@ -60,11 +61,11 @@ public class VarianceAggregatorCollector
}
public static final Comparator<VarianceAggregatorCollector> COMPARATOR = (o1, o2) -> {
int compare = Longs.compare(o1.count, o2.count);
int compare = Doubles.compare(o1.nvariance, o2.nvariance);
if (compare == 0) {
compare = Doubles.compare(o1.sum, o2.sum);
compare = Longs.compare(o1.count, o2.count);
if (compare == 0) {
compare = Doubles.compare(o1.nvariance, o2.nvariance);
compare = Doubles.compare(o1.sum, o2.sum);
}
}
return compare;
@ -156,11 +157,11 @@ public class VarianceAggregatorCollector
return this;
}
public double getVariance(boolean variancePop)
@Nullable
public Double getVariance(boolean variancePop)
{
if (count == 0) {
// in SQL standard, we should return null for zero elements. But druid there should not be such a case
throw new IllegalStateException("should not be empty holder");
return NullHandling.defaultDoubleValue();
} else if (count == 1) {
return 0d;
} else {

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregateCombiner;
@ -49,6 +50,7 @@ import java.util.List;
import java.util.Objects;
/**
*
*/
@JsonTypeName("variance")
public class VarianceAggregatorFactory extends AggregatorFactory
@ -239,7 +241,9 @@ public class VarianceAggregatorFactory extends AggregatorFactory
@Override
public Object finalizeComputation(@Nullable Object object)
{
return object == null ? null : ((VarianceAggregatorCollector) object).getVariance(isVariancePop);
return object == null
? NullHandling.defaultDoubleValue()
: ((VarianceAggregatorCollector) object).getVariance(isVariancePop);
}
@Override

View File

@ -19,14 +19,44 @@
package org.apache.druid.query.aggregation.variance;
import com.google.common.collect.ImmutableMap;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
public class StandardDeviationPostAggregatorTest
import java.util.Map;
@RunWith(MockitoJUnitRunner.class)
public class StandardDeviationPostAggregatorTest extends InitializedNullHandlingTest
{
private static final String NAME = "NAME";
private static final String FIELD_NAME = "FIELD_NAME";
private static final String POPULATION = "population";
private static final double VARIANCE = 12.56;
private Map<String, Object> combinedAggregators;
@Mock
private VarianceAggregatorCollector collector;
private StandardDeviationPostAggregator target;
@Before
public void setUp()
{
Mockito.doReturn(VARIANCE).when(collector).getVariance(true);
combinedAggregators = ImmutableMap.of(FIELD_NAME, collector);
target = new StandardDeviationPostAggregator(NAME, FIELD_NAME, POPULATION);
}
@Test
public void testSerde() throws Exception
{
@ -64,4 +94,17 @@ public class StandardDeviationPostAggregatorTest
.usingGetClass()
.verify();
}
@Test
public void testComputeForNullVarianceShouldReturnDefaultDoubleValue()
{
Mockito.when(collector.getVariance(true)).thenReturn(null);
Assert.assertEquals(NullHandling.defaultDoubleValue(), target.compute(combinedAggregators));
}
@Test
public void testComputeForVarianceShouldReturnSqrtOfVariance()
{
Assert.assertEquals(Math.sqrt(VARIANCE), target.compute(combinedAggregators), 1e-15);
}
}

View File

@ -144,6 +144,35 @@ public class VarianceAggregatorCollectorTest extends InitializedNullHandlingTest
}
}
@Test
public void testVarianceComparatorShouldCompareVarianceFirst()
{
VarianceAggregatorCollector v1 = new VarianceAggregatorCollector(4, 2d, 1.0);
VarianceAggregatorCollector v2 = new VarianceAggregatorCollector(3, 5d, 2.0);
Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v1, v2) < 0);
Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v2, v1) > 0);
}
@Test
public void testVarianceComparatorShouldCompareCountsIfVarianceIsEqual()
{
VarianceAggregatorCollector v1 = new VarianceAggregatorCollector(4, 2d, 1.0);
VarianceAggregatorCollector v2 = new VarianceAggregatorCollector(3, 5d, 1.0);
Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v1, v2) > 0);
Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v2, v1) < 0);
}
@Test
public void testVarianceComparatorShouldCompareSumIfVarianceAndCountsAreEqual()
{
VarianceAggregatorCollector v1 = new VarianceAggregatorCollector(4, 2d, 1.0);
VarianceAggregatorCollector v2 = new VarianceAggregatorCollector(4, 5d, 1.0);
Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v1, v2) < 0);
Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v2, v1) > 0);
v2.sum = v1.sum;
Assert.assertEquals(0, VarianceAggregatorCollector.COMPARATOR.compare(v1, v2));
}
private static class FloatHandOver extends TestFloatColumnSelector
{
float v;

View File

@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.Druids;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
@ -28,10 +29,11 @@ import org.apache.druid.query.timeseries.TimeseriesQuery;
import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
public class VarianceAggregatorFactoryTest
public class VarianceAggregatorFactoryTest extends InitializedNullHandlingTest
{
@Test
public void testResultArraySignature()
@ -68,4 +70,19 @@ public class VarianceAggregatorFactoryTest
new TimeseriesQueryQueryToolChest().resultArraySignature(query)
);
}
@Test
public void testFinalizeComputationWithZeroCountShouldReturnNull()
{
VarianceAggregatorFactory target = new VarianceAggregatorFactory("test", "test", null, null);
VarianceAggregatorCollector v1 = new VarianceAggregatorCollector();
Assert.assertEquals(NullHandling.defaultDoubleValue(), target.finalizeComputation(v1));
}
@Test
public void testFinalizeComputationWithNullShouldReturnNull()
{
VarianceAggregatorFactory target = new VarianceAggregatorFactory("test", "test", null, null);
Assert.assertEquals(NullHandling.defaultDoubleValue(), target.finalizeComputation(null));
}
}

View File

@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.variance;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.query.aggregation.TestFloatColumnSelector;
import org.apache.druid.segment.ColumnSelectorFactory;
@ -93,16 +94,10 @@ public class VarianceAggregatorTest extends InitializedNullHandlingTest
Assert.assertEquals(sum, holder.sum, 0.0001);
Assert.assertEquals(nvariance, holder.nvariance, 0.0001);
if (count == 0) {
try {
holder.getVariance(false);
Assert.fail("Should throw ISE");
}
catch (IllegalStateException e) {
Assert.assertTrue(e.getMessage().contains("should not be empty holder"));
}
Assert.assertEquals(NullHandling.defaultDoubleValue(), holder.getVariance(false));
} else {
Assert.assertEquals(holder.getVariance(true), variances_pop[(int) count - 1], 0.0001);
Assert.assertEquals(holder.getVariance(false), variances_samp[(int) count - 1], 0.0001);
Assert.assertEquals(variances_pop[(int) count - 1], holder.getVariance(true), 0.0001);
Assert.assertEquals(variances_samp[(int) count - 1], holder.getVariance(false), 0.0001);
}
}

View File

@ -22,6 +22,8 @@ package org.apache.druid.query.aggregation.variance.sql;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.InputRow;
@ -34,18 +36,34 @@ import org.apache.druid.data.input.impl.LongDimensionSchema;
import org.apache.druid.data.input.impl.MapInputRowParser;
import org.apache.druid.data.input.impl.TimeAndDimsParseSpec;
import org.apache.druid.data.input.impl.TimestampSpec;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.query.Druids;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.QueryRunnerTestHelper;
import org.apache.druid.query.Result;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
import org.apache.druid.query.groupby.orderby.OrderByColumnSpec;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.query.timeseries.TimeseriesQuery;
import org.apache.druid.query.timeseries.TimeseriesQueryEngine;
import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
import org.apache.druid.query.timeseries.TimeseriesQueryRunnerFactory;
import org.apache.druid.query.timeseries.TimeseriesResultValue;
import org.apache.druid.segment.IndexBuilder;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
@ -74,10 +92,15 @@ import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RunWith(JUnitParamsRunner.class)
public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
{
private static AuthenticationResult authenticationResult = CalciteTests.REGULAR_USER_AUTH_RESULT;
@ -86,6 +109,8 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
private static QueryRunnerFactoryConglomerate conglomerate;
private static Closer resourceCloser;
private SqlLifecycle sqlLifecycle;
@BeforeClass
public static void setUpClass()
{
@ -181,6 +206,8 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
CalciteTests.DRUID_SCHEMA_NAME
)
);
queryLogHook.clearRecordedQueries();
sqlLifecycle = sqlLifecycleFactory.factorize();
}
@After
@ -221,7 +248,6 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
@Test
public void testVarPop() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
final String sql = "SELECT\n"
+ "VAR_POP(d1),\n"
+ "VAR_POP(f1),\n"
@ -251,14 +277,11 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
holder1.getVariance(true),
(float) holder2.getVariance(true),
(long) holder3.getVariance(true),
holder2.getVariance(true).floatValue(),
holder3.getVariance(true).longValue()
}
);
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
}
assertResultsEquals(expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
@ -281,7 +304,6 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
@Test
public void testVarSamp() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
final String sql = "SELECT\n"
+ "VAR_SAMP(d1),\n"
+ "VAR_SAMP(f1),\n"
@ -311,14 +333,11 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
holder1.getVariance(false),
(float) holder2.getVariance(false),
(long) holder3.getVariance(false),
holder2.getVariance(false).floatValue(),
holder3.getVariance(false).longValue(),
}
);
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
}
assertResultsEquals(expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
@ -341,7 +360,6 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
@Test
public void testStdDevPop() throws Exception
{
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
final String sql = "SELECT\n"
+ "STDDEV_POP(d1),\n"
+ "STDDEV_POP(f1),\n"
@ -375,10 +393,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
(long) Math.sqrt(holder3.getVariance(true)),
}
);
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
}
assertResultsEquals(expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
@ -407,8 +422,6 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
@Test
public void testStdDevSamp() throws Exception
{
queryLogHook.clearRecordedQueries();
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
final String sql = "SELECT\n"
+ "STDDEV_SAMP(d1),\n"
+ "STDDEV_SAMP(f1),\n"
@ -442,10 +455,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
(long) Math.sqrt(holder3.getVariance(false)),
}
);
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
}
assertResultsEquals(expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
@ -469,12 +479,10 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
@Test
public void testStdDevWithVirtualColumns() throws Exception
{
queryLogHook.clearRecordedQueries();
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
final String sql = "SELECT\n"
+ "STDDEV(d1*7),\n"
+ "STDDEV(f1*7),\n"
@ -508,10 +516,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
(long) Math.sqrt(holder3.getVariance(false)),
}
);
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
}
assertResultsEquals(expectedResults, results);
Assert.assertEquals(
Druids.newTimeseriesQueryBuilder()
@ -540,4 +545,112 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
@Test
public void testVarianceOrderBy() throws Exception
{
queryLogHook.clearRecordedQueries();
final String sql = "select dim2, VARIANCE(f1) from druid.numfoo group by 1 order by 2 desc";
final List<Object[]> results =
sqlLifecycle.runSimple(
sql,
BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT,
CalciteTestBase.DEFAULT_PARAMETERS,
authenticationResult
).toList();
List<Object[]> expectedResults = NullHandling.sqlCompatible()
? ImmutableList.of(
new Object[] {"a", 0f},
new Object[] {null, 0f},
new Object[] {"", 0f},
new Object[] {"abc", null}
) : ImmutableList.of(
new Object[] {"a", 0.5f},
new Object[] {"", 0.0033333334f},
new Object[] {"abc", 0f}
);
assertResultsEquals(expectedResults, results);
Assert.assertEquals(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE3)
.setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.setGranularity(Granularities.ALL)
.setDimensions(new DefaultDimensionSpec("dim2", "_d0"))
.setAggregatorSpecs(
new VarianceAggregatorFactory("a0:agg", "f1", "sample", "float")
)
.setLimitSpec(
DefaultLimitSpec
.builder()
.orderBy(
new OrderByColumnSpec(
"a0:agg",
OrderByColumnSpec.Direction.DESCENDING,
StringComparators.NUMERIC
)
)
.build()
)
.setContext(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
.build(),
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
);
}
public Object[] timeseriesQueryRunners()
{
return QueryRunnerTestHelper.makeQueryRunners(
new TimeseriesQueryRunnerFactory(
new TimeseriesQueryQueryToolChest(),
new TimeseriesQueryEngine(),
QueryRunnerTestHelper.NOOP_QUERYWATCHER
)
).toArray();
}
@Test
@Parameters(method = "timeseriesQueryRunners")
public void testEmptyTimeseries(QueryRunner<Result<TimeseriesResultValue>> runner)
{
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(QueryRunnerTestHelper.DATA_SOURCE)
.granularity(QueryRunnerTestHelper.ALL_GRAN)
.intervals(QueryRunnerTestHelper.EMPTY_INTERVAL)
.aggregators(
Arrays.asList(
QueryRunnerTestHelper.ROWS_COUNT,
QueryRunnerTestHelper.INDEX_DOUBLE_SUM,
new VarianceAggregatorFactory("variance", "index")
)
)
.descending(true)
.context(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
.build();
Map<String, Object> resultMap = new HashMap<>();
resultMap.put("rows", 0L);
resultMap.put("index", NullHandling.defaultDoubleValue());
resultMap.put("variance", NullHandling.defaultDoubleValue());
List<Result<TimeseriesResultValue>> expectedResults = ImmutableList.of(
new Result<>(
DateTimes.of("2020-04-02"),
new TimeseriesResultValue(
resultMap
)
)
);
Iterable<Result<TimeseriesResultValue>> actualResults = runner.run(QueryPlus.wrap(query)).toList();
TestHelper.assertExpectedResults(expectedResults, actualResults);
}
private static void assertResultsEquals(List<Object[]> expectedResults, List<Object[]> results)
{
Assert.assertEquals(expectedResults.size(), results.size());
for (int i = 0; i < expectedResults.size(); i++) {
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
}
}
}