fix null handling for arithmetic post aggregator comparator (#9159)

* fix null handling for arithmetic postagg comparator, add test for comparator for min/max/quantile postaggs in histogram ext

* fix
This commit is contained in:
Clint Wylie 2020-01-10 13:49:19 -08:00 committed by Jonathan Wei
parent 8c53818fa9
commit 85219ece13
11 changed files with 224 additions and 53 deletions

View File

@ -35,14 +35,8 @@ import java.util.Set;
@JsonTypeName("max") @JsonTypeName("max")
public class MaxPostAggregator extends ApproximateHistogramPostAggregator public class MaxPostAggregator extends ApproximateHistogramPostAggregator
{ {
static final Comparator COMPARATOR = new Comparator() // this doesn't need to handle nulls because the values come from ApproximateHistogram
{ static final Comparator COMPARATOR = Comparator.comparingDouble(o -> ((Number) o).doubleValue());
@Override
public int compare(Object o, Object o1)
{
return Double.compare(((Number) o).doubleValue(), ((Number) o1).doubleValue());
}
};
@JsonCreator @JsonCreator
public MaxPostAggregator( public MaxPostAggregator(

View File

@ -36,14 +36,8 @@ import java.util.Set;
@JsonTypeName("min") @JsonTypeName("min")
public class MinPostAggregator extends ApproximateHistogramPostAggregator public class MinPostAggregator extends ApproximateHistogramPostAggregator
{ {
static final Comparator COMPARATOR = new Comparator() // this doesn't need to handle nulls because the values come from ApproximateHistogram
{ static final Comparator COMPARATOR = Comparator.comparingDouble(o -> ((Number) o).doubleValue());
@Override
public int compare(Object o, Object o1)
{
return Double.compare(((Number) o).doubleValue(), ((Number) o1).doubleValue());
}
};
@JsonCreator @JsonCreator
public MinPostAggregator( public MinPostAggregator(

View File

@ -36,14 +36,8 @@ import java.util.Set;
@JsonTypeName("quantile") @JsonTypeName("quantile")
public class QuantilePostAggregator extends ApproximateHistogramPostAggregator public class QuantilePostAggregator extends ApproximateHistogramPostAggregator
{ {
static final Comparator COMPARATOR = new Comparator() // this doesn't need to handle nulls because the values come from ApproximateHistogram
{ static final Comparator COMPARATOR = Comparator.comparingDouble(o -> ((Number) o).doubleValue());
@Override
public int compare(Object o, Object o1)
{
return Double.compare(((Number) o).doubleValue(), ((Number) o1).doubleValue());
}
};
private final float probability; private final float probability;
private final String fieldName; private final String fieldName;

View File

@ -41,7 +41,7 @@ public class ApproximateHistogramPostAggregatorTest extends InitializedNullHandl
} }
@Test @Test
public void testCompute() public void testApproxHistogramCompute()
{ {
ApproximateHistogram ah = buildHistogram(10, VALUES); ApproximateHistogram ah = buildHistogram(10, VALUES);
final TestFloatColumnSelector selector = new TestFloatColumnSelector(VALUES); final TestFloatColumnSelector selector = new TestFloatColumnSelector(VALUES);
@ -63,5 +63,4 @@ public class ApproximateHistogramPostAggregatorTest extends InitializedNullHandl
); );
Assert.assertEquals(ah.toHistogram(5), approximateHistogramPostAggregator.compute(metricValues)); Assert.assertEquals(ah.toHistogram(5), approximateHistogramPostAggregator.compute(metricValues));
} }
} }

View File

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.histogram;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
public class MaxPostAggregatorTest extends InitializedNullHandlingTest
{
@Test
public void testComparator()
{
final String aggName = "doubleWithNulls";
Map<String, Object> metricValues = new HashMap<>();
MaxPostAggregator max = new MaxPostAggregator("max", aggName);
Comparator comp = max.getComparator();
ApproximateHistogram histo1 = new ApproximateHistogram();
metricValues.put(aggName, histo1);
Object before = max.compute(metricValues);
ApproximateHistogram histo2 = new ApproximateHistogram();
histo2.offer(1.0f);
metricValues.put(aggName, histo2);
Object after = max.compute(metricValues);
Assert.assertEquals(-1, comp.compare(before, after));
Assert.assertEquals(0, comp.compare(before, before));
Assert.assertEquals(0, comp.compare(after, after));
Assert.assertEquals(1, comp.compare(after, before));
}
}

View File

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.histogram;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
public class MinPostAggregatorTest extends InitializedNullHandlingTest
{
@Test
public void testComparator()
{
final String aggName = "doubleWithNulls";
Map<String, Object> metricValues = new HashMap<>();
MinPostAggregator min = new MinPostAggregator("min", aggName);
Comparator comp = min.getComparator();
ApproximateHistogram histo1 = new ApproximateHistogram();
metricValues.put(aggName, histo1);
Object before = min.compute(metricValues);
ApproximateHistogram histo2 = new ApproximateHistogram();
histo2.offer(1.0f);
metricValues.put(aggName, histo2);
Object after = min.compute(metricValues);
Assert.assertEquals(1, comp.compare(before, after));
Assert.assertEquals(0, comp.compare(before, before));
Assert.assertEquals(0, comp.compare(after, after));
Assert.assertEquals(-1, comp.compare(after, before));
}
}

View File

@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.histogram;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
public class QuantilePostAggregatorTest extends InitializedNullHandlingTest
{
@Test
public void testComparator()
{
final String aggName = "doubleWithNulls";
Map<String, Object> metricValues = new HashMap<>();
QuantilePostAggregator quantile = new QuantilePostAggregator("quantile", aggName, 0.9f);
Comparator comp = quantile.getComparator();
ApproximateHistogram histo1 = new ApproximateHistogram();
histo1.offer(10.0f);
metricValues.put(aggName, histo1);
Object before = quantile.compute(metricValues);
ApproximateHistogram histo2 = new ApproximateHistogram();
histo2.offer(100.0f);
metricValues.put(aggName, histo2);
Object after = quantile.compute(metricValues);
Assert.assertEquals(-1, comp.compare(before, after));
Assert.assertEquals(0, comp.compare(before, before));
Assert.assertEquals(0, comp.compare(after, after));
Assert.assertEquals(1, comp.compare(after, before));
}
}

View File

@ -29,7 +29,7 @@ import java.util.Comparator;
*/ */
public class DoubleSumAggregator implements Aggregator public class DoubleSumAggregator implements Aggregator
{ {
static final Comparator COMPARATOR = new Ordering() public static final Comparator COMPARATOR = new Ordering()
{ {
@Override @Override
public int compare(Object o, Object o1) public int compare(Object o, Object o1)

View File

@ -26,6 +26,7 @@ import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.query.Queries; import org.apache.druid.query.Queries;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregator;
import org.apache.druid.query.aggregation.PostAggregator; import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
@ -43,14 +44,7 @@ import java.util.Set;
*/ */
public class ArithmeticPostAggregator implements PostAggregator public class ArithmeticPostAggregator implements PostAggregator
{ {
public static final Comparator DEFAULT_COMPARATOR = new Comparator() public static final Comparator DEFAULT_COMPARATOR = DoubleSumAggregator.COMPARATOR;
{
@Override
public int compare(Object o, Object o1)
{
return ((Double) o).compareTo((Double) o1);
}
};
private final String name; private final String name;
private final String fnName; private final String fnName;

View File

@ -26,6 +26,7 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import org.apache.druid.js.JavaScriptConfig; import org.apache.druid.js.JavaScriptConfig;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregator;
import org.apache.druid.query.aggregation.PostAggregator; import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.checkerframework.checker.nullness.qual.EnsuresNonNull; import org.checkerframework.checker.nullness.qual.EnsuresNonNull;
@ -42,15 +43,6 @@ import java.util.Set;
public class JavaScriptPostAggregator implements PostAggregator public class JavaScriptPostAggregator implements PostAggregator
{ {
private static final Comparator COMPARATOR = new Comparator()
{
@Override
public int compare(Object o, Object o1)
{
return ((Double) o).compareTo((Double) o1);
}
};
private interface Function private interface Function
{ {
double apply(Object[] args); double apply(Object[] args);
@ -127,7 +119,7 @@ public class JavaScriptPostAggregator implements PostAggregator
@Override @Override
public Comparator getComparator() public Comparator getComparator()
{ {
return COMPARATOR; return DoubleSumAggregator.COMPARATOR;
} }
@Override @Override

View File

@ -22,9 +22,11 @@ package org.apache.druid.query.aggregation.post;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.CountAggregator; import org.apache.druid.query.aggregation.CountAggregator;
import org.apache.druid.query.aggregation.PostAggregator; import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.expression.TestExprMacroTable; import org.apache.druid.query.expression.TestExprMacroTable;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
@ -33,9 +35,7 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
/** public class ArithmeticPostAggregatorTest extends InitializedNullHandlingTest
*/
public class ArithmeticPostAggregatorTest
{ {
@Test @Test
public void testCompute() public void testCompute()
@ -53,10 +53,12 @@ public class ArithmeticPostAggregatorTest
List<PostAggregator> postAggregatorList = List<PostAggregator> postAggregatorList =
Lists.newArrayList( Lists.newArrayList(
new ConstantPostAggregator( new ConstantPostAggregator(
"roku", 6D "roku",
6D
), ),
new FieldAccessPostAggregator( new FieldAccessPostAggregator(
"rows", "rows" "rows",
"rows"
) )
); );
@ -91,16 +93,18 @@ public class ArithmeticPostAggregatorTest
final String aggName = "rows"; final String aggName = "rows";
ArithmeticPostAggregator arithmeticPostAggregator; ArithmeticPostAggregator arithmeticPostAggregator;
CountAggregator agg = new CountAggregator(); CountAggregator agg = new CountAggregator();
Map<String, Object> metricValues = new HashMap<String, Object>(); Map<String, Object> metricValues = new HashMap<>();
metricValues.put(aggName, agg.get()); metricValues.put(aggName, agg.get());
List<PostAggregator> postAggregatorList = List<PostAggregator> postAggregatorList =
Lists.newArrayList( Lists.newArrayList(
new ConstantPostAggregator( new ConstantPostAggregator(
"roku", 6D "roku",
6D
), ),
new FieldAccessPostAggregator( new FieldAccessPostAggregator(
"rows", "rows" "rows",
"rows"
) )
); );
@ -119,6 +123,39 @@ public class ArithmeticPostAggregatorTest
Assert.assertEquals(1, comp.compare(after, before)); Assert.assertEquals(1, comp.compare(after, before));
} }
@Test
public void testComparatorNulls()
{
final String aggName = "doubleWithNulls";
ArithmeticPostAggregator arithmeticPostAggregator;
Map<String, Object> metricValues = new HashMap<>();
List<PostAggregator> postAggregatorList =
Lists.newArrayList(
new ConstantPostAggregator(
"roku",
6D
),
new FieldAccessPostAggregator(
aggName,
aggName
)
);
arithmeticPostAggregator = new ArithmeticPostAggregator("add", "+", postAggregatorList);
Comparator comp = arithmeticPostAggregator.getComparator();
metricValues.put(aggName, NullHandling.replaceWithDefault() ? NullHandling.defaultDoubleValue() : null);
Object before = arithmeticPostAggregator.compute(metricValues);
metricValues.put(aggName, 1.0);
Object after = arithmeticPostAggregator.compute(metricValues);
Assert.assertEquals(-1, comp.compare(before, after));
Assert.assertEquals(0, comp.compare(before, before));
Assert.assertEquals(0, comp.compare(after, after));
Assert.assertEquals(1, comp.compare(after, before));
}
@Test @Test
public void testQuotient() public void testQuotient()
{ {