fix null handling for arithmetic post aggregator comparator (#9159)

* fix null handling for arithmetic postagg comparator, add test for comparator for min/max/quantile postaggs in histogram ext

* fix
This commit is contained in:
Clint Wylie 2020-01-10 13:49:19 -08:00 committed by Jonathan Wei
parent 8c53818fa9
commit 85219ece13
11 changed files with 224 additions and 53 deletions

View File

@ -35,14 +35,8 @@ import java.util.Set;
@JsonTypeName("max")
public class MaxPostAggregator extends ApproximateHistogramPostAggregator
{
static final Comparator COMPARATOR = new Comparator()
{
@Override
public int compare(Object o, Object o1)
{
return Double.compare(((Number) o).doubleValue(), ((Number) o1).doubleValue());
}
};
// this doesn't need to handle nulls because the values come from ApproximateHistogram
static final Comparator COMPARATOR = Comparator.comparingDouble(o -> ((Number) o).doubleValue());
@JsonCreator
public MaxPostAggregator(

View File

@ -36,14 +36,8 @@ import java.util.Set;
@JsonTypeName("min")
public class MinPostAggregator extends ApproximateHistogramPostAggregator
{
static final Comparator COMPARATOR = new Comparator()
{
@Override
public int compare(Object o, Object o1)
{
return Double.compare(((Number) o).doubleValue(), ((Number) o1).doubleValue());
}
};
// this doesn't need to handle nulls because the values come from ApproximateHistogram
static final Comparator COMPARATOR = Comparator.comparingDouble(o -> ((Number) o).doubleValue());
@JsonCreator
public MinPostAggregator(

View File

@ -36,14 +36,8 @@ import java.util.Set;
@JsonTypeName("quantile")
public class QuantilePostAggregator extends ApproximateHistogramPostAggregator
{
static final Comparator COMPARATOR = new Comparator()
{
@Override
public int compare(Object o, Object o1)
{
return Double.compare(((Number) o).doubleValue(), ((Number) o1).doubleValue());
}
};
// this doesn't need to handle nulls because the values come from ApproximateHistogram
static final Comparator COMPARATOR = Comparator.comparingDouble(o -> ((Number) o).doubleValue());
private final float probability;
private final String fieldName;

View File

@ -41,7 +41,7 @@ public class ApproximateHistogramPostAggregatorTest extends InitializedNullHandl
}
@Test
public void testCompute()
public void testApproxHistogramCompute()
{
ApproximateHistogram ah = buildHistogram(10, VALUES);
final TestFloatColumnSelector selector = new TestFloatColumnSelector(VALUES);
@ -63,5 +63,4 @@ public class ApproximateHistogramPostAggregatorTest extends InitializedNullHandl
);
Assert.assertEquals(ah.toHistogram(5), approximateHistogramPostAggregator.compute(metricValues));
}
}

View File

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.histogram;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
public class MaxPostAggregatorTest extends InitializedNullHandlingTest
{
@Test
public void testComparator()
{
final String aggName = "doubleWithNulls";
Map<String, Object> metricValues = new HashMap<>();
MaxPostAggregator max = new MaxPostAggregator("max", aggName);
Comparator comp = max.getComparator();
ApproximateHistogram histo1 = new ApproximateHistogram();
metricValues.put(aggName, histo1);
Object before = max.compute(metricValues);
ApproximateHistogram histo2 = new ApproximateHistogram();
histo2.offer(1.0f);
metricValues.put(aggName, histo2);
Object after = max.compute(metricValues);
Assert.assertEquals(-1, comp.compare(before, after));
Assert.assertEquals(0, comp.compare(before, before));
Assert.assertEquals(0, comp.compare(after, after));
Assert.assertEquals(1, comp.compare(after, before));
}
}

View File

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.histogram;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
public class MinPostAggregatorTest extends InitializedNullHandlingTest
{
@Test
public void testComparator()
{
final String aggName = "doubleWithNulls";
Map<String, Object> metricValues = new HashMap<>();
MinPostAggregator min = new MinPostAggregator("min", aggName);
Comparator comp = min.getComparator();
ApproximateHistogram histo1 = new ApproximateHistogram();
metricValues.put(aggName, histo1);
Object before = min.compute(metricValues);
ApproximateHistogram histo2 = new ApproximateHistogram();
histo2.offer(1.0f);
metricValues.put(aggName, histo2);
Object after = min.compute(metricValues);
Assert.assertEquals(1, comp.compare(before, after));
Assert.assertEquals(0, comp.compare(before, before));
Assert.assertEquals(0, comp.compare(after, after));
Assert.assertEquals(-1, comp.compare(after, before));
}
}

View File

@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query.aggregation.histogram;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
public class QuantilePostAggregatorTest extends InitializedNullHandlingTest
{
@Test
public void testComparator()
{
final String aggName = "doubleWithNulls";
Map<String, Object> metricValues = new HashMap<>();
QuantilePostAggregator quantile = new QuantilePostAggregator("quantile", aggName, 0.9f);
Comparator comp = quantile.getComparator();
ApproximateHistogram histo1 = new ApproximateHistogram();
histo1.offer(10.0f);
metricValues.put(aggName, histo1);
Object before = quantile.compute(metricValues);
ApproximateHistogram histo2 = new ApproximateHistogram();
histo2.offer(100.0f);
metricValues.put(aggName, histo2);
Object after = quantile.compute(metricValues);
Assert.assertEquals(-1, comp.compare(before, after));
Assert.assertEquals(0, comp.compare(before, before));
Assert.assertEquals(0, comp.compare(after, after));
Assert.assertEquals(1, comp.compare(after, before));
}
}

View File

@ -29,7 +29,7 @@ import java.util.Comparator;
*/
public class DoubleSumAggregator implements Aggregator
{
static final Comparator COMPARATOR = new Ordering()
public static final Comparator COMPARATOR = new Ordering()
{
@Override
public int compare(Object o, Object o1)

View File

@ -26,6 +26,7 @@ import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.query.Queries;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregator;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
@ -43,14 +44,7 @@ import java.util.Set;
*/
public class ArithmeticPostAggregator implements PostAggregator
{
public static final Comparator DEFAULT_COMPARATOR = new Comparator()
{
@Override
public int compare(Object o, Object o1)
{
return ((Double) o).compareTo((Double) o1);
}
};
public static final Comparator DEFAULT_COMPARATOR = DoubleSumAggregator.COMPARATOR;
private final String name;
private final String fnName;

View File

@ -26,6 +26,7 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.Sets;
import org.apache.druid.js.JavaScriptConfig;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregator;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.checkerframework.checker.nullness.qual.EnsuresNonNull;
@ -42,15 +43,6 @@ import java.util.Set;
public class JavaScriptPostAggregator implements PostAggregator
{
private static final Comparator COMPARATOR = new Comparator()
{
@Override
public int compare(Object o, Object o1)
{
return ((Double) o).compareTo((Double) o1);
}
};
private interface Function
{
double apply(Object[] args);
@ -127,7 +119,7 @@ public class JavaScriptPostAggregator implements PostAggregator
@Override
public Comparator getComparator()
{
return COMPARATOR;
return DoubleSumAggregator.COMPARATOR;
}
@Override

View File

@ -22,9 +22,11 @@ package org.apache.druid.query.aggregation.post;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.CountAggregator;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.expression.TestExprMacroTable;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
@ -33,9 +35,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
*/
public class ArithmeticPostAggregatorTest
public class ArithmeticPostAggregatorTest extends InitializedNullHandlingTest
{
@Test
public void testCompute()
@ -53,10 +53,12 @@ public class ArithmeticPostAggregatorTest
List<PostAggregator> postAggregatorList =
Lists.newArrayList(
new ConstantPostAggregator(
"roku", 6D
"roku",
6D
),
new FieldAccessPostAggregator(
"rows", "rows"
"rows",
"rows"
)
);
@ -91,16 +93,18 @@ public class ArithmeticPostAggregatorTest
final String aggName = "rows";
ArithmeticPostAggregator arithmeticPostAggregator;
CountAggregator agg = new CountAggregator();
Map<String, Object> metricValues = new HashMap<String, Object>();
Map<String, Object> metricValues = new HashMap<>();
metricValues.put(aggName, agg.get());
List<PostAggregator> postAggregatorList =
Lists.newArrayList(
new ConstantPostAggregator(
"roku", 6D
"roku",
6D
),
new FieldAccessPostAggregator(
"rows", "rows"
"rows",
"rows"
)
);
@ -119,6 +123,39 @@ public class ArithmeticPostAggregatorTest
Assert.assertEquals(1, comp.compare(after, before));
}
@Test
public void testComparatorNulls()
{
final String aggName = "doubleWithNulls";
ArithmeticPostAggregator arithmeticPostAggregator;
Map<String, Object> metricValues = new HashMap<>();
List<PostAggregator> postAggregatorList =
Lists.newArrayList(
new ConstantPostAggregator(
"roku",
6D
),
new FieldAccessPostAggregator(
aggName,
aggName
)
);
arithmeticPostAggregator = new ArithmeticPostAggregator("add", "+", postAggregatorList);
Comparator comp = arithmeticPostAggregator.getComparator();
metricValues.put(aggName, NullHandling.replaceWithDefault() ? NullHandling.defaultDoubleValue() : null);
Object before = arithmeticPostAggregator.compute(metricValues);
metricValues.put(aggName, 1.0);
Object after = arithmeticPostAggregator.compute(metricValues);
Assert.assertEquals(-1, comp.compare(before, after));
Assert.assertEquals(0, comp.compare(before, before));
Assert.assertEquals(0, comp.compare(after, after));
Assert.assertEquals(1, comp.compare(after, before));
}
@Test
public void testQuotient()
{