diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java b/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java index c4c9a7875ef..4d12327c896 100755 --- a/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java @@ -454,4 +454,13 @@ public class AggregatorUtil } return false; } + + public static List getCombiningAggregators(List aggs) + { + List combining = new ArrayList<>(aggs.size()); + for (AggregatorFactory agg : aggs) { + combining.add(agg.getCombiningFactory()); + } + return combining; + } } diff --git a/processing/src/main/java/org/apache/druid/query/groupby/GroupingEngine.java b/processing/src/main/java/org/apache/druid/query/groupby/GroupingEngine.java index 0962e5400e2..ce63050a7e6 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/GroupingEngine.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/GroupingEngine.java @@ -61,6 +61,7 @@ import org.apache.druid.query.QueryWatcher; import org.apache.druid.query.ResourceLimitExceededException; import org.apache.druid.query.ResultMergeQueryRunner; import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.aggregation.PostAggregator; import org.apache.druid.query.context.ResponseContext; import org.apache.druid.query.dimension.DefaultDimensionSpec; @@ -508,6 +509,9 @@ public class GroupingEngine final CursorBuildSpec buildSpec = makeCursorBuildSpec(query, groupByQueryMetrics); final CursorHolder cursorHolder = closer.register(cursorFactory.makeCursorHolder(buildSpec)); + if (cursorHolder.isPreAggregated()) { + query = query.withAggregatorSpecs(AggregatorUtil.getCombiningAggregators(query.getAggregatorSpecs())); + } final ColumnInspector inspector = query.getVirtualColumns().wrapInspector(cursorFactory); // group by specific vectorization check diff --git a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java index 6e2cb62adcf..88d488f85b9 100644 --- a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java +++ b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java @@ -243,6 +243,11 @@ public class TimeseriesQuery extends BaseQuery> return Druids.TimeseriesQueryBuilder.copy(this).filters(dimFilter).build(); } + public TimeseriesQuery withAggregatorSpecs(List aggregatorSpecs) + { + return Druids.TimeseriesQueryBuilder.copy(this).aggregators(aggregatorSpecs).build(); + } + public TimeseriesQuery withPostAggregatorSpecs(final List postAggregatorSpecs) { return Druids.TimeseriesQueryBuilder.copy(this).postAggregators(postAggregatorSpecs).build(); diff --git a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryEngine.java b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryEngine.java index dd5a8cb2b58..dbec221248e 100644 --- a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryEngine.java +++ b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryEngine.java @@ -38,6 +38,7 @@ import org.apache.druid.query.Result; import org.apache.druid.query.aggregation.Aggregator; import org.apache.druid.query.aggregation.AggregatorAdapters; import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.vector.VectorCursorGranularizer; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.Cursor; @@ -86,7 +87,7 @@ public class TimeseriesQueryEngine * scoped down to a single interval before calling this method. */ public Sequence> process( - final TimeseriesQuery query, + TimeseriesQuery query, final CursorFactory cursorFactory, @Nullable TimeBoundaryInspector timeBoundaryInspector, @Nullable final TimeseriesQueryMetrics timeseriesQueryMetrics @@ -102,6 +103,9 @@ public class TimeseriesQueryEngine final Granularity gran = query.getGranularity(); final CursorHolder cursorHolder = cursorFactory.makeCursorHolder(makeCursorBuildSpec(query, timeseriesQueryMetrics)); + if (cursorHolder.isPreAggregated()) { + query = query.withAggregatorSpecs(AggregatorUtil.getCombiningAggregators(query.getAggregatorSpecs())); + } try { final Sequence> result; diff --git a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryEngine.java b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryEngine.java index 442e04552f1..d10d26242e3 100644 --- a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryEngine.java +++ b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryEngine.java @@ -30,6 +30,7 @@ import org.apache.druid.query.CursorGranularizer; import org.apache.druid.query.QueryMetrics; import org.apache.druid.query.Result; import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.aggregation.AggregatorUtil; import org.apache.druid.query.extraction.ExtractionFn; import org.apache.druid.query.topn.types.TopNColumnAggregatesProcessor; import org.apache.druid.query.topn.types.TopNColumnAggregatesProcessorFactory; @@ -73,7 +74,7 @@ public class TopNQueryEngine * update {@link TopNResultValue} */ public Sequence> query( - final TopNQuery query, + TopNQuery query, final Segment segment, @Nullable final TopNQueryMetrics queryMetrics ) @@ -87,6 +88,9 @@ public class TopNQueryEngine final CursorBuildSpec buildSpec = makeCursorBuildSpec(query, queryMetrics); final CursorHolder cursorHolder = cursorFactory.makeCursorHolder(buildSpec); + if (cursorHolder.isPreAggregated()) { + query = query.withAggregatorSpecs(AggregatorUtil.getCombiningAggregators(query.getAggregatorSpecs())); + } final Cursor cursor = cursorHolder.asCursor(); if (cursor == null) { return Sequences.withBaggage(Sequences.empty(), cursorHolder); @@ -127,7 +131,6 @@ public class TopNQueryEngine return Sequences.withBaggage(Sequences.empty(), cursorHolder); } - if (queryMetrics != null) { queryMetrics.cursor(cursor); } diff --git a/processing/src/main/java/org/apache/druid/segment/CursorHolder.java b/processing/src/main/java/org/apache/druid/segment/CursorHolder.java index a70fd8757e1..79bf2b4e557 100644 --- a/processing/src/main/java/org/apache/druid/segment/CursorHolder.java +++ b/processing/src/main/java/org/apache/druid/segment/CursorHolder.java @@ -22,6 +22,7 @@ package org.apache.druid.segment; import org.apache.druid.java.util.common.UOE; import org.apache.druid.query.Order; import org.apache.druid.query.OrderBy; +import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.vector.VectorCursor; @@ -58,6 +59,22 @@ public interface CursorHolder extends Closeable return false; } + /** + * Returns true if the {@link Cursor} or {@link VectorCursor} contains pre-aggregated columns for all + * {@link AggregatorFactory} specified in {@link CursorBuildSpec#getAggregators()}. + *

+ * If this method returns true, {@link ColumnSelectorFactory} and + * {@link org.apache.druid.segment.vector.VectorColumnSelectorFactory} created from {@link Cursor} and + * {@link VectorCursor} respectively will provide selectors for {@link AggregatorFactory#getName()}, and engines + * should rewrite the query using {@link AggregatorFactory#getCombiningFactory()}, since the values returned from + * these selectors will be of type {@link AggregatorFactory#getIntermediateType()}, so the cursor becomes a "fold" + * operation rather than a "build" operation. + */ + default boolean isPreAggregated() + { + return false; + } + /** * Returns cursor ordering, which may or may not match {@link CursorBuildSpec#getPreferredOrdering()}. If returns * an empty list then the cursor has no defined ordering. diff --git a/processing/src/test/java/org/apache/druid/segment/CursorHolderPreaggTest.java b/processing/src/test/java/org/apache/druid/segment/CursorHolderPreaggTest.java new file mode 100644 index 00000000000..82bba60821c --- /dev/null +++ b/processing/src/test/java/org/apache/druid/segment/CursorHolderPreaggTest.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.collections.CloseableDefaultBlockingPool; +import org.apache.druid.collections.CloseableStupidPool; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.query.Druids; +import org.apache.druid.query.IterableRowsCursorHelper; +import org.apache.druid.query.Result; +import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.groupby.GroupByQuery; +import org.apache.druid.query.groupby.GroupByQueryConfig; +import org.apache.druid.query.groupby.GroupByResourcesReservationPool; +import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.query.groupby.ResultRow; +import org.apache.druid.query.timeseries.TimeseriesQuery; +import org.apache.druid.query.timeseries.TimeseriesQueryEngine; +import org.apache.druid.query.timeseries.TimeseriesResultValue; +import org.apache.druid.query.topn.TopNQuery; +import org.apache.druid.query.topn.TopNQueryBuilder; +import org.apache.druid.query.topn.TopNQueryEngine; +import org.apache.druid.query.topn.TopNResultValue; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.apache.druid.timeline.SegmentId; +import org.apache.druid.utils.CloseableUtils; +import org.joda.time.Interval; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; + +import javax.annotation.Nullable; +import java.io.Closeable; +import java.nio.ByteBuffer; +import java.util.List; + +public class CursorHolderPreaggTest extends InitializedNullHandlingTest +{ + private GroupingEngine groupingEngine; + private TopNQueryEngine topNQueryEngine; + private TimeseriesQueryEngine timeseriesQueryEngine; + + private CursorFactory cursorFactory; + private Segment segment; + + @Rule + public final CloserRule closer = new CloserRule(false); + + @Before + public void setup() + { + final CloseableStupidPool pool = closer.closeLater( + new CloseableStupidPool<>( + "CursorHolderPreaggTest-bufferPool", + () -> ByteBuffer.allocate(50000) + ) + ); + topNQueryEngine = new TopNQueryEngine(pool); + timeseriesQueryEngine = new TimeseriesQueryEngine(pool); + groupingEngine = new GroupingEngine( + new DruidProcessingConfig(), + GroupByQueryConfig::new, + pool, + new GroupByResourcesReservationPool( + closer.closeLater( + new CloseableDefaultBlockingPool<>( + () -> ByteBuffer.allocate(50000), + 4 + ) + ), + new GroupByQueryConfig() + ), + TestHelper.makeJsonMapper(), + TestHelper.makeSmileMapper(), + (query, future) -> { + } + ); + + this.cursorFactory = new CursorFactory() + { + private final RowSignature rowSignature = RowSignature.builder() + .add("a", ColumnType.STRING) + .add("b", ColumnType.STRING) + .add("cnt", ColumnType.LONG) + .build(); + + private final Pair cursorAndCloser = IterableRowsCursorHelper.getCursorFromIterable( + ImmutableList.of( + new Object[]{"a", "aa", 5L}, + new Object[]{"a", "aa", 6L}, + new Object[]{"b", "bb", 7L} + ), + rowSignature + ); + + @Override + public CursorHolder makeCursorHolder(CursorBuildSpec spec) + { + return new CursorHolder() + { + @Nullable + @Override + public Cursor asCursor() + { + return cursorAndCloser.lhs; + } + + @Override + public boolean isPreAggregated() + { + return true; + } + + @Override + public void close() + { + CloseableUtils.closeAndWrapExceptions(cursorAndCloser.rhs); + } + }; + } + + @Override + public RowSignature getRowSignature() + { + return rowSignature; + } + + @Override + @Nullable + public ColumnCapabilities getColumnCapabilities(String column) + { + return rowSignature.getColumnCapabilities(column); + } + }; + + segment = new Segment() + { + @Override + public SegmentId getId() + { + return SegmentId.dummy("test"); + } + + @Override + public Interval getDataInterval() + { + return Intervals.ETERNITY; + } + + @Nullable + @Override + public QueryableIndex asQueryableIndex() + { + return null; + } + + @Override + public CursorFactory asCursorFactory() + { + return cursorFactory; + } + + @Override + public void close() + { + + } + }; + } + + @Test + public void testTopn() + { + final TopNQuery topNQuery = new TopNQueryBuilder().dataSource("test") + .granularity(Granularities.ALL) + .intervals(ImmutableList.of(Intervals.ETERNITY)) + .dimension("a") + .aggregators(new CountAggregatorFactory("cnt")) + .metric("cnt") + .threshold(10) + .build(); + Sequence> results = topNQueryEngine.query( + topNQuery, + segment, + null + ); + + List> rows = results.toList(); + Assert.assertEquals(1, rows.size()); + // the cnt column is treated as pre-aggregated, so the values of the rows are summed + Assert.assertEquals(2, rows.get(0).getValue().getValue().size()); + Assert.assertEquals(11L, rows.get(0).getValue().getValue().get(0).getLongMetric("cnt").longValue()); + Assert.assertEquals(7L, rows.get(0).getValue().getValue().get(1).getLongMetric("cnt").longValue()); + } + + @Test + public void testGroupBy() + { + final GroupByQuery query = GroupByQuery.builder() + .setDataSource("test") + .setGranularity(Granularities.ALL) + .setInterval(Intervals.ETERNITY) + .addDimension("a") + .addDimension("b") + .addAggregator(new CountAggregatorFactory("cnt")) + .build(); + + Sequence results = groupingEngine.process( + query, + cursorFactory, + null, + null + ); + List rows = results.toList(); + Assert.assertEquals(2, rows.size()); + // the cnt column is treated as pre-aggregated, so the values of the rows are summed + Assert.assertArrayEquals(new Object[]{"a", "aa", 11L}, rows.get(0).getArray()); + Assert.assertArrayEquals(new Object[]{"b", "bb", 7L}, rows.get(1).getArray()); + } + + @Test + public void testTimeseries() + { + TimeseriesQuery timeseriesQuery = Druids.newTimeseriesQueryBuilder() + .dataSource("test") + .intervals(ImmutableList.of(Intervals.ETERNITY)) + .granularity(Granularities.ALL) + .aggregators(new CountAggregatorFactory("cnt")) + .build(); + Sequence> results = timeseriesQueryEngine.process( + timeseriesQuery, + cursorFactory, + null, + null + ); + List> rows = results.toList(); + Assert.assertEquals(1, rows.size()); + // the cnt column is treated as pre-aggregated, so the values of the rows are summed + Assert.assertEquals(18L, (long) rows.get(0).getValue().getLongMetric("cnt")); + } +}