GroupByQuery: Fix type-spanning comparisons. (#4317)

Jackson deserializes integers sometimes as int and sometimes as long,
depending on how big they are. This leads to ClassCastException
when comparing deserialized values as part of groupBy merging on the
broker.
This commit is contained in:
Gian Merlino 2017-05-24 02:06:04 +09:00 committed by Fangjin Yang
parent 22977780aa
commit 9283807ad7
2 changed files with 67 additions and 24 deletions

View File

@ -58,6 +58,7 @@ import io.druid.query.spec.QuerySegmentSpec;
import io.druid.segment.VirtualColumn; import io.druid.segment.VirtualColumn;
import io.druid.segment.VirtualColumns; import io.druid.segment.VirtualColumns;
import io.druid.segment.column.Column; import io.druid.segment.column.Column;
import io.druid.segment.column.ValueType;
import org.joda.time.Interval; import org.joda.time.Interval;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -331,10 +332,23 @@ public class GroupByQuery extends BaseQuery<Row>
private static int compareDims(List<DimensionSpec> dimensions, Row lhs, Row rhs) private static int compareDims(List<DimensionSpec> dimensions, Row lhs, Row rhs)
{ {
for (DimensionSpec dimension : dimensions) { for (DimensionSpec dimension : dimensions) {
final int dimCompare = NATURAL_NULLS_FIRST.compare( final int dimCompare;
if (dimension.getOutputType() == ValueType.LONG) {
dimCompare = Long.compare(
((Number) lhs.getRaw(dimension.getOutputName())).longValue(),
((Number) rhs.getRaw(dimension.getOutputName())).longValue()
);
} else if (dimension.getOutputType() == ValueType.FLOAT) {
dimCompare = Double.compare(
((Number) lhs.getRaw(dimension.getOutputName())).doubleValue(),
((Number) rhs.getRaw(dimension.getOutputName())).doubleValue()
);
} else {
dimCompare = NATURAL_NULLS_FIRST.compare(
lhs.getRaw(dimension.getOutputName()), lhs.getRaw(dimension.getOutputName()),
rhs.getRaw(dimension.getOutputName()) rhs.getRaw(dimension.getOutputName())
); );
}
if (dimCompare != 0) { if (dimCompare != 0) {
return dimCompare; return dimCompare;
} }

View File

@ -21,8 +21,13 @@ package io.druid.query.groupby;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import io.druid.data.input.MapBasedRow;
import io.druid.data.input.Row;
import io.druid.jackson.DefaultObjectMapper; import io.druid.jackson.DefaultObjectMapper;
import io.druid.java.util.common.granularity.Granularities;
import io.druid.query.Query; import io.druid.query.Query;
import io.druid.query.QueryRunnerTestHelper; import io.druid.query.QueryRunnerTestHelper;
import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.AggregatorFactory;
@ -34,6 +39,7 @@ import io.druid.query.dimension.DimensionSpec;
import io.druid.query.groupby.orderby.DefaultLimitSpec; import io.druid.query.groupby.orderby.DefaultLimitSpec;
import io.druid.query.groupby.orderby.OrderByColumnSpec; import io.druid.query.groupby.orderby.OrderByColumnSpec;
import io.druid.query.ordering.StringComparators; import io.druid.query.ordering.StringComparators;
import io.druid.segment.column.ValueType;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
@ -62,7 +68,11 @@ public class GroupByQueryTest
.setPostAggregatorSpecs(ImmutableList.<PostAggregator>of(new FieldAccessPostAggregator("x", "idx"))) .setPostAggregatorSpecs(ImmutableList.<PostAggregator>of(new FieldAccessPostAggregator("x", "idx")))
.setLimitSpec( .setLimitSpec(
new DefaultLimitSpec( new DefaultLimitSpec(
ImmutableList.of(new OrderByColumnSpec("alias", OrderByColumnSpec.Direction.ASCENDING, StringComparators.LEXICOGRAPHIC)), ImmutableList.of(new OrderByColumnSpec(
"alias",
OrderByColumnSpec.Direction.ASCENDING,
StringComparators.LEXICOGRAPHIC
)),
100 100
) )
) )
@ -74,4 +84,23 @@ public class GroupByQueryTest
Assert.assertEquals(query, serdeQuery); Assert.assertEquals(query, serdeQuery);
} }
@Test
public void testRowOrderingMixTypes()
{
final GroupByQuery query = GroupByQuery.builder()
.setDataSource("dummy")
.setGranularity(Granularities.ALL)
.setInterval("2000/2001")
.addDimension(new DefaultDimensionSpec("foo", "foo", ValueType.LONG))
.addDimension(new DefaultDimensionSpec("bar", "bar", ValueType.FLOAT))
.addDimension(new DefaultDimensionSpec("baz", "baz", ValueType.STRING))
.build();
final Ordering<Row> rowOrdering = query.getRowOrdering(false);
final int compare = rowOrdering.compare(
new MapBasedRow(0L, ImmutableMap.of("foo", 1, "bar", 1f, "baz", "a")),
new MapBasedRow(0L, ImmutableMap.of("foo", 1L, "bar", 1d, "baz", "b"))
);
Assert.assertEquals(-1, compare);
}
} }