mirror of https://github.com/apache/druid.git
GroupByQuery: Fix type-spanning comparisons. (#4317)
Jackson deserializes integers sometimes as int and sometimes as long, depending on how big they are. This leads to ClassCastException when comparing deserialized values as part of groupBy merging on the broker.
This commit is contained in:
parent
22977780aa
commit
9283807ad7
|
@ -58,6 +58,7 @@ import io.druid.query.spec.QuerySegmentSpec;
|
|||
import io.druid.segment.VirtualColumn;
|
||||
import io.druid.segment.VirtualColumns;
|
||||
import io.druid.segment.column.Column;
|
||||
import io.druid.segment.column.ValueType;
|
||||
import org.joda.time.Interval;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
|
@ -331,10 +332,23 @@ public class GroupByQuery extends BaseQuery<Row>
|
|||
private static int compareDims(List<DimensionSpec> dimensions, Row lhs, Row rhs)
|
||||
{
|
||||
for (DimensionSpec dimension : dimensions) {
|
||||
final int dimCompare = NATURAL_NULLS_FIRST.compare(
|
||||
lhs.getRaw(dimension.getOutputName()),
|
||||
rhs.getRaw(dimension.getOutputName())
|
||||
);
|
||||
final int dimCompare;
|
||||
if (dimension.getOutputType() == ValueType.LONG) {
|
||||
dimCompare = Long.compare(
|
||||
((Number) lhs.getRaw(dimension.getOutputName())).longValue(),
|
||||
((Number) rhs.getRaw(dimension.getOutputName())).longValue()
|
||||
);
|
||||
} else if (dimension.getOutputType() == ValueType.FLOAT) {
|
||||
dimCompare = Double.compare(
|
||||
((Number) lhs.getRaw(dimension.getOutputName())).doubleValue(),
|
||||
((Number) rhs.getRaw(dimension.getOutputName())).doubleValue()
|
||||
);
|
||||
} else {
|
||||
dimCompare = NATURAL_NULLS_FIRST.compare(
|
||||
lhs.getRaw(dimension.getOutputName()),
|
||||
rhs.getRaw(dimension.getOutputName())
|
||||
);
|
||||
}
|
||||
if (dimCompare != 0) {
|
||||
return dimCompare;
|
||||
}
|
||||
|
@ -735,17 +749,17 @@ public class GroupByQuery extends BaseQuery<Row>
|
|||
public String toString()
|
||||
{
|
||||
return "GroupByQuery{" +
|
||||
"dataSource='" + getDataSource() + '\'' +
|
||||
", querySegmentSpec=" + getQuerySegmentSpec() +
|
||||
", virtualColumns=" + virtualColumns +
|
||||
", limitSpec=" + limitSpec +
|
||||
", dimFilter=" + dimFilter +
|
||||
", granularity=" + granularity +
|
||||
", dimensions=" + dimensions +
|
||||
", aggregatorSpecs=" + aggregatorSpecs +
|
||||
", postAggregatorSpecs=" + postAggregatorSpecs +
|
||||
", havingSpec=" + havingSpec +
|
||||
'}';
|
||||
"dataSource='" + getDataSource() + '\'' +
|
||||
", querySegmentSpec=" + getQuerySegmentSpec() +
|
||||
", virtualColumns=" + virtualColumns +
|
||||
", limitSpec=" + limitSpec +
|
||||
", dimFilter=" + dimFilter +
|
||||
", granularity=" + granularity +
|
||||
", dimensions=" + dimensions +
|
||||
", aggregatorSpecs=" + aggregatorSpecs +
|
||||
", postAggregatorSpecs=" + postAggregatorSpecs +
|
||||
", havingSpec=" + havingSpec +
|
||||
'}';
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -762,13 +776,13 @@ public class GroupByQuery extends BaseQuery<Row>
|
|||
}
|
||||
final GroupByQuery that = (GroupByQuery) o;
|
||||
return Objects.equals(virtualColumns, that.virtualColumns) &&
|
||||
Objects.equals(limitSpec, that.limitSpec) &&
|
||||
Objects.equals(havingSpec, that.havingSpec) &&
|
||||
Objects.equals(dimFilter, that.dimFilter) &&
|
||||
Objects.equals(granularity, that.granularity) &&
|
||||
Objects.equals(dimensions, that.dimensions) &&
|
||||
Objects.equals(aggregatorSpecs, that.aggregatorSpecs) &&
|
||||
Objects.equals(postAggregatorSpecs, that.postAggregatorSpecs);
|
||||
Objects.equals(limitSpec, that.limitSpec) &&
|
||||
Objects.equals(havingSpec, that.havingSpec) &&
|
||||
Objects.equals(dimFilter, that.dimFilter) &&
|
||||
Objects.equals(granularity, that.granularity) &&
|
||||
Objects.equals(dimensions, that.dimensions) &&
|
||||
Objects.equals(aggregatorSpecs, that.aggregatorSpecs) &&
|
||||
Objects.equals(postAggregatorSpecs, that.postAggregatorSpecs);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -21,8 +21,13 @@ package io.druid.query.groupby;
|
|||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Ordering;
|
||||
import io.druid.data.input.MapBasedRow;
|
||||
import io.druid.data.input.Row;
|
||||
import io.druid.jackson.DefaultObjectMapper;
|
||||
import io.druid.java.util.common.granularity.Granularities;
|
||||
import io.druid.query.Query;
|
||||
import io.druid.query.QueryRunnerTestHelper;
|
||||
import io.druid.query.aggregation.AggregatorFactory;
|
||||
|
@ -34,6 +39,7 @@ import io.druid.query.dimension.DimensionSpec;
|
|||
import io.druid.query.groupby.orderby.DefaultLimitSpec;
|
||||
import io.druid.query.groupby.orderby.OrderByColumnSpec;
|
||||
import io.druid.query.ordering.StringComparators;
|
||||
import io.druid.segment.column.ValueType;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
|
@ -62,7 +68,11 @@ public class GroupByQueryTest
|
|||
.setPostAggregatorSpecs(ImmutableList.<PostAggregator>of(new FieldAccessPostAggregator("x", "idx")))
|
||||
.setLimitSpec(
|
||||
new DefaultLimitSpec(
|
||||
ImmutableList.of(new OrderByColumnSpec("alias", OrderByColumnSpec.Direction.ASCENDING, StringComparators.LEXICOGRAPHIC)),
|
||||
ImmutableList.of(new OrderByColumnSpec(
|
||||
"alias",
|
||||
OrderByColumnSpec.Direction.ASCENDING,
|
||||
StringComparators.LEXICOGRAPHIC
|
||||
)),
|
||||
100
|
||||
)
|
||||
)
|
||||
|
@ -74,4 +84,23 @@ public class GroupByQueryTest
|
|||
Assert.assertEquals(query, serdeQuery);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRowOrderingMixTypes()
|
||||
{
|
||||
final GroupByQuery query = GroupByQuery.builder()
|
||||
.setDataSource("dummy")
|
||||
.setGranularity(Granularities.ALL)
|
||||
.setInterval("2000/2001")
|
||||
.addDimension(new DefaultDimensionSpec("foo", "foo", ValueType.LONG))
|
||||
.addDimension(new DefaultDimensionSpec("bar", "bar", ValueType.FLOAT))
|
||||
.addDimension(new DefaultDimensionSpec("baz", "baz", ValueType.STRING))
|
||||
.build();
|
||||
|
||||
final Ordering<Row> rowOrdering = query.getRowOrdering(false);
|
||||
final int compare = rowOrdering.compare(
|
||||
new MapBasedRow(0L, ImmutableMap.of("foo", 1, "bar", 1f, "baz", "a")),
|
||||
new MapBasedRow(0L, ImmutableMap.of("foo", 1L, "bar", 1d, "baz", "b"))
|
||||
);
|
||||
Assert.assertEquals(-1, compare);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue