Use typecasting comparator for numeric "any" aggregations. (#16494)

This brings them in line with the behavior of other numeric aggregations.
It is important because otherwise ClassCastExceptions can arise if comparing
different numeric types that may arise from deserialization.
This commit is contained in:
Gian Merlino 2024-05-22 12:38:51 -07:00 committed by GitHub
parent 44ea4e1c51
commit eb410f712d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 41 additions and 11 deletions

View File

@ -28,7 +28,7 @@ import java.util.Comparator;
*/
public class FloatSumAggregator implements Aggregator
{
static final Comparator COMPARATOR = new Ordering()
public static final Comparator COMPARATOR = new Ordering()
{
@Override
public int compare(Object o, Object o1)

View File

@ -29,7 +29,7 @@ import java.util.Comparator;
*/
public class LongSumAggregator implements Aggregator
{
static final Comparator COMPARATOR = new Ordering()
public static final Comparator COMPARATOR = new Ordering()
{
@Override
public int compare(Object o, Object o1)

View File

@ -28,6 +28,7 @@ import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.DoubleSumAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.BaseDoubleColumnValueSelector;
@ -48,8 +49,6 @@ import java.util.Objects;
public class DoubleAnyAggregatorFactory extends AggregatorFactory
{
private static final Comparator<Double> VALUE_COMPARATOR = Comparator.nullsFirst(Double::compare);
private static final Aggregator NIL_AGGREGATOR = new DoubleAnyAggregator(
NilColumnValueSelector.instance()
)
@ -136,7 +135,7 @@ public class DoubleAnyAggregatorFactory extends AggregatorFactory
@Override
public Comparator getComparator()
{
return VALUE_COMPARATOR;
return DoubleSumAggregator.COMPARATOR;
}
@Override

View File

@ -28,6 +28,7 @@ import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.FloatSumAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.BaseFloatColumnValueSelector;
@ -47,8 +48,6 @@ import java.util.Objects;
public class FloatAnyAggregatorFactory extends AggregatorFactory
{
private static final Comparator<Float> VALUE_COMPARATOR = Comparator.nullsFirst(Float::compare);
private static final Aggregator NIL_AGGREGATOR = new FloatAnyAggregator(
NilColumnValueSelector.instance()
)
@ -133,7 +132,7 @@ public class FloatAnyAggregatorFactory extends AggregatorFactory
@Override
public Comparator getComparator()
{
return VALUE_COMPARATOR;
return FloatSumAggregator.COMPARATOR;
}
@Override

View File

@ -28,6 +28,7 @@ import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.AggregatorUtil;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.aggregation.LongSumAggregator;
import org.apache.druid.query.aggregation.VectorAggregator;
import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.segment.BaseLongColumnValueSelector;
@ -46,8 +47,6 @@ import java.util.List;
public class LongAnyAggregatorFactory extends AggregatorFactory
{
private static final Comparator<Long> VALUE_COMPARATOR = Comparator.nullsFirst(Long::compare);
private static final Aggregator NIL_AGGREGATOR = new LongAnyAggregator(
NilColumnValueSelector.instance()
)
@ -132,7 +131,7 @@ public class LongAnyAggregatorFactory extends AggregatorFactory
@Override
public Comparator getComparator()
{
return VALUE_COMPARATOR;
return LongSumAggregator.COMPARATOR;
}
@Override

View File

@ -117,6 +117,17 @@ public class DoubleAnyAggregationTest extends InitializedNullHandlingTest
Assert.assertEquals(-1, comparator.compare(d2, d1));
}
@Test
public void testComparatorWithTypeMismatch()
{
Long n1 = 3L;
Double n2 = 4.0;
Comparator comparator = doubleAnyAggFactory.getComparator();
Assert.assertEquals(0, comparator.compare(n1, n1));
Assert.assertEquals(-1, comparator.compare(n1, n2));
Assert.assertEquals(1, comparator.compare(n2, n1));
}
@Test
public void testDoubleAnyCombiningAggregator()
{

View File

@ -117,6 +117,17 @@ public class FloatAnyAggregationTest extends InitializedNullHandlingTest
Assert.assertEquals(-1, comparator.compare(f2, f1));
}
@Test
public void testComparatorWithTypeMismatch()
{
Long n1 = 3L;
Float n2 = 4.0f;
Comparator comparator = floatAnyAggFactory.getComparator();
Assert.assertEquals(0, comparator.compare(n1, n1));
Assert.assertEquals(-1, comparator.compare(n1, n2));
Assert.assertEquals(1, comparator.compare(n2, n1));
}
@Test
public void testFloatAnyCombiningAggregator()
{

View File

@ -118,6 +118,17 @@ public class LongAnyAggregationTest extends InitializedNullHandlingTest
Assert.assertEquals(-1, comparator.compare(l2, l1));
}
@Test
public void testComparatorWithTypeMismatch()
{
Integer n1 = 3;
Long n2 = 4L;
Comparator comparator = longAnyAggFactory.getComparator();
Assert.assertEquals(0, comparator.compare(n1, n1));
Assert.assertEquals(-1, comparator.compare(n1, n2));
Assert.assertEquals(1, comparator.compare(n2, n1));
}
@Test
public void testLongAnyCombiningAggregator()
{