Fix DefaultLimitSpec to respect sortByDimsFirst (#5385)

* Fix DefaultLimitSpec to respect sortByDimsFirst

* fix bug

* address comment
This commit is contained in:
Jihoon Son 2018-02-16 15:26:32 -08:00 committed by Gian Merlino
parent fba13d8978
commit deeda0dff2
6 changed files with 186 additions and 28 deletions

View File

@ -136,8 +136,13 @@ public class GroupByQuery extends BaseQuery<Row>
private Function<Sequence<Row>, Sequence<Row>> makePostProcessingFn() private Function<Sequence<Row>, Sequence<Row>> makePostProcessingFn()
{ {
Function<Sequence<Row>, Sequence<Row>> postProcessingFn = Function<Sequence<Row>, Sequence<Row>> postProcessingFn = limitSpec.build(
limitSpec.build(dimensions, aggregatorSpecs, postAggregatorSpecs); dimensions,
aggregatorSpecs,
postAggregatorSpecs,
getGranularity(),
getContextSortByDimsFirst()
);
if (havingSpec != null) { if (havingSpec != null) {
postProcessingFn = Functions.compose( postProcessingFn = Functions.compose(

View File

@ -32,6 +32,8 @@ import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs; import com.google.common.primitives.Longs;
import io.druid.data.input.Row; import io.druid.data.input.Row;
import io.druid.java.util.common.ISE; import io.druid.java.util.common.ISE;
import io.druid.java.util.common.granularity.Granularities;
import io.druid.java.util.common.granularity.Granularity;
import io.druid.java.util.common.guava.Sequence; import io.druid.java.util.common.guava.Sequence;
import io.druid.java.util.common.guava.Sequences; import io.druid.java.util.common.guava.Sequences;
import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.AggregatorFactory;
@ -119,16 +121,14 @@ public class DefaultLimitSpec implements LimitSpec
public Function<Sequence<Row>, Sequence<Row>> build( public Function<Sequence<Row>, Sequence<Row>> build(
List<DimensionSpec> dimensions, List<DimensionSpec> dimensions,
List<AggregatorFactory> aggs, List<AggregatorFactory> aggs,
List<PostAggregator> postAggs List<PostAggregator> postAggs,
Granularity granularity,
boolean sortByDimsFirst
) )
{ {
// Can avoid re-sorting if the natural ordering is good enough. // Can avoid re-sorting if the natural ordering is good enough.
boolean sortingNeeded = false; boolean sortingNeeded = dimensions.size() < columns.size();
if (dimensions.size() < columns.size()) {
sortingNeeded = true;
}
final Set<String> aggAndPostAggNames = Sets.newHashSet(); final Set<String> aggAndPostAggNames = Sets.newHashSet();
for (AggregatorFactory agg : aggs) { for (AggregatorFactory agg : aggs) {
@ -167,12 +167,17 @@ public class DefaultLimitSpec implements LimitSpec
} }
} }
if (!sortingNeeded) {
// If granularity is ALL, sortByDimsFirst doesn't change the sorting order.
sortingNeeded = !granularity.equals(Granularities.ALL) && sortByDimsFirst;
}
if (!sortingNeeded) { if (!sortingNeeded) {
return isLimited() ? new LimitingFn(limit) : Functions.identity(); return isLimited() ? new LimitingFn(limit) : Functions.identity();
} }
// Materialize the Comparator first for fast-fail error checking. // Materialize the Comparator first for fast-fail error checking.
final Ordering<Row> ordering = makeComparator(dimensions, aggs, postAggs); final Ordering<Row> ordering = makeComparator(dimensions, aggs, postAggs, sortByDimsFirst);
if (isLimited()) { if (isLimited()) {
return new TopNFunction(ordering, limit); return new TopNFunction(ordering, limit);
@ -199,10 +204,13 @@ public class DefaultLimitSpec implements LimitSpec
} }
private Ordering<Row> makeComparator( private Ordering<Row> makeComparator(
List<DimensionSpec> dimensions, List<AggregatorFactory> aggs, List<PostAggregator> postAggs List<DimensionSpec> dimensions,
List<AggregatorFactory> aggs,
List<PostAggregator> postAggs,
boolean sortByDimsFirst
) )
{ {
Ordering<Row> ordering = new Ordering<Row>() Ordering<Row> timeOrdering = new Ordering<Row>()
{ {
@Override @Override
public int compare(Row left, Row right) public int compare(Row left, Row right)
@ -226,6 +234,7 @@ public class DefaultLimitSpec implements LimitSpec
postAggregatorsMap.put(postAgg.getName(), postAgg); postAggregatorsMap.put(postAgg.getName(), postAgg);
} }
Ordering<Row> ordering = null;
for (OrderByColumnSpec columnSpec : columns) { for (OrderByColumnSpec columnSpec : columns) {
String columnName = columnSpec.getDimension(); String columnName = columnSpec.getDimension();
Ordering<Row> nextOrdering = null; Ordering<Row> nextOrdering = null;
@ -246,7 +255,13 @@ public class DefaultLimitSpec implements LimitSpec
nextOrdering = nextOrdering.reverse(); nextOrdering = nextOrdering.reverse();
} }
ordering = ordering.compound(nextOrdering); ordering = ordering == null ? nextOrdering : ordering.compound(nextOrdering);
}
if (ordering != null) {
ordering = sortByDimsFirst ? ordering.compound(timeOrdering) : timeOrdering.compound(ordering);
} else {
ordering = timeOrdering;
} }
return ordering; return ordering;

View File

@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.google.common.base.Function; import com.google.common.base.Function;
import io.druid.data.input.Row; import io.druid.data.input.Row;
import io.druid.java.util.common.Cacheable; import io.druid.java.util.common.Cacheable;
import io.druid.java.util.common.granularity.Granularity;
import io.druid.java.util.common.guava.Sequence; import io.druid.java.util.common.guava.Sequence;
import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.PostAggregator; import io.druid.query.aggregation.PostAggregator;
@ -48,16 +49,20 @@ public interface LimitSpec extends Cacheable
/** /**
* Returns a function that applies a limit to an input sequence that is assumed to be sorted on dimensions. * Returns a function that applies a limit to an input sequence that is assumed to be sorted on dimensions.
* *
* @param dimensions query dimensions * @param dimensions query dimensions
* @param aggs query aggregators * @param aggs query aggregators
* @param postAggs query postAggregators * @param postAggs query postAggregators
* @param granularity query granularity
* @param sortByDimsFirst 'sortByDimsFirst' value in queryContext
* *
* @return limit function * @return limit function
*/ */
Function<Sequence<Row>, Sequence<Row>> build( Function<Sequence<Row>, Sequence<Row>> build(
List<DimensionSpec> dimensions, List<DimensionSpec> dimensions,
List<AggregatorFactory> aggs, List<AggregatorFactory> aggs,
List<PostAggregator> postAggs List<PostAggregator> postAggs,
Granularity granularity,
boolean sortByDimsFirst
); );
LimitSpec merge(LimitSpec other); LimitSpec merge(LimitSpec other);

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.google.common.base.Function; import com.google.common.base.Function;
import com.google.common.base.Functions; import com.google.common.base.Functions;
import io.druid.data.input.Row; import io.druid.data.input.Row;
import io.druid.java.util.common.granularity.Granularity;
import io.druid.java.util.common.guava.Sequence; import io.druid.java.util.common.guava.Sequence;
import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.PostAggregator; import io.druid.query.aggregation.PostAggregator;
@ -52,7 +53,9 @@ public final class NoopLimitSpec implements LimitSpec
public Function<Sequence<Row>, Sequence<Row>> build( public Function<Sequence<Row>, Sequence<Row>> build(
List<DimensionSpec> dimensions, List<DimensionSpec> dimensions,
List<AggregatorFactory> aggs, List<AggregatorFactory> aggs,
List<PostAggregator> postAggs List<PostAggregator> postAggs,
Granularity granularity,
boolean sortByDimsFirst
) )
{ {
return Functions.identity(); return Functions.identity();

View File

@ -3531,6 +3531,68 @@ public class GroupByQueryRunnerTest
TestHelper.assertExpectedObjects(expectedResults, results, ""); TestHelper.assertExpectedObjects(expectedResults, results, "");
} }
@Test
public void testGroupByWithLookupAndLimitAndSortByDimsFirst()
{
Map<String, String> map = new HashMap<>();
map.put("automotive", "9");
map.put("business", "8");
map.put("entertainment", "7");
map.put("health", "6");
map.put("mezzanine", "5");
map.put("news", "4");
map.put("premium", "3");
map.put("technology", "2");
map.put("travel", "1");
GroupByQuery query = GroupByQuery
.builder()
.setDataSource(QueryRunnerTestHelper.dataSource)
.setQuerySegmentSpec(QueryRunnerTestHelper.firstToThird)
.setDimensions(
Lists.<DimensionSpec>newArrayList(
new ExtractionDimensionSpec(
"quality",
"alias",
new LookupExtractionFn(new MapLookupExtractor(map, false), false, null, false, false)
)
)
)
.setAggregatorSpecs(
Arrays.asList(
QueryRunnerTestHelper.rowsCount,
new LongSumAggregatorFactory("idx", "index")
)
)
.setLimitSpec(new DefaultLimitSpec(Lists.<OrderByColumnSpec>newArrayList(
new OrderByColumnSpec("alias", null, StringComparators.ALPHANUMERIC)), 11))
.setGranularity(QueryRunnerTestHelper.dayGran)
.setContext(ImmutableMap.of("sortByDimsFirst", true))
.build();
List<Row> expectedResults = Arrays.asList(
GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "1", "rows", 1L, "idx", 119L),
GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "1", "rows", 1L, "idx", 126L),
GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "2", "rows", 1L, "idx", 78L),
GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "2", "rows", 1L, "idx", 97L),
GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "3", "rows", 3L, "idx", 2900L),
GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "3", "rows", 3L, "idx", 2505L),
GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "4", "rows", 1L, "idx", 121L),
GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "4", "rows", 1L, "idx", 114L),
GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "5", "rows", 3L, "idx", 2870L),
GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "5", "rows", 3L, "idx", 2447L),
GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "6", "rows", 1L, "idx", 120L)
);
Iterable<Row> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
TestHelper.assertExpectedObjects(expectedResults, results, "");
}
@Ignore @Ignore
@Test @Test
// This is a test to verify per limit groupings, but Druid currently does not support this functionality. At a point // This is a test to verify per limit groupings, but Druid currently does not support this functionality. At a point
@ -7436,7 +7498,9 @@ public class GroupByQueryRunnerTest
query.getLimitSpec().build( query.getLimitSpec().build(
query.getDimensions(), query.getDimensions(),
query.getAggregatorSpecs(), query.getAggregatorSpecs(),
query.getPostAggregatorSpecs() query.getPostAggregatorSpecs(),
query.getGranularity(),
query.getContextSortByDimsFirst()
) )
); );
@ -7497,7 +7561,9 @@ public class GroupByQueryRunnerTest
query.getLimitSpec().build( query.getLimitSpec().build(
query.getDimensions(), query.getDimensions(),
query.getAggregatorSpecs(), query.getAggregatorSpecs(),
query.getPostAggregatorSpecs() query.getPostAggregatorSpecs(),
query.getGranularity(),
query.getContextSortByDimsFirst()
) )
); );
@ -7700,7 +7766,9 @@ public class GroupByQueryRunnerTest
query.getLimitSpec().build( query.getLimitSpec().build(
query.getDimensions(), query.getDimensions(),
query.getAggregatorSpecs(), query.getAggregatorSpecs(),
query.getPostAggregatorSpecs() query.getPostAggregatorSpecs(),
query.getGranularity(),
query.getContextSortByDimsFirst()
) )
); );
@ -7762,7 +7830,9 @@ public class GroupByQueryRunnerTest
query.getLimitSpec().build( query.getLimitSpec().build(
query.getDimensions(), query.getDimensions(),
query.getAggregatorSpecs(), query.getAggregatorSpecs(),
query.getPostAggregatorSpecs() query.getPostAggregatorSpecs(),
query.getGranularity(),
query.getContextSortByDimsFirst()
) )
); );
@ -7823,7 +7893,9 @@ public class GroupByQueryRunnerTest
query.getLimitSpec().build( query.getLimitSpec().build(
query.getDimensions(), query.getDimensions(),
query.getAggregatorSpecs(), query.getAggregatorSpecs(),
query.getPostAggregatorSpecs() query.getPostAggregatorSpecs(),
query.getGranularity(),
query.getContextSortByDimsFirst()
) )
); );

