Fix GroupBy limit push down descending sorting on numeric columns (#5453)

This commit is contained in:
Jonathan Wei 2018-03-01 18:43:45 -08:00 committed by Fangjin Yang
parent e0d456b1ba
commit cf5f74b013
2 changed files with 205 additions and 20 deletions

View File

@ -574,22 +574,14 @@ public class GroupByQuery extends BaseQuery<Row>
final StringComparator comparator = comparators.get(i);
final int dimCompare;
Object lhsObj;
Object rhsObj;
if (needsReverseList.get(i)) {
lhsObj = rhs.getRaw(fieldName);
rhsObj = lhs.getRaw(fieldName);
} else {
lhsObj = lhs.getRaw(fieldName);
rhsObj = rhs.getRaw(fieldName);
}
final Object lhsObj = lhs.getRaw(fieldName);
final Object rhsObj = rhs.getRaw(fieldName);
if (isNumericField.get(i)) {
if (comparator.equals(StringComparators.NUMERIC)) {
dimCompare = ((Ordering) Comparators.naturalNullsFirst()).compare(
lhs.getRaw(fieldName),
rhs.getRaw(fieldName)
lhsObj,
rhsObj
);
} else {
dimCompare = comparator.compare(String.valueOf(lhsObj), String.valueOf(rhsObj));
@ -599,7 +591,7 @@ public class GroupByQuery extends BaseQuery<Row>
}
if (dimCompare != 0) {
return dimCompare;
return needsReverseList.get(i) ? -dimCompare : dimCompare;
}
}
return 0;

View File

@ -49,6 +49,9 @@ import io.druid.java.util.common.guava.Sequence;
import io.druid.java.util.common.guava.Sequences;
import io.druid.java.util.common.logger.Logger;
import io.druid.math.expr.ExprMacroTable;
import io.druid.query.aggregation.CountAggregatorFactory;
import io.druid.query.expression.TestExprMacroTable;
import io.druid.segment.virtual.ExpressionVirtualColumn;
import io.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import io.druid.query.BySegmentQueryRunner;
import io.druid.query.DruidProcessingConfig;
@ -240,7 +243,78 @@ public class GroupByLimitPushDownMultiNodeMergeTest
);
QueryableIndex qindexB = INDEX_IO.loadIndex(fileB);
groupByIndices = Arrays.asList(qindexA, qindexB);
final IncrementalIndex indexC = makeIncIndex(false);
incrementalIndices.add(indexC);
event = new HashMap<>();
event.put("dimA", "pomegranate");
event.put("metA", 2395L);
row = new MapBasedInputRow(1505260800000L, dimNames, event);
indexC.add(row);
event = new HashMap<>();
event.put("dimA", "mango");
event.put("metA", 8L);
row = new MapBasedInputRow(1605260800000L, dimNames, event);
indexC.add(row);
event = new HashMap<>();
event.put("dimA", "pomegranate");
event.put("metA", 5028L);
row = new MapBasedInputRow(1705264400000L, dimNames, event);
indexC.add(row);
event = new HashMap<>();
event.put("dimA", "mango");
event.put("metA", 7L);
row = new MapBasedInputRow(1805264400000L, dimNames, event);
indexC.add(row);
final File fileC = INDEX_MERGER_V9.persist(
indexC,
new File(tmpDir, "C"),
new IndexSpec(),
null
);
QueryableIndex qindexC = INDEX_IO.loadIndex(fileC);
final IncrementalIndex indexD = makeIncIndex(false);
incrementalIndices.add(indexD);
event = new HashMap<>();
event.put("dimA", "pomegranate");
event.put("metA", 4718L);
row = new MapBasedInputRow(1505260800000L, dimNames, event);
indexD.add(row);
event = new HashMap<>();
event.put("dimA", "mango");
event.put("metA", 18L);
row = new MapBasedInputRow(1605260800000L, dimNames, event);
indexD.add(row);
event = new HashMap<>();
event.put("dimA", "pomegranate");
event.put("metA", 2698L);
row = new MapBasedInputRow(1705264400000L, dimNames, event);
indexD.add(row);
event = new HashMap<>();
event.put("dimA", "mango");
event.put("metA", 3L);
row = new MapBasedInputRow(1805264400000L, dimNames, event);
indexD.add(row);
final File fileD = INDEX_MERGER_V9.persist(
indexD,
new File(tmpDir, "D"),
new IndexSpec(),
null
);
QueryableIndex qindexD = INDEX_IO.loadIndex(fileD);
groupByIndices = Arrays.asList(qindexA, qindexB, qindexC, qindexD);
setupGroupByFactory();
}
@ -376,6 +450,125 @@ public class GroupByLimitPushDownMultiNodeMergeTest
}
}
@Test
public void testDescendingNumerics() throws Exception
{
QueryToolChest<Row, GroupByQuery> toolChest = groupByFactory.getToolchest();
QueryRunner<Row> theRunner = new FinalizeResultsQueryRunner<>(
toolChest.mergeResults(
groupByFactory.mergeRunners(executorService, getRunner1(2))
),
(QueryToolChest) toolChest
);
QueryRunner<Row> theRunner2 = new FinalizeResultsQueryRunner<>(
toolChest.mergeResults(
groupByFactory2.mergeRunners(executorService, getRunner2(3))
),
(QueryToolChest) toolChest
);
QueryRunner<Row> finalRunner = new FinalizeResultsQueryRunner<>(
toolChest.mergeResults(
new QueryRunner<Row>()
{
@Override
public Sequence<Row> run(QueryPlus<Row> queryPlus, Map<String, Object> responseContext)
{
return Sequences
.simple(
ImmutableList.of(
theRunner.run(queryPlus, responseContext),
theRunner2.run(queryPlus, responseContext)
)
)
.flatMerge(Function.identity(), queryPlus.getQuery().getResultOrdering());
}
}
),
(QueryToolChest) toolChest
);
QuerySegmentSpec intervalSpec = new MultipleIntervalSegmentSpec(
Collections.singletonList(Intervals.utc(1500000000000L, 1900000000000L))
);
DefaultLimitSpec ls2 = new DefaultLimitSpec(
Arrays.asList(
new OrderByColumnSpec("d0", OrderByColumnSpec.Direction.DESCENDING, StringComparators.NUMERIC),
new OrderByColumnSpec("d1", OrderByColumnSpec.Direction.DESCENDING, StringComparators.NUMERIC),
new OrderByColumnSpec("d2", OrderByColumnSpec.Direction.DESCENDING, StringComparators.NUMERIC)
),
100
);
GroupByQuery query = GroupByQuery
.builder()
.setDataSource("blah")
.setQuerySegmentSpec(intervalSpec)
.setVirtualColumns(
new ExpressionVirtualColumn("d0:v", "timestamp_extract(\"__time\",'YEAR','UTC')", ValueType.LONG, TestExprMacroTable.INSTANCE),
new ExpressionVirtualColumn("d1:v", "timestamp_extract(\"__time\",'MONTH','UTC')", ValueType.LONG, TestExprMacroTable.INSTANCE),
new ExpressionVirtualColumn("d2:v", "timestamp_extract(\"__time\",'DAY','UTC')", ValueType.LONG, TestExprMacroTable.INSTANCE)
)
.setDimensions(Lists.<DimensionSpec>newArrayList(
new DefaultDimensionSpec("d0:v", "d0", ValueType.LONG),
new DefaultDimensionSpec("d1:v", "d1", ValueType.LONG),
new DefaultDimensionSpec("d2:v", "d2", ValueType.LONG)
))
.setAggregatorSpecs(
Arrays.asList(new CountAggregatorFactory("a0"))
)
.setLimitSpec(
ls2
)
.setContext(
ImmutableMap.of(
GroupByQueryConfig.CTX_KEY_APPLY_LIMIT_PUSH_DOWN, true
)
)
.setGranularity(Granularities.ALL)
.build();
Sequence<Row> queryResult = finalRunner.run(QueryPlus.wrap(query), Maps.newHashMap());
List<Row> results = queryResult.toList();
Row expectedRow0 = GroupByQueryRunnerTestHelper.createExpectedRow(
"2017-07-14T02:40:00.000Z",
"d0", 2027L,
"d1", 3L,
"d2", 17L,
"a0", 2L
);
Row expectedRow1 = GroupByQueryRunnerTestHelper.createExpectedRow(
"2017-07-14T02:40:00.000Z",
"d0", 2024L,
"d1", 1L,
"d2", 14L,
"a0", 2L
);
Row expectedRow2 = GroupByQueryRunnerTestHelper.createExpectedRow(
"2017-07-14T02:40:00.000Z",
"d0", 2020L,
"d1", 11L,
"d2", 13L,
"a0", 2L
);
Row expectedRow3 = GroupByQueryRunnerTestHelper.createExpectedRow(
"2017-07-14T02:40:00.000Z",
"d0", 2017L,
"d1", 9L,
"d2", 13L,
"a0", 2L
);
Assert.assertEquals(4, results.size());
Assert.assertEquals(expectedRow0, results.get(0));
Assert.assertEquals(expectedRow1, results.get(1));
Assert.assertEquals(expectedRow2, results.get(2));
Assert.assertEquals(expectedRow3, results.get(3));
}
@Test
public void testPartialLimitPushDownMerge() throws Exception
{
@ -384,14 +577,14 @@ public class GroupByLimitPushDownMultiNodeMergeTest
QueryToolChest<Row, GroupByQuery> toolChest = groupByFactory.getToolchest();
QueryRunner<Row> theRunner = new FinalizeResultsQueryRunner<>(
toolChest.mergeResults(
groupByFactory.mergeRunners(executorService, getRunner1())
groupByFactory.mergeRunners(executorService, getRunner1(0))
),
(QueryToolChest) toolChest
);
QueryRunner<Row> theRunner2 = new FinalizeResultsQueryRunner<>(
toolChest.mergeResults(
groupByFactory2.mergeRunners(executorService, getRunner2())
groupByFactory2.mergeRunners(executorService, getRunner2(1))
),
(QueryToolChest) toolChest
);
@ -495,10 +688,10 @@ public class GroupByLimitPushDownMultiNodeMergeTest
Assert.assertEquals(expectedRow3, results.get(3));
}
private List<QueryRunner<Row>> getRunner1()
private List<QueryRunner<Row>> getRunner1(int qIndexNumber)
{
List<QueryRunner<Row>> runners = Lists.newArrayList();
QueryableIndex index = groupByIndices.get(0);
QueryableIndex index = groupByIndices.get(qIndexNumber);
QueryRunner<Row> runner = makeQueryRunner(
groupByFactory,
index.toString(),
@ -508,10 +701,10 @@ public class GroupByLimitPushDownMultiNodeMergeTest
return runners;
}
private List<QueryRunner<Row>> getRunner2()
private List<QueryRunner<Row>> getRunner2(int qIndexNumber)
{
List<QueryRunner<Row>> runners = Lists.newArrayList();
QueryableIndex index2 = groupByIndices.get(1);
QueryableIndex index2 = groupByIndices.get(qIndexNumber);
QueryRunner<Row> tooSmallRunner = makeQueryRunner(
groupByFactory2,
index2.toString(),