mirror of https://github.com/apache/druid.git
Fix VARIANCE aggregator comparator (#10340)
* Fix VARIANCE aggregator comparator The comparator for the variance aggregator used to compare values using the count. This is now fixed to compare values using the variance. If the variance is equal, the count and sum are used as tie breakers. * fix tests + sql compatible mode * code review * more tests * fix last test
This commit is contained in:
parent
3fc8bc0701
commit
a5cd5f1e84
|
@ -136,6 +136,16 @@
|
|||
<artifactId>junit</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>pl.pragmatists</groupId>
|
||||
<artifactId>JUnitParams</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.easymock</groupId>
|
||||
<artifactId>easymock</artifactId>
|
||||
|
|
|
@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
|
|||
import com.fasterxml.jackson.annotation.JsonTypeName;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Sets;
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.PostAggregator;
|
||||
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
|
||||
|
@ -31,6 +32,7 @@ import org.apache.druid.query.aggregation.post.PostAggregatorIds;
|
|||
import org.apache.druid.query.cache.CacheKeyBuilder;
|
||||
import org.apache.druid.segment.column.ValueType;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.Comparator;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
@ -73,9 +75,11 @@ public class StandardDeviationPostAggregator implements PostAggregator
|
|||
}
|
||||
|
||||
@Override
|
||||
public Object compute(Map<String, Object> combinedAggregators)
|
||||
@Nullable
|
||||
public Double compute(Map<String, Object> combinedAggregators)
|
||||
{
|
||||
return Math.sqrt(((VarianceAggregatorCollector) combinedAggregators.get(fieldName)).getVariance(isVariancePop));
|
||||
Double variance = ((VarianceAggregatorCollector) combinedAggregators.get(fieldName)).getVariance(isVariancePop);
|
||||
return variance == null ? NullHandling.defaultDoubleValue() : (Double) Math.sqrt(variance);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonValue;
|
|||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.primitives.Doubles;
|
||||
import com.google.common.primitives.Longs;
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -60,11 +61,11 @@ public class VarianceAggregatorCollector
|
|||
}
|
||||
|
||||
public static final Comparator<VarianceAggregatorCollector> COMPARATOR = (o1, o2) -> {
|
||||
int compare = Longs.compare(o1.count, o2.count);
|
||||
int compare = Doubles.compare(o1.nvariance, o2.nvariance);
|
||||
if (compare == 0) {
|
||||
compare = Doubles.compare(o1.sum, o2.sum);
|
||||
compare = Longs.compare(o1.count, o2.count);
|
||||
if (compare == 0) {
|
||||
compare = Doubles.compare(o1.nvariance, o2.nvariance);
|
||||
compare = Doubles.compare(o1.sum, o2.sum);
|
||||
}
|
||||
}
|
||||
return compare;
|
||||
|
@ -156,11 +157,11 @@ public class VarianceAggregatorCollector
|
|||
return this;
|
||||
}
|
||||
|
||||
public double getVariance(boolean variancePop)
|
||||
@Nullable
|
||||
public Double getVariance(boolean variancePop)
|
||||
{
|
||||
if (count == 0) {
|
||||
// in SQL standard, we should return null for zero elements. But druid there should not be such a case
|
||||
throw new IllegalStateException("should not be empty holder");
|
||||
return NullHandling.defaultDoubleValue();
|
||||
} else if (count == 1) {
|
||||
return 0d;
|
||||
} else {
|
||||
|
|
|
@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
|
|||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.annotation.JsonTypeName;
|
||||
import com.google.common.base.Preconditions;
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.java.util.common.IAE;
|
||||
import org.apache.druid.java.util.common.StringUtils;
|
||||
import org.apache.druid.query.aggregation.AggregateCombiner;
|
||||
|
@ -49,6 +50,7 @@ import java.util.List;
|
|||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
@JsonTypeName("variance")
|
||||
public class VarianceAggregatorFactory extends AggregatorFactory
|
||||
|
@ -239,7 +241,9 @@ public class VarianceAggregatorFactory extends AggregatorFactory
|
|||
@Override
|
||||
public Object finalizeComputation(@Nullable Object object)
|
||||
{
|
||||
return object == null ? null : ((VarianceAggregatorCollector) object).getVariance(isVariancePop);
|
||||
return object == null
|
||||
? NullHandling.defaultDoubleValue()
|
||||
: ((VarianceAggregatorCollector) object).getVariance(isVariancePop);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -19,14 +19,44 @@
|
|||
|
||||
package org.apache.druid.query.aggregation.variance;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import nl.jqno.equalsverifier.EqualsVerifier;
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.jackson.DefaultObjectMapper;
|
||||
import org.apache.druid.query.aggregation.PostAggregator;
|
||||
import org.apache.druid.testing.InitializedNullHandlingTest;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.Mockito;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
|
||||
public class StandardDeviationPostAggregatorTest
|
||||
import java.util.Map;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class StandardDeviationPostAggregatorTest extends InitializedNullHandlingTest
|
||||
{
|
||||
private static final String NAME = "NAME";
|
||||
private static final String FIELD_NAME = "FIELD_NAME";
|
||||
private static final String POPULATION = "population";
|
||||
private static final double VARIANCE = 12.56;
|
||||
|
||||
private Map<String, Object> combinedAggregators;
|
||||
@Mock
|
||||
private VarianceAggregatorCollector collector;
|
||||
|
||||
private StandardDeviationPostAggregator target;
|
||||
|
||||
@Before
|
||||
public void setUp()
|
||||
{
|
||||
Mockito.doReturn(VARIANCE).when(collector).getVariance(true);
|
||||
combinedAggregators = ImmutableMap.of(FIELD_NAME, collector);
|
||||
target = new StandardDeviationPostAggregator(NAME, FIELD_NAME, POPULATION);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerde() throws Exception
|
||||
{
|
||||
|
@ -64,4 +94,17 @@ public class StandardDeviationPostAggregatorTest
|
|||
.usingGetClass()
|
||||
.verify();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComputeForNullVarianceShouldReturnDefaultDoubleValue()
|
||||
{
|
||||
Mockito.when(collector.getVariance(true)).thenReturn(null);
|
||||
Assert.assertEquals(NullHandling.defaultDoubleValue(), target.compute(combinedAggregators));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComputeForVarianceShouldReturnSqrtOfVariance()
|
||||
{
|
||||
Assert.assertEquals(Math.sqrt(VARIANCE), target.compute(combinedAggregators), 1e-15);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -144,6 +144,35 @@ public class VarianceAggregatorCollectorTest extends InitializedNullHandlingTest
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVarianceComparatorShouldCompareVarianceFirst()
|
||||
{
|
||||
VarianceAggregatorCollector v1 = new VarianceAggregatorCollector(4, 2d, 1.0);
|
||||
VarianceAggregatorCollector v2 = new VarianceAggregatorCollector(3, 5d, 2.0);
|
||||
Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v1, v2) < 0);
|
||||
Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v2, v1) > 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVarianceComparatorShouldCompareCountsIfVarianceIsEqual()
|
||||
{
|
||||
VarianceAggregatorCollector v1 = new VarianceAggregatorCollector(4, 2d, 1.0);
|
||||
VarianceAggregatorCollector v2 = new VarianceAggregatorCollector(3, 5d, 1.0);
|
||||
Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v1, v2) > 0);
|
||||
Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v2, v1) < 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVarianceComparatorShouldCompareSumIfVarianceAndCountsAreEqual()
|
||||
{
|
||||
VarianceAggregatorCollector v1 = new VarianceAggregatorCollector(4, 2d, 1.0);
|
||||
VarianceAggregatorCollector v2 = new VarianceAggregatorCollector(4, 5d, 1.0);
|
||||
Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v1, v2) < 0);
|
||||
Assert.assertTrue(VarianceAggregatorCollector.COMPARATOR.compare(v2, v1) > 0);
|
||||
v2.sum = v1.sum;
|
||||
Assert.assertEquals(0, VarianceAggregatorCollector.COMPARATOR.compare(v1, v2));
|
||||
}
|
||||
|
||||
private static class FloatHandOver extends TestFloatColumnSelector
|
||||
{
|
||||
float v;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
package org.apache.druid.query.aggregation.variance;
|
||||
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.java.util.common.granularity.Granularities;
|
||||
import org.apache.druid.query.Druids;
|
||||
import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
||||
|
@ -28,10 +29,11 @@ import org.apache.druid.query.timeseries.TimeseriesQuery;
|
|||
import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
|
||||
import org.apache.druid.segment.column.RowSignature;
|
||||
import org.apache.druid.segment.column.ValueType;
|
||||
import org.apache.druid.testing.InitializedNullHandlingTest;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class VarianceAggregatorFactoryTest
|
||||
public class VarianceAggregatorFactoryTest extends InitializedNullHandlingTest
|
||||
{
|
||||
@Test
|
||||
public void testResultArraySignature()
|
||||
|
@ -68,4 +70,19 @@ public class VarianceAggregatorFactoryTest
|
|||
new TimeseriesQueryQueryToolChest().resultArraySignature(query)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFinalizeComputationWithZeroCountShouldReturnNull()
|
||||
{
|
||||
VarianceAggregatorFactory target = new VarianceAggregatorFactory("test", "test", null, null);
|
||||
VarianceAggregatorCollector v1 = new VarianceAggregatorCollector();
|
||||
Assert.assertEquals(NullHandling.defaultDoubleValue(), target.finalizeComputation(v1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFinalizeComputationWithNullShouldReturnNull()
|
||||
{
|
||||
VarianceAggregatorFactory target = new VarianceAggregatorFactory("test", "test", null, null);
|
||||
Assert.assertEquals(NullHandling.defaultDoubleValue(), target.finalizeComputation(null));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
package org.apache.druid.query.aggregation.variance;
|
||||
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.jackson.DefaultObjectMapper;
|
||||
import org.apache.druid.query.aggregation.TestFloatColumnSelector;
|
||||
import org.apache.druid.segment.ColumnSelectorFactory;
|
||||
|
@ -93,16 +94,10 @@ public class VarianceAggregatorTest extends InitializedNullHandlingTest
|
|||
Assert.assertEquals(sum, holder.sum, 0.0001);
|
||||
Assert.assertEquals(nvariance, holder.nvariance, 0.0001);
|
||||
if (count == 0) {
|
||||
try {
|
||||
holder.getVariance(false);
|
||||
Assert.fail("Should throw ISE");
|
||||
}
|
||||
catch (IllegalStateException e) {
|
||||
Assert.assertTrue(e.getMessage().contains("should not be empty holder"));
|
||||
}
|
||||
Assert.assertEquals(NullHandling.defaultDoubleValue(), holder.getVariance(false));
|
||||
} else {
|
||||
Assert.assertEquals(holder.getVariance(true), variances_pop[(int) count - 1], 0.0001);
|
||||
Assert.assertEquals(holder.getVariance(false), variances_samp[(int) count - 1], 0.0001);
|
||||
Assert.assertEquals(variances_pop[(int) count - 1], holder.getVariance(true), 0.0001);
|
||||
Assert.assertEquals(variances_samp[(int) count - 1], holder.getVariance(false), 0.0001);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,8 @@ package org.apache.druid.query.aggregation.variance.sql;
|
|||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Iterables;
|
||||
import junitparams.JUnitParamsRunner;
|
||||
import junitparams.Parameters;
|
||||
import org.apache.calcite.schema.SchemaPlus;
|
||||
import org.apache.druid.common.config.NullHandling;
|
||||
import org.apache.druid.data.input.InputRow;
|
||||
|
@ -34,18 +36,34 @@ import org.apache.druid.data.input.impl.LongDimensionSchema;
|
|||
import org.apache.druid.data.input.impl.MapInputRowParser;
|
||||
import org.apache.druid.data.input.impl.TimeAndDimsParseSpec;
|
||||
import org.apache.druid.data.input.impl.TimestampSpec;
|
||||
import org.apache.druid.java.util.common.DateTimes;
|
||||
import org.apache.druid.java.util.common.granularity.Granularities;
|
||||
import org.apache.druid.java.util.common.io.Closer;
|
||||
import org.apache.druid.query.Druids;
|
||||
import org.apache.druid.query.QueryPlus;
|
||||
import org.apache.druid.query.QueryRunner;
|
||||
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
|
||||
import org.apache.druid.query.QueryRunnerTestHelper;
|
||||
import org.apache.druid.query.Result;
|
||||
import org.apache.druid.query.aggregation.CountAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
|
||||
import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
|
||||
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
|
||||
import org.apache.druid.query.dimension.DefaultDimensionSpec;
|
||||
import org.apache.druid.query.groupby.GroupByQuery;
|
||||
import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
|
||||
import org.apache.druid.query.groupby.orderby.OrderByColumnSpec;
|
||||
import org.apache.druid.query.ordering.StringComparators;
|
||||
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
|
||||
import org.apache.druid.query.timeseries.TimeseriesQuery;
|
||||
import org.apache.druid.query.timeseries.TimeseriesQueryEngine;
|
||||
import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
|
||||
import org.apache.druid.query.timeseries.TimeseriesQueryRunnerFactory;
|
||||
import org.apache.druid.query.timeseries.TimeseriesResultValue;
|
||||
import org.apache.druid.segment.IndexBuilder;
|
||||
import org.apache.druid.segment.QueryableIndex;
|
||||
import org.apache.druid.segment.TestHelper;
|
||||
import org.apache.druid.segment.column.ValueType;
|
||||
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
|
||||
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
|
||||
|
@ -74,10 +92,15 @@ import org.junit.BeforeClass;
|
|||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@RunWith(JUnitParamsRunner.class)
|
||||
public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
||||
{
|
||||
private static AuthenticationResult authenticationResult = CalciteTests.REGULAR_USER_AUTH_RESULT;
|
||||
|
@ -86,6 +109,8 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
private static QueryRunnerFactoryConglomerate conglomerate;
|
||||
private static Closer resourceCloser;
|
||||
|
||||
private SqlLifecycle sqlLifecycle;
|
||||
|
||||
@BeforeClass
|
||||
public static void setUpClass()
|
||||
{
|
||||
|
@ -181,6 +206,8 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
CalciteTests.DRUID_SCHEMA_NAME
|
||||
)
|
||||
);
|
||||
queryLogHook.clearRecordedQueries();
|
||||
sqlLifecycle = sqlLifecycleFactory.factorize();
|
||||
}
|
||||
|
||||
@After
|
||||
|
@ -221,7 +248,6 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
@Test
|
||||
public void testVarPop() throws Exception
|
||||
{
|
||||
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
|
||||
final String sql = "SELECT\n"
|
||||
+ "VAR_POP(d1),\n"
|
||||
+ "VAR_POP(f1),\n"
|
||||
|
@ -251,14 +277,11 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
final List<Object[]> expectedResults = ImmutableList.of(
|
||||
new Object[]{
|
||||
holder1.getVariance(true),
|
||||
(float) holder2.getVariance(true),
|
||||
(long) holder3.getVariance(true),
|
||||
holder2.getVariance(true).floatValue(),
|
||||
holder3.getVariance(true).longValue()
|
||||
}
|
||||
);
|
||||
Assert.assertEquals(expectedResults.size(), results.size());
|
||||
for (int i = 0; i < expectedResults.size(); i++) {
|
||||
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
|
||||
}
|
||||
assertResultsEquals(expectedResults, results);
|
||||
|
||||
Assert.assertEquals(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
|
@ -281,7 +304,6 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
@Test
|
||||
public void testVarSamp() throws Exception
|
||||
{
|
||||
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
|
||||
final String sql = "SELECT\n"
|
||||
+ "VAR_SAMP(d1),\n"
|
||||
+ "VAR_SAMP(f1),\n"
|
||||
|
@ -311,14 +333,11 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
final List<Object[]> expectedResults = ImmutableList.of(
|
||||
new Object[]{
|
||||
holder1.getVariance(false),
|
||||
(float) holder2.getVariance(false),
|
||||
(long) holder3.getVariance(false),
|
||||
holder2.getVariance(false).floatValue(),
|
||||
holder3.getVariance(false).longValue(),
|
||||
}
|
||||
);
|
||||
Assert.assertEquals(expectedResults.size(), results.size());
|
||||
for (int i = 0; i < expectedResults.size(); i++) {
|
||||
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
|
||||
}
|
||||
assertResultsEquals(expectedResults, results);
|
||||
|
||||
Assert.assertEquals(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
|
@ -341,7 +360,6 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
@Test
|
||||
public void testStdDevPop() throws Exception
|
||||
{
|
||||
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
|
||||
final String sql = "SELECT\n"
|
||||
+ "STDDEV_POP(d1),\n"
|
||||
+ "STDDEV_POP(f1),\n"
|
||||
|
@ -375,10 +393,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
(long) Math.sqrt(holder3.getVariance(true)),
|
||||
}
|
||||
);
|
||||
Assert.assertEquals(expectedResults.size(), results.size());
|
||||
for (int i = 0; i < expectedResults.size(); i++) {
|
||||
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
|
||||
}
|
||||
assertResultsEquals(expectedResults, results);
|
||||
|
||||
Assert.assertEquals(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
|
@ -407,8 +422,6 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
@Test
|
||||
public void testStdDevSamp() throws Exception
|
||||
{
|
||||
queryLogHook.clearRecordedQueries();
|
||||
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
|
||||
final String sql = "SELECT\n"
|
||||
+ "STDDEV_SAMP(d1),\n"
|
||||
+ "STDDEV_SAMP(f1),\n"
|
||||
|
@ -442,10 +455,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
(long) Math.sqrt(holder3.getVariance(false)),
|
||||
}
|
||||
);
|
||||
Assert.assertEquals(expectedResults.size(), results.size());
|
||||
for (int i = 0; i < expectedResults.size(); i++) {
|
||||
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
|
||||
}
|
||||
assertResultsEquals(expectedResults, results);
|
||||
|
||||
Assert.assertEquals(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
|
@ -469,12 +479,10 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testStdDevWithVirtualColumns() throws Exception
|
||||
{
|
||||
queryLogHook.clearRecordedQueries();
|
||||
SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
|
||||
final String sql = "SELECT\n"
|
||||
+ "STDDEV(d1*7),\n"
|
||||
+ "STDDEV(f1*7),\n"
|
||||
|
@ -508,10 +516,7 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
(long) Math.sqrt(holder3.getVariance(false)),
|
||||
}
|
||||
);
|
||||
Assert.assertEquals(expectedResults.size(), results.size());
|
||||
for (int i = 0; i < expectedResults.size(); i++) {
|
||||
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
|
||||
}
|
||||
assertResultsEquals(expectedResults, results);
|
||||
|
||||
Assert.assertEquals(
|
||||
Druids.newTimeseriesQueryBuilder()
|
||||
|
@ -540,4 +545,112 @@ public class VarianceSqlAggregatorTest extends InitializedNullHandlingTest
|
|||
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testVarianceOrderBy() throws Exception
|
||||
{
|
||||
queryLogHook.clearRecordedQueries();
|
||||
final String sql = "select dim2, VARIANCE(f1) from druid.numfoo group by 1 order by 2 desc";
|
||||
final List<Object[]> results =
|
||||
sqlLifecycle.runSimple(
|
||||
sql,
|
||||
BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT,
|
||||
CalciteTestBase.DEFAULT_PARAMETERS,
|
||||
authenticationResult
|
||||
).toList();
|
||||
List<Object[]> expectedResults = NullHandling.sqlCompatible()
|
||||
? ImmutableList.of(
|
||||
new Object[] {"a", 0f},
|
||||
new Object[] {null, 0f},
|
||||
new Object[] {"", 0f},
|
||||
new Object[] {"abc", null}
|
||||
) : ImmutableList.of(
|
||||
new Object[] {"a", 0.5f},
|
||||
new Object[] {"", 0.0033333334f},
|
||||
new Object[] {"abc", 0f}
|
||||
);
|
||||
assertResultsEquals(expectedResults, results);
|
||||
|
||||
Assert.assertEquals(
|
||||
GroupByQuery.builder()
|
||||
.setDataSource(CalciteTests.DATASOURCE3)
|
||||
.setInterval(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setDimensions(new DefaultDimensionSpec("dim2", "_d0"))
|
||||
.setAggregatorSpecs(
|
||||
new VarianceAggregatorFactory("a0:agg", "f1", "sample", "float")
|
||||
)
|
||||
.setLimitSpec(
|
||||
DefaultLimitSpec
|
||||
.builder()
|
||||
.orderBy(
|
||||
new OrderByColumnSpec(
|
||||
"a0:agg",
|
||||
OrderByColumnSpec.Direction.DESCENDING,
|
||||
StringComparators.NUMERIC
|
||||
)
|
||||
)
|
||||
.build()
|
||||
)
|
||||
.setContext(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
|
||||
.build(),
|
||||
Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
|
||||
);
|
||||
}
|
||||
|
||||
public Object[] timeseriesQueryRunners()
|
||||
{
|
||||
return QueryRunnerTestHelper.makeQueryRunners(
|
||||
new TimeseriesQueryRunnerFactory(
|
||||
new TimeseriesQueryQueryToolChest(),
|
||||
new TimeseriesQueryEngine(),
|
||||
QueryRunnerTestHelper.NOOP_QUERYWATCHER
|
||||
)
|
||||
).toArray();
|
||||
}
|
||||
|
||||
@Test
|
||||
@Parameters(method = "timeseriesQueryRunners")
|
||||
public void testEmptyTimeseries(QueryRunner<Result<TimeseriesResultValue>> runner)
|
||||
{
|
||||
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
|
||||
.dataSource(QueryRunnerTestHelper.DATA_SOURCE)
|
||||
.granularity(QueryRunnerTestHelper.ALL_GRAN)
|
||||
.intervals(QueryRunnerTestHelper.EMPTY_INTERVAL)
|
||||
.aggregators(
|
||||
Arrays.asList(
|
||||
QueryRunnerTestHelper.ROWS_COUNT,
|
||||
QueryRunnerTestHelper.INDEX_DOUBLE_SUM,
|
||||
new VarianceAggregatorFactory("variance", "index")
|
||||
)
|
||||
)
|
||||
.descending(true)
|
||||
.context(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT)
|
||||
.build();
|
||||
Map<String, Object> resultMap = new HashMap<>();
|
||||
resultMap.put("rows", 0L);
|
||||
resultMap.put("index", NullHandling.defaultDoubleValue());
|
||||
resultMap.put("variance", NullHandling.defaultDoubleValue());
|
||||
List<Result<TimeseriesResultValue>> expectedResults = ImmutableList.of(
|
||||
new Result<>(
|
||||
DateTimes.of("2020-04-02"),
|
||||
new TimeseriesResultValue(
|
||||
resultMap
|
||||
)
|
||||
)
|
||||
);
|
||||
Iterable<Result<TimeseriesResultValue>> actualResults = runner.run(QueryPlus.wrap(query)).toList();
|
||||
TestHelper.assertExpectedResults(expectedResults, actualResults);
|
||||
}
|
||||
|
||||
private static void assertResultsEquals(List<Object[]> expectedResults, List<Object[]> results)
|
||||
{
|
||||
Assert.assertEquals(expectedResults.size(), results.size());
|
||||
for (int i = 0; i < expectedResults.size(); i++) {
|
||||
Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue