GroupByQuery: Fix type-spanning comparisons. (#4317)

Jackson deserializes integers sometimes as int and sometimes as long,
depending on how big they are. This leads to ClassCastException
when comparing deserialized values as part of groupBy merging on the
broker.
This commit is contained in:
Gian Merlino 2017-05-24 02:06:04 +09:00 committed by Fangjin Yang
parent 22977780aa
commit 9283807ad7
2 changed files with 67 additions and 24 deletions

View File

@ -58,6 +58,7 @@ import io.druid.query.spec.QuerySegmentSpec;
import io.druid.segment.VirtualColumn;
import io.druid.segment.VirtualColumns;
import io.druid.segment.column.Column;
import io.druid.segment.column.ValueType;
import org.joda.time.Interval;
import javax.annotation.Nullable;
@ -331,10 +332,23 @@ public class GroupByQuery extends BaseQuery<Row>
private static int compareDims(List<DimensionSpec> dimensions, Row lhs, Row rhs)
{
for (DimensionSpec dimension : dimensions) {
final int dimCompare = NATURAL_NULLS_FIRST.compare(
lhs.getRaw(dimension.getOutputName()),
rhs.getRaw(dimension.getOutputName())
);
final int dimCompare;
if (dimension.getOutputType() == ValueType.LONG) {
dimCompare = Long.compare(
((Number) lhs.getRaw(dimension.getOutputName())).longValue(),
((Number) rhs.getRaw(dimension.getOutputName())).longValue()
);
} else if (dimension.getOutputType() == ValueType.FLOAT) {
dimCompare = Double.compare(
((Number) lhs.getRaw(dimension.getOutputName())).doubleValue(),
((Number) rhs.getRaw(dimension.getOutputName())).doubleValue()
);
} else {
dimCompare = NATURAL_NULLS_FIRST.compare(
lhs.getRaw(dimension.getOutputName()),
rhs.getRaw(dimension.getOutputName())
);
}
if (dimCompare != 0) {
return dimCompare;
}
@ -735,17 +749,17 @@ public class GroupByQuery extends BaseQuery<Row>
public String toString()
{
return "GroupByQuery{" +
"dataSource='" + getDataSource() + '\'' +
", querySegmentSpec=" + getQuerySegmentSpec() +
", virtualColumns=" + virtualColumns +
", limitSpec=" + limitSpec +
", dimFilter=" + dimFilter +
", granularity=" + granularity +
", dimensions=" + dimensions +
", aggregatorSpecs=" + aggregatorSpecs +
", postAggregatorSpecs=" + postAggregatorSpecs +
", havingSpec=" + havingSpec +
'}';
"dataSource='" + getDataSource() + '\'' +
", querySegmentSpec=" + getQuerySegmentSpec() +
", virtualColumns=" + virtualColumns +
", limitSpec=" + limitSpec +
", dimFilter=" + dimFilter +
", granularity=" + granularity +
", dimensions=" + dimensions +
", aggregatorSpecs=" + aggregatorSpecs +
", postAggregatorSpecs=" + postAggregatorSpecs +
", havingSpec=" + havingSpec +
'}';
}
@Override
@ -762,13 +776,13 @@ public class GroupByQuery extends BaseQuery<Row>
}
final GroupByQuery that = (GroupByQuery) o;
return Objects.equals(virtualColumns, that.virtualColumns) &&
Objects.equals(limitSpec, that.limitSpec) &&
Objects.equals(havingSpec, that.havingSpec) &&
Objects.equals(dimFilter, that.dimFilter) &&
Objects.equals(granularity, that.granularity) &&
Objects.equals(dimensions, that.dimensions) &&
Objects.equals(aggregatorSpecs, that.aggregatorSpecs) &&
Objects.equals(postAggregatorSpecs, that.postAggregatorSpecs);
Objects.equals(limitSpec, that.limitSpec) &&
Objects.equals(havingSpec, that.havingSpec) &&
Objects.equals(dimFilter, that.dimFilter) &&
Objects.equals(granularity, that.granularity) &&
Objects.equals(dimensions, that.dimensions) &&
Objects.equals(aggregatorSpecs, that.aggregatorSpecs) &&
Objects.equals(postAggregatorSpecs, that.postAggregatorSpecs);
}
@Override

View File

@ -21,8 +21,13 @@ package io.druid.query.groupby;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import io.druid.data.input.MapBasedRow;
import io.druid.data.input.Row;
import io.druid.jackson.DefaultObjectMapper;
import io.druid.java.util.common.granularity.Granularities;
import io.druid.query.Query;
import io.druid.query.QueryRunnerTestHelper;
import io.druid.query.aggregation.AggregatorFactory;
@ -34,6 +39,7 @@ import io.druid.query.dimension.DimensionSpec;
import io.druid.query.groupby.orderby.DefaultLimitSpec;
import io.druid.query.groupby.orderby.OrderByColumnSpec;
import io.druid.query.ordering.StringComparators;
import io.druid.segment.column.ValueType;
import org.junit.Assert;
import org.junit.Test;
@ -62,7 +68,11 @@ public class GroupByQueryTest
.setPostAggregatorSpecs(ImmutableList.<PostAggregator>of(new FieldAccessPostAggregator("x", "idx")))
.setLimitSpec(
new DefaultLimitSpec(
ImmutableList.of(new OrderByColumnSpec("alias", OrderByColumnSpec.Direction.ASCENDING, StringComparators.LEXICOGRAPHIC)),
ImmutableList.of(new OrderByColumnSpec(
"alias",
OrderByColumnSpec.Direction.ASCENDING,
StringComparators.LEXICOGRAPHIC
)),
100
)
)
@ -74,4 +84,23 @@ public class GroupByQueryTest
Assert.assertEquals(query, serdeQuery);
}
@Test
public void testRowOrderingMixTypes()
{
final GroupByQuery query = GroupByQuery.builder()
.setDataSource("dummy")
.setGranularity(Granularities.ALL)
.setInterval("2000/2001")
.addDimension(new DefaultDimensionSpec("foo", "foo", ValueType.LONG))
.addDimension(new DefaultDimensionSpec("bar", "bar", ValueType.FLOAT))
.addDimension(new DefaultDimensionSpec("baz", "baz", ValueType.STRING))
.build();
final Ordering<Row> rowOrdering = query.getRowOrdering(false);
final int compare = rowOrdering.compare(
new MapBasedRow(0L, ImmutableMap.of("foo", 1, "bar", 1f, "baz", "a")),
new MapBasedRow(0L, ImmutableMap.of("foo", 1L, "bar", 1d, "baz", "b"))
);
Assert.assertEquals(-1, compare);
}
}