View File

@ -27,6 +27,7 @@ import com.google.common.collect.Maps;
import io.druid.data.input.MapBasedRow; import io.druid.data.input.MapBasedRow;
import io.druid.data.input.Row; import io.druid.data.input.Row;
import io.druid.java.util.common.DateTimes; import io.druid.java.util.common.DateTimes;
import io.druid.java.util.common.granularity.Granularities;
import io.druid.java.util.common.guava.Sequence; import io.druid.java.util.common.guava.Sequence;
import io.druid.java.util.common.guava.Sequences; import io.druid.java.util.common.guava.Sequences;
import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.AggregatorFactory;
@ -40,6 +41,7 @@ import io.druid.query.dimension.DimensionSpec;
import io.druid.query.expression.TestExprMacroTable; import io.druid.query.expression.TestExprMacroTable;
import io.druid.query.ordering.StringComparators; import io.druid.query.ordering.StringComparators;
import io.druid.segment.TestHelper; import io.druid.segment.TestHelper;
import io.druid.segment.column.ValueType;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
@ -159,7 +161,9 @@ public class DefaultLimitSpecTest
Function<Sequence<Row>, Sequence<Row>> limitFn = limitSpec.build( Function<Sequence<Row>, Sequence<Row>> limitFn = limitSpec.build(
ImmutableList.<DimensionSpec>of(), ImmutableList.<DimensionSpec>of(),
ImmutableList.<AggregatorFactory>of(), ImmutableList.<AggregatorFactory>of(),
ImmutableList.<PostAggregator>of() ImmutableList.<PostAggregator>of(),
Granularities.NONE,
false
); );
Assert.assertEquals( Assert.assertEquals(
@ -168,6 +172,50 @@ public class DefaultLimitSpecTest
); );
} }
@Test
public void testWithAllGranularity()
{
DefaultLimitSpec limitSpec = new DefaultLimitSpec(
ImmutableList.of(new OrderByColumnSpec("k1", OrderByColumnSpec.Direction.ASCENDING, StringComparators.NUMERIC)),
2
);
Function<Sequence<Row>, Sequence<Row>> limitFn = limitSpec.build(
ImmutableList.of(new DefaultDimensionSpec("k1", "k1", ValueType.DOUBLE)),
ImmutableList.of(),
ImmutableList.of(),
Granularities.ALL,
true
);
Assert.assertEquals(
ImmutableList.of(testRowsList.get(0), testRowsList.get(1)),
limitFn.apply(testRowsSequence).toList()
);
}
@Test
public void testWithSortByDimsFirst()
{
DefaultLimitSpec limitSpec = new DefaultLimitSpec(
ImmutableList.of(new OrderByColumnSpec("k1", OrderByColumnSpec.Direction.ASCENDING, StringComparators.NUMERIC)),
2
);
Function<Sequence<Row>, Sequence<Row>> limitFn = limitSpec.build(
ImmutableList.of(new DefaultDimensionSpec("k1", "k1", ValueType.DOUBLE)),
ImmutableList.of(),
ImmutableList.of(),
Granularities.NONE,
true
);
Assert.assertEquals(
ImmutableList.of(testRowsList.get(2), testRowsList.get(0)),
limitFn.apply(testRowsSequence).toList()
);
}
@Test @Test
public void testSortDimensionDescending() public void testSortDimensionDescending()
{ {
@ -179,7 +227,9 @@ public class DefaultLimitSpecTest
Function<Sequence<Row>, Sequence<Row>> limitFn = limitSpec.build( Function<Sequence<Row>, Sequence<Row>> limitFn = limitSpec.build(
ImmutableList.<DimensionSpec>of(new DefaultDimensionSpec("k1", "k1")), ImmutableList.<DimensionSpec>of(new DefaultDimensionSpec("k1", "k1")),
ImmutableList.<AggregatorFactory>of(), ImmutableList.<AggregatorFactory>of(),
ImmutableList.<PostAggregator>of() ImmutableList.<PostAggregator>of(),
Granularities.NONE,
false
); );
// Note: This test encodes the fact that limitSpec sorts numbers like strings; we might want to change this // Note: This test encodes the fact that limitSpec sorts numbers like strings; we might want to change this
@ -209,7 +259,9 @@ public class DefaultLimitSpecTest
), ),
ImmutableList.<PostAggregator>of( ImmutableList.<PostAggregator>of(
new ConstantPostAggregator("k3", 1L) new ConstantPostAggregator("k3", 1L)
) ),
Granularities.NONE,
false
); );
Assert.assertEquals( Assert.assertEquals(
ImmutableList.of(testRowsList.get(0), testRowsList.get(1)), ImmutableList.of(testRowsList.get(0), testRowsList.get(1)),
@ -226,7 +278,9 @@ public class DefaultLimitSpecTest
), ),
ImmutableList.<PostAggregator>of( ImmutableList.<PostAggregator>of(
new ConstantPostAggregator("k3", 1L) new ConstantPostAggregator("k3", 1L)
) ),
Granularities.NONE,
false
); );
Assert.assertEquals( Assert.assertEquals(
ImmutableList.of(testRowsList.get(2), testRowsList.get(0)), ImmutableList.of(testRowsList.get(2), testRowsList.get(0)),
@ -249,7 +303,9 @@ public class DefaultLimitSpecTest
new ConstantPostAggregator("x", 1), new ConstantPostAggregator("x", 1),
new ConstantPostAggregator("y", 1)) new ConstantPostAggregator("y", 1))
) )
) ),
Granularities.NONE,
false
); );
Assert.assertEquals( Assert.assertEquals(
(List) ImmutableList.of(testRowsList.get(2), testRowsList.get(0)), (List) ImmutableList.of(testRowsList.get(2), testRowsList.get(0)),
@ -260,7 +316,9 @@ public class DefaultLimitSpecTest
limitFn = limitSpec.build( limitFn = limitSpec.build(
ImmutableList.<DimensionSpec>of(new DefaultDimensionSpec("k1", "k1")), ImmutableList.<DimensionSpec>of(new DefaultDimensionSpec("k1", "k1")),
ImmutableList.<AggregatorFactory>of(new LongSumAggregatorFactory("k2", "k2")), ImmutableList.<AggregatorFactory>of(new LongSumAggregatorFactory("k2", "k2")),
ImmutableList.<PostAggregator>of(new ExpressionPostAggregator("k1", "1 + 1", null, TestExprMacroTable.INSTANCE)) ImmutableList.<PostAggregator>of(new ExpressionPostAggregator("k1", "1 + 1", null, TestExprMacroTable.INSTANCE)),
Granularities.NONE,
false
); );
Assert.assertEquals( Assert.assertEquals(
(List) ImmutableList.of(testRowsList.get(2), testRowsList.get(0)), (List) ImmutableList.of(testRowsList.get(2), testRowsList.get(0)),