mirror of https://github.com/apache/druid.git
Use typecasting comparator for numeric "any" aggregations. (#16494)
This brings them in line with the behavior of other numeric aggregations. It is important because otherwise ClassCastExceptions can arise if comparing different numeric types that may arise from deserialization.
This commit is contained in:
parent
44ea4e1c51
commit
eb410f712d
|
@ -28,7 +28,7 @@ import java.util.Comparator;
|
|||
*/
|
||||
public class FloatSumAggregator implements Aggregator
|
||||
{
|
||||
static final Comparator COMPARATOR = new Ordering()
|
||||
public static final Comparator COMPARATOR = new Ordering()
|
||||
{
|
||||
@Override
|
||||
public int compare(Object o, Object o1)
|
||||
|
|
|
@ -29,7 +29,7 @@ import java.util.Comparator;
|
|||
*/
|
||||
public class LongSumAggregator implements Aggregator
|
||||
{
|
||||
static final Comparator COMPARATOR = new Ordering()
|
||||
public static final Comparator COMPARATOR = new Ordering()
|
||||
{
|
||||
@Override
|
||||
public int compare(Object o, Object o1)
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.druid.query.aggregation.Aggregator;
|
|||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.AggregatorUtil;
|
||||
import org.apache.druid.query.aggregation.BufferAggregator;
|
||||
import org.apache.druid.query.aggregation.DoubleSumAggregator;
|
||||
import org.apache.druid.query.aggregation.VectorAggregator;
|
||||
import org.apache.druid.query.cache.CacheKeyBuilder;
|
||||
import org.apache.druid.segment.BaseDoubleColumnValueSelector;
|
||||
|
@ -48,8 +49,6 @@ import java.util.Objects;
|
|||
|
||||
public class DoubleAnyAggregatorFactory extends AggregatorFactory
|
||||
{
|
||||
private static final Comparator<Double> VALUE_COMPARATOR = Comparator.nullsFirst(Double::compare);
|
||||
|
||||
private static final Aggregator NIL_AGGREGATOR = new DoubleAnyAggregator(
|
||||
NilColumnValueSelector.instance()
|
||||
)
|
||||
|
@ -136,7 +135,7 @@ public class DoubleAnyAggregatorFactory extends AggregatorFactory
|
|||
@Override
|
||||
public Comparator getComparator()
|
||||
{
|
||||
return VALUE_COMPARATOR;
|
||||
return DoubleSumAggregator.COMPARATOR;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.druid.query.aggregation.Aggregator;
|
|||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.AggregatorUtil;
|
||||
import org.apache.druid.query.aggregation.BufferAggregator;
|
||||
import org.apache.druid.query.aggregation.FloatSumAggregator;
|
||||
import org.apache.druid.query.aggregation.VectorAggregator;
|
||||
import org.apache.druid.query.cache.CacheKeyBuilder;
|
||||
import org.apache.druid.segment.BaseFloatColumnValueSelector;
|
||||
|
@ -47,8 +48,6 @@ import java.util.Objects;
|
|||
|
||||
public class FloatAnyAggregatorFactory extends AggregatorFactory
|
||||
{
|
||||
private static final Comparator<Float> VALUE_COMPARATOR = Comparator.nullsFirst(Float::compare);
|
||||
|
||||
private static final Aggregator NIL_AGGREGATOR = new FloatAnyAggregator(
|
||||
NilColumnValueSelector.instance()
|
||||
)
|
||||
|
@ -133,7 +132,7 @@ public class FloatAnyAggregatorFactory extends AggregatorFactory
|
|||
@Override
|
||||
public Comparator getComparator()
|
||||
{
|
||||
return VALUE_COMPARATOR;
|
||||
return FloatSumAggregator.COMPARATOR;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.druid.query.aggregation.Aggregator;
|
|||
import org.apache.druid.query.aggregation.AggregatorFactory;
|
||||
import org.apache.druid.query.aggregation.AggregatorUtil;
|
||||
import org.apache.druid.query.aggregation.BufferAggregator;
|
||||
import org.apache.druid.query.aggregation.LongSumAggregator;
|
||||
import org.apache.druid.query.aggregation.VectorAggregator;
|
||||
import org.apache.druid.query.cache.CacheKeyBuilder;
|
||||
import org.apache.druid.segment.BaseLongColumnValueSelector;
|
||||
|
@ -46,8 +47,6 @@ import java.util.List;
|
|||
|
||||
public class LongAnyAggregatorFactory extends AggregatorFactory
|
||||
{
|
||||
private static final Comparator<Long> VALUE_COMPARATOR = Comparator.nullsFirst(Long::compare);
|
||||
|
||||
private static final Aggregator NIL_AGGREGATOR = new LongAnyAggregator(
|
||||
NilColumnValueSelector.instance()
|
||||
)
|
||||
|
@ -132,7 +131,7 @@ public class LongAnyAggregatorFactory extends AggregatorFactory
|
|||
@Override
|
||||
public Comparator getComparator()
|
||||
{
|
||||
return VALUE_COMPARATOR;
|
||||
return LongSumAggregator.COMPARATOR;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -117,6 +117,17 @@ public class DoubleAnyAggregationTest extends InitializedNullHandlingTest
|
|||
Assert.assertEquals(-1, comparator.compare(d2, d1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComparatorWithTypeMismatch()
|
||||
{
|
||||
Long n1 = 3L;
|
||||
Double n2 = 4.0;
|
||||
Comparator comparator = doubleAnyAggFactory.getComparator();
|
||||
Assert.assertEquals(0, comparator.compare(n1, n1));
|
||||
Assert.assertEquals(-1, comparator.compare(n1, n2));
|
||||
Assert.assertEquals(1, comparator.compare(n2, n1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDoubleAnyCombiningAggregator()
|
||||
{
|
||||
|
|
|
@ -117,6 +117,17 @@ public class FloatAnyAggregationTest extends InitializedNullHandlingTest
|
|||
Assert.assertEquals(-1, comparator.compare(f2, f1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComparatorWithTypeMismatch()
|
||||
{
|
||||
Long n1 = 3L;
|
||||
Float n2 = 4.0f;
|
||||
Comparator comparator = floatAnyAggFactory.getComparator();
|
||||
Assert.assertEquals(0, comparator.compare(n1, n1));
|
||||
Assert.assertEquals(-1, comparator.compare(n1, n2));
|
||||
Assert.assertEquals(1, comparator.compare(n2, n1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFloatAnyCombiningAggregator()
|
||||
{
|
||||
|
|
|
@ -118,6 +118,17 @@ public class LongAnyAggregationTest extends InitializedNullHandlingTest
|
|||
Assert.assertEquals(-1, comparator.compare(l2, l1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComparatorWithTypeMismatch()
|
||||
{
|
||||
Integer n1 = 3;
|
||||
Long n2 = 4L;
|
||||
Comparator comparator = longAnyAggFactory.getComparator();
|
||||
Assert.assertEquals(0, comparator.compare(n1, n1));
|
||||
Assert.assertEquals(-1, comparator.compare(n1, n2));
|
||||
Assert.assertEquals(1, comparator.compare(n2, n1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLongAnyCombiningAggregator()
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue