Optimize filtered aggs with interval filters in per-segment queries (#5857)

* Optimize per-segment queries

* Always optimize, add unit test

* PR comments

* Only run IntervalDimFilter optimization on __time column

* PR comments

* Checkstyle fix

* Add test for non __time column
This commit is contained in:
Jonathan Wei 2018-08-01 14:39:38 -07:00 committed by GitHub
parent e270362767
commit b9c445c780
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1212 additions and 29 deletions

View File

@ -0,0 +1,457 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.benchmark.query.timecompare;
import com.fasterxml.jackson.databind.InjectableValues;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.Files;
import io.druid.benchmark.datagen.BenchmarkDataGenerator;
import io.druid.benchmark.datagen.BenchmarkSchemaInfo;
import io.druid.benchmark.datagen.BenchmarkSchemas;
import io.druid.benchmark.query.QueryBenchmarkUtil;
import io.druid.collections.StupidPool;
import io.druid.data.input.InputRow;
import io.druid.hll.HyperLogLogHash;
import io.druid.jackson.DefaultObjectMapper;
import io.druid.java.util.common.Intervals;
import io.druid.java.util.common.concurrent.Execs;
import io.druid.java.util.common.granularity.Granularities;
import io.druid.java.util.common.guava.Sequence;
import io.druid.java.util.common.logger.Logger;
import io.druid.math.expr.ExprMacroTable;
import io.druid.offheap.OffheapBufferGenerator;
import io.druid.query.Druids;
import io.druid.query.FinalizeResultsQueryRunner;
import io.druid.query.PerSegmentOptimizingQueryRunner;
import io.druid.query.PerSegmentQueryOptimizationContext;
import io.druid.query.Query;
import io.druid.query.QueryPlus;
import io.druid.query.QueryRunner;
import io.druid.query.QueryRunnerFactory;
import io.druid.query.QueryToolChest;
import io.druid.query.Result;
import io.druid.query.SegmentDescriptor;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.FilteredAggregatorFactory;
import io.druid.query.aggregation.LongSumAggregatorFactory;
import io.druid.query.aggregation.hyperloglog.HyperUniquesSerde;
import io.druid.query.filter.IntervalDimFilter;
import io.druid.query.spec.MultipleIntervalSegmentSpec;
import io.druid.query.spec.QuerySegmentSpec;
import io.druid.query.timeseries.TimeseriesQueryEngine;
import io.druid.query.timeseries.TimeseriesQueryQueryToolChest;
import io.druid.query.timeseries.TimeseriesQueryRunnerFactory;
import io.druid.query.timeseries.TimeseriesResultValue;
import io.druid.query.topn.TopNQueryBuilder;
import io.druid.query.topn.TopNQueryConfig;
import io.druid.query.topn.TopNQueryQueryToolChest;
import io.druid.query.topn.TopNQueryRunnerFactory;
import io.druid.query.topn.TopNResultValue;
import io.druid.segment.IndexIO;
import io.druid.segment.IndexMergerV9;
import io.druid.segment.IndexSpec;
import io.druid.segment.QueryableIndex;
import io.druid.segment.QueryableIndexSegment;
import io.druid.segment.column.Column;
import io.druid.segment.column.ColumnConfig;
import io.druid.segment.incremental.IncrementalIndex;
import io.druid.segment.serde.ComplexMetrics;
import io.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.commons.io.FileUtils;
import org.joda.time.Interval;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
@State(Scope.Benchmark)
@Fork(value = 1)
@Warmup(iterations = 50)
@Measurement(iterations = 200)
public class TimeCompareBenchmark
{
@Param({"10"})
private int numSegments;
@Param({"100000"})
private int rowsPerSegment;
@Param({"100"})
private int threshold;
protected static final Map<String, String> scriptDoubleSum = Maps.newHashMap();
static {
scriptDoubleSum.put("fnAggregate", "function aggregate(current, a) { return current + a }");
scriptDoubleSum.put("fnReset", "function reset() { return 0 }");
scriptDoubleSum.put("fnCombine", "function combine(a,b) { return a + b }");
}
private static final Logger log = new Logger(TimeCompareBenchmark.class);
private static final int RNG_SEED = 9999;
private static final IndexMergerV9 INDEX_MERGER_V9;
private static final IndexIO INDEX_IO;
public static final ObjectMapper JSON_MAPPER;
private List<IncrementalIndex> incIndexes;
private List<QueryableIndex> qIndexes;
private QueryRunnerFactory topNFactory;
private Query topNQuery;
private QueryRunner topNRunner;
private QueryRunnerFactory timeseriesFactory;
private Query timeseriesQuery;
private QueryRunner timeseriesRunner;
private BenchmarkSchemaInfo schemaInfo;
private File tmpDir;
private Interval[] segmentIntervals;
private ExecutorService executorService;
static {
JSON_MAPPER = new DefaultObjectMapper();
InjectableValues.Std injectableValues = new InjectableValues.Std();
injectableValues.addValue(ExprMacroTable.class, ExprMacroTable.nil());
JSON_MAPPER.setInjectableValues(injectableValues);
INDEX_IO = new IndexIO(
JSON_MAPPER,
OffHeapMemorySegmentWriteOutMediumFactory.instance(),
new ColumnConfig()
{
@Override
public int columnCacheSizeBytes()
{
return 0;
}
}
);
INDEX_MERGER_V9 = new IndexMergerV9(JSON_MAPPER, INDEX_IO, OffHeapMemorySegmentWriteOutMediumFactory.instance());
}
private static final Map<String, Map<String, Object>> SCHEMA_QUERY_MAP = new LinkedHashMap<>();
private void setupQueries()
{
// queries for the basic schema
Map<String, Object> basicQueries = new LinkedHashMap<>();
BenchmarkSchemaInfo basicSchema = BenchmarkSchemas.SCHEMA_MAP.get("basic");
QuerySegmentSpec intervalSpec = new MultipleIntervalSegmentSpec(Collections.singletonList(basicSchema.getDataInterval()));
long startMillis = basicSchema.getDataInterval().getStartMillis();
long endMillis = basicSchema.getDataInterval().getEndMillis();
long half = (endMillis - startMillis) / 2;
Interval recent = Intervals.utc(half, endMillis);
Interval previous = Intervals.utc(startMillis, half);
log.info("Recent interval: " + recent);
log.info("Previous interval: " + previous);
{ // basic.topNTimeCompare
List<AggregatorFactory> queryAggs = new ArrayList<>();
queryAggs.add(
new FilteredAggregatorFactory(
//jsAgg1,
new LongSumAggregatorFactory(
"sumLongSequential", "sumLongSequential"
),
new IntervalDimFilter(
Column.TIME_COLUMN_NAME,
Collections.singletonList(recent),
null
)
)
);
queryAggs.add(
new FilteredAggregatorFactory(
new LongSumAggregatorFactory(
"_cmp_sumLongSequential", "sumLongSequential"
),
new IntervalDimFilter(
Column.TIME_COLUMN_NAME,
Collections.singletonList(previous),
null
)
)
);
TopNQueryBuilder queryBuilderA = new TopNQueryBuilder()
.dataSource("blah")
.granularity(Granularities.ALL)
.dimension("dimUniform")
.metric("sumLongSequential")
.intervals(intervalSpec)
.aggregators(queryAggs)
.threshold(threshold);
topNQuery = queryBuilderA.build();
topNFactory = new TopNQueryRunnerFactory(
new StupidPool<>(
"TopNBenchmark-compute-bufferPool",
new OffheapBufferGenerator("compute", 250000000),
0,
Integer.MAX_VALUE
),
new TopNQueryQueryToolChest(new TopNQueryConfig(), QueryBenchmarkUtil.NoopIntervalChunkingQueryRunnerDecorator()),
QueryBenchmarkUtil.NOOP_QUERYWATCHER
);
basicQueries.put("topNTimeCompare", queryBuilderA);
}
{ // basic.timeseriesTimeCompare
List<AggregatorFactory> queryAggs = new ArrayList<>();
queryAggs.add(
new FilteredAggregatorFactory(
new LongSumAggregatorFactory(
"sumLongSequential", "sumLongSequential"
),
new IntervalDimFilter(
Column.TIME_COLUMN_NAME,
Collections.singletonList(recent),
null
)
)
);
queryAggs.add(
new FilteredAggregatorFactory(
new LongSumAggregatorFactory(
"_cmp_sumLongSequential", "sumLongSequential"
),
new IntervalDimFilter(
Column.TIME_COLUMN_NAME,
Collections.singletonList(previous),
null
)
)
);
Druids.TimeseriesQueryBuilder timeseriesQueryBuilder = Druids.newTimeseriesQueryBuilder()
.dataSource("blah")
.granularity(Granularities.ALL)
.intervals(intervalSpec)
.aggregators(queryAggs)
.descending(false);
timeseriesQuery = timeseriesQueryBuilder.build();
timeseriesFactory = new TimeseriesQueryRunnerFactory(
new TimeseriesQueryQueryToolChest(
QueryBenchmarkUtil.NoopIntervalChunkingQueryRunnerDecorator()
),
new TimeseriesQueryEngine(),
QueryBenchmarkUtil.NOOP_QUERYWATCHER
);
}
SCHEMA_QUERY_MAP.put("basic", basicQueries);
}
@Setup
public void setup() throws IOException
{
log.info("SETUP CALLED AT " + System.currentTimeMillis());
if (ComplexMetrics.getSerdeForType("hyperUnique") == null) {
ComplexMetrics.registerSerde("hyperUnique", new HyperUniquesSerde(HyperLogLogHash.getDefault()));
}
executorService = Execs.multiThreaded(numSegments, "TopNThreadPool");
setupQueries();
String schemaName = "basic";
schemaInfo = BenchmarkSchemas.SCHEMA_MAP.get(schemaName);
segmentIntervals = new Interval[numSegments];
long startMillis = schemaInfo.getDataInterval().getStartMillis();
long endMillis = schemaInfo.getDataInterval().getEndMillis();
long partialIntervalMillis = (endMillis - startMillis) / numSegments;
for (int i = 0; i < numSegments; i++) {
long partialEndMillis = startMillis + partialIntervalMillis;
segmentIntervals[i] = Intervals.utc(startMillis, partialEndMillis);
log.info("Segment [%d] with interval [%s]", i, segmentIntervals[i]);
startMillis = partialEndMillis;
}
incIndexes = new ArrayList<>();
for (int i = 0; i < numSegments; i++) {
log.info("Generating rows for segment " + i);
BenchmarkDataGenerator gen = new BenchmarkDataGenerator(
schemaInfo.getColumnSchemas(),
RNG_SEED + i,
segmentIntervals[i],
rowsPerSegment
);
IncrementalIndex incIndex = makeIncIndex();
for (int j = 0; j < rowsPerSegment; j++) {
InputRow row = gen.nextRow();
if (j % 10000 == 0) {
log.info(j + " rows generated.");
}
incIndex.add(row);
}
incIndexes.add(incIndex);
}
tmpDir = Files.createTempDir();
log.info("Using temp dir: " + tmpDir.getAbsolutePath());
qIndexes = new ArrayList<>();
for (int i = 0; i < numSegments; i++) {
File indexFile = INDEX_MERGER_V9.persist(
incIndexes.get(i),
tmpDir,
new IndexSpec(),
null
);
QueryableIndex qIndex = INDEX_IO.loadIndex(indexFile);
qIndexes.add(qIndex);
}
List<QueryRunner<Result<TopNResultValue>>> singleSegmentRunners = Lists.newArrayList();
QueryToolChest toolChest = topNFactory.getToolchest();
for (int i = 0; i < numSegments; i++) {
String segmentName = "qIndex" + i;
QueryRunner<Result<TopNResultValue>> runner = QueryBenchmarkUtil.makeQueryRunner(
topNFactory,
segmentName,
new QueryableIndexSegment(segmentName, qIndexes.get(i))
);
singleSegmentRunners.add(
new PerSegmentOptimizingQueryRunner<>(
toolChest.preMergeQueryDecoration(runner),
new PerSegmentQueryOptimizationContext(
new SegmentDescriptor(segmentIntervals[i], "1", 0)
)
)
);
}
topNRunner = toolChest.postMergeQueryDecoration(
new FinalizeResultsQueryRunner<>(
toolChest.mergeResults(topNFactory.mergeRunners(executorService, singleSegmentRunners)),
toolChest
)
);
List<QueryRunner<Result<TimeseriesResultValue>>> singleSegmentRunnersT = Lists.newArrayList();
QueryToolChest toolChestT = timeseriesFactory.getToolchest();
for (int i = 0; i < numSegments; i++) {
String segmentName = "qIndex" + i;
QueryRunner<Result<TimeseriesResultValue>> runner = QueryBenchmarkUtil.makeQueryRunner(
timeseriesFactory,
segmentName,
new QueryableIndexSegment(segmentName, qIndexes.get(i))
);
singleSegmentRunnersT.add(
new PerSegmentOptimizingQueryRunner<>(
toolChestT.preMergeQueryDecoration(runner),
new PerSegmentQueryOptimizationContext(
new SegmentDescriptor(segmentIntervals[i], "1", 0)
)
)
);
}
timeseriesRunner = toolChestT.postMergeQueryDecoration(
new FinalizeResultsQueryRunner<>(
toolChestT.mergeResults(timeseriesFactory.mergeRunners(executorService, singleSegmentRunnersT)),
toolChestT
)
);
}
@TearDown
public void tearDown() throws IOException
{
FileUtils.deleteDirectory(tmpDir);
}
private IncrementalIndex makeIncIndex()
{
return new IncrementalIndex.Builder()
.setSimpleTestingIndexSchema(schemaInfo.getAggsArray())
.setReportParseExceptions(false)
.setMaxRowCount(rowsPerSegment)
.buildOnheap();
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public void queryMultiQueryableIndexTopN(Blackhole blackhole)
{
Sequence<Result<TopNResultValue>> queryResult = topNRunner.run(
QueryPlus.wrap(topNQuery),
Maps.<String, Object>newHashMap()
);
List<Result<TopNResultValue>> results = queryResult.toList();
for (Result<TopNResultValue> result : results) {
blackhole.consume(result);
}
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public void queryMultiQueryableIndexTimeseries(Blackhole blackhole)
{
Sequence<Result<TimeseriesResultValue>> queryResult = timeseriesRunner.run(
QueryPlus.wrap(timeseriesQuery),
Maps.<String, Object>newHashMap()
);
List<Result<TimeseriesResultValue>> results = queryResult.toList();
for (Result<TimeseriesResultValue> result : results) {
blackhole.consume(result);
}
}
}

View File

@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.query;
import io.druid.java.util.common.guava.Sequence;
import java.util.Map;
/**
* This runner optimizes queries made on a single segment, using per-segment information,
* before submitting the queries to the base runner.
*
* Example optimizations include adjusting query filters based on per-segment information, such as intervals.
*
* This query runner should only wrap base query runners that will
* be used to query a single segment (i.e., when the query reaches a historical node).
*
* @param <T>
*/
public class PerSegmentOptimizingQueryRunner<T> implements QueryRunner<T>
{
private final QueryRunner<T> base;
private final PerSegmentQueryOptimizationContext optimizationContext;
public PerSegmentOptimizingQueryRunner(
QueryRunner<T> base,
PerSegmentQueryOptimizationContext optimizationContext
)
{
this.base = base;
this.optimizationContext = optimizationContext;
}
@Override
public Sequence<T> run(final QueryPlus<T> input, final Map<String, Object> responseContext)
{
return base.run(
input.optimizeForSegment(optimizationContext),
responseContext
);
}
}

View File

@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.query;
/**
* Holds information about a single segment that Query objects can use to optimize themselves
* when they are run on that single segment.
*
* @see PerSegmentOptimizingQueryRunner
*/
public class PerSegmentQueryOptimizationContext
{
private final SegmentDescriptor segmentDescriptor;
public PerSegmentQueryOptimizationContext(
SegmentDescriptor segmentDescriptor
)
{
this.segmentDescriptor = segmentDescriptor;
}
public SegmentDescriptor getSegmentDescriptor()
{
return segmentDescriptor;
}
}

View File

@ -109,4 +109,9 @@ public interface Query<T>
String getId();
Query<T> withDataSource(DataSource dataSource);
default Query<T> optimizeForSegment(PerSegmentQueryOptimizationContext optimizationContext)
{
return this;
}
}

View File

@ -144,4 +144,9 @@ public final class QueryPlus<T>
{
return query.getRunner(walker).run(this, context);
}
public QueryPlus<T> optimizeForSegment(PerSegmentQueryOptimizationContext optimizationContext)
{
return new QueryPlus<>(query.optimizeForSegment(optimizationContext), queryMetrics, identity);
}
}

View File

@ -23,6 +23,7 @@ import io.druid.guice.annotations.ExtensionPoint;
import io.druid.java.util.common.Cacheable;
import io.druid.java.util.common.UOE;
import io.druid.java.util.common.logger.Logger;
import io.druid.query.PerSegmentQueryOptimizationContext;
import io.druid.segment.ColumnSelectorFactory;
import javax.annotation.Nullable;
@ -141,6 +142,14 @@ public abstract class AggregatorFactory implements Cacheable
*/
public abstract int getMaxIntermediateSize();
/**
* Return a potentially optimized form of this AggregatorFactory for per-segment queries.
*/
public AggregatorFactory optimizeForSegment(PerSegmentQueryOptimizationContext optimizationContext)
{
return this;
}
/**
* Merges the list of AggregatorFactory[] (presumable from metadata of some segments being merged) and
* returns merged AggregatorFactory[] (for the metadata for merged segment).

View File

@ -98,6 +98,9 @@ public class AggregatorUtil
public static final byte STRING_FIRST_CACHE_TYPE_ID = 0x2B;
public static final byte STRING_LAST_CACHE_TYPE_ID = 0x2C;
// Suppressed aggregator
public static final byte SUPPRESSED_AGG_CACHE_TYPE_ID = 0x2D;
/**
* returns the list of dependent postAggregators that should be calculated in order to calculate given postAgg
*

View File

@ -21,12 +21,17 @@ package io.druid.query.aggregation;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import io.druid.query.PerSegmentQueryOptimizationContext;
import io.druid.query.filter.DimFilter;
import io.druid.query.filter.IntervalDimFilter;
import io.druid.query.filter.ValueMatcher;
import io.druid.segment.ColumnSelectorFactory;
import io.druid.segment.column.Column;
import io.druid.segment.filter.Filters;
import org.joda.time.Interval;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
@ -140,6 +145,66 @@ public class FilteredAggregatorFactory extends AggregatorFactory
return delegate.getMaxIntermediateSize();
}
@Override
public AggregatorFactory optimizeForSegment(PerSegmentQueryOptimizationContext optimizationContext)
{
if (filter instanceof IntervalDimFilter) {
IntervalDimFilter intervalDimFilter = ((IntervalDimFilter) filter);
if (intervalDimFilter.getExtractionFn() != null) {
// no support for extraction functions right now
return this;
}
if (!intervalDimFilter.getDimension().equals(Column.TIME_COLUMN_NAME)) {
// segment time boundary optimization only applies when we filter on __time
return this;
}
Interval segmentInterval = optimizationContext.getSegmentDescriptor().getInterval();
List<Interval> filterIntervals = intervalDimFilter.getIntervals();
List<Interval> excludedFilterIntervals = new ArrayList<>();
List<Interval> effectiveFilterIntervals = new ArrayList<>();
boolean segmentIsCovered = false;
for (Interval filterInterval : filterIntervals) {
Interval overlap = filterInterval.overlap(segmentInterval);
if (overlap == null) {
excludedFilterIntervals.add(filterInterval);
continue;
}
if (overlap.equals(segmentInterval)) {
segmentIsCovered = true;
break;
} else {
// clip the overlapping interval to the segment time boundaries
effectiveFilterIntervals.add(overlap);
}
}
// we can skip applying this filter, everything in the segment will match
if (segmentIsCovered) {
return delegate;
}
// we can skip this filter, nothing in the segment would match
if (excludedFilterIntervals.size() == filterIntervals.size()) {
return new SuppressedAggregatorFactory(delegate);
}
return new FilteredAggregatorFactory(
delegate,
new IntervalDimFilter(
intervalDimFilter.getDimension(),
effectiveFilterIntervals,
intervalDimFilter.getExtractionFn()
)
);
} else {
return this;
}
}
@JsonProperty
public AggregatorFactory getAggregator()
{

View File

@ -0,0 +1,374 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.query.aggregation;
import io.druid.query.PerSegmentQueryOptimizationContext;
import io.druid.query.cache.CacheKeyBuilder;
import io.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import io.druid.segment.ColumnSelectorFactory;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
/**
* This AggregatorFactory is meant for wrapping delegate aggregators for optimization purposes.
*
* The wrapper suppresses the aggregate() method for the underlying delegate, while leaving
* the behavior of other calls unchanged.
*
* This wrapper is meant to be used when an optimization decides that an aggregator can be entirely skipped
* (e.g., a FilteredAggregatorFactory where the filter condition will never match).
*/
public class SuppressedAggregatorFactory extends AggregatorFactory
{
private final AggregatorFactory delegate;
public SuppressedAggregatorFactory(
AggregatorFactory delegate
)
{
this.delegate = delegate;
}
@Override
public Aggregator factorize(ColumnSelectorFactory metricFactory)
{
return new SuppressedAggregator(delegate.factorize(metricFactory));
}
@Override
public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
{
return new SuppressedBufferAggregator(delegate.factorizeBuffered(metricFactory));
}
@Override
public Comparator getComparator()
{
return delegate.getComparator();
}
@Override
public Object combine(Object lhs, Object rhs)
{
return delegate.combine(lhs, rhs);
}
@Override
public AggregateCombiner makeAggregateCombiner()
{
return delegate.makeAggregateCombiner();
}
@Override
public AggregatorFactory getCombiningFactory()
{
return delegate.getCombiningFactory();
}
@Override
public AggregatorFactory getMergingFactory(AggregatorFactory other) throws AggregatorFactoryNotMergeableException
{
return delegate.getMergingFactory(other);
}
@Override
public List<AggregatorFactory> getRequiredColumns()
{
return delegate.getRequiredColumns();
}
@Override
public Object deserialize(Object object)
{
return delegate.deserialize(object);
}
@Override
public Object finalizeComputation(Object object)
{
return delegate.finalizeComputation(object);
}
@Override
public String getName()
{
return delegate.getName();
}
@Override
public List<String> requiredFields()
{
return delegate.requiredFields();
}
@Override
public String getTypeName()
{
return delegate.getTypeName();
}
@Override
public int getMaxIntermediateSize()
{
return delegate.getMaxIntermediateSize();
}
@Override
public AggregatorFactory optimizeForSegment(PerSegmentQueryOptimizationContext optimizationContext)
{
// we are already the result of an optimizeForSegment() call
return this;
}
@Override
public byte[] getCacheKey()
{
CacheKeyBuilder cacheKeyBuilder = new CacheKeyBuilder(AggregatorUtil.SUPPRESSED_AGG_CACHE_TYPE_ID);
cacheKeyBuilder.appendCacheable(delegate);
return cacheKeyBuilder.build();
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SuppressedAggregatorFactory that = (SuppressedAggregatorFactory) o;
return Objects.equals(getDelegate(), that.getDelegate());
}
@Override
public int hashCode()
{
return Objects.hash(getDelegate());
}
@Override
public String toString()
{
return "SuppressedAggregatorFactory{" +
"delegate=" + delegate +
'}';
}
public AggregatorFactory getDelegate()
{
return delegate;
}
public static class SuppressedAggregator implements Aggregator
{
private final Aggregator delegate;
public SuppressedAggregator(
Aggregator delegate
)
{
this.delegate = delegate;
}
@Override
public void aggregate()
{
//no-op
}
@Nullable
@Override
public Object get()
{
return delegate.get();
}
@Override
public float getFloat()
{
return delegate.getFloat();
}
@Override
public long getLong()
{
return delegate.getLong();
}
@Override
public double getDouble()
{
return delegate.getDouble();
}
@Override
public boolean isNull()
{
return delegate.isNull();
}
@Override
public void close()
{
delegate.close();
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SuppressedAggregator that = (SuppressedAggregator) o;
return Objects.equals(getDelegate(), that.getDelegate());
}
@Override
public int hashCode()
{
return Objects.hash(getDelegate());
}
@Override
public String toString()
{
return "SuppressedAggregator{" +
"delegate=" + delegate +
'}';
}
public Aggregator getDelegate()
{
return delegate;
}
}
public static class SuppressedBufferAggregator implements BufferAggregator
{
private final BufferAggregator delegate;
public SuppressedBufferAggregator(
BufferAggregator delegate
)
{
this.delegate = delegate;
}
@Override
public void init(ByteBuffer buf, int position)
{
delegate.init(buf, position);
}
@Override
public void aggregate(ByteBuffer buf, int position)
{
// no-op
}
@Override
public Object get(ByteBuffer buf, int position)
{
return delegate.get(buf, position);
}
@Override
public float getFloat(ByteBuffer buf, int position)
{
return delegate.getFloat(buf, position);
}
@Override
public long getLong(ByteBuffer buf, int position)
{
return delegate.getLong(buf, position);
}
@Override
public double getDouble(ByteBuffer buf, int position)
{
return delegate.getDouble(buf, position);
}
@Override
public void close()
{
delegate.close();
}
@Override
public void inspectRuntimeShape(RuntimeShapeInspector inspector)
{
delegate.inspectRuntimeShape(inspector);
}
@Override
public void relocate(int oldPosition, int newPosition, ByteBuffer oldBuffer, ByteBuffer newBuffer)
{
delegate.relocate(oldPosition, newPosition, oldBuffer, newBuffer);
}
@Override
public boolean isNull(ByteBuffer buf, int position)
{
return delegate.isNull(buf, position);
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SuppressedBufferAggregator that = (SuppressedBufferAggregator) o;
return Objects.equals(getDelegate(), that.getDelegate());
}
@Override
public int hashCode()
{
return Objects.hash(getDelegate());
}
@Override
public String toString()
{
return "SuppressedBufferAggregator{" +
"delegate=" + delegate +
'}';
}
public BufferAggregator getDelegate()
{
return delegate;
}
}
}

View File

@ -28,6 +28,7 @@ import io.druid.java.util.common.granularity.Granularity;
import io.druid.query.BaseQuery;
import io.druid.query.DataSource;
import io.druid.query.Druids;
import io.druid.query.PerSegmentQueryOptimizationContext;
import io.druid.query.Queries;
import io.druid.query.Query;
import io.druid.query.Result;
@ -37,6 +38,7 @@ import io.druid.query.filter.DimFilter;
import io.druid.query.spec.QuerySegmentSpec;
import io.druid.segment.VirtualColumns;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -153,6 +155,12 @@ public class TimeseriesQuery extends BaseQuery<Result<TimeseriesResultValue>>
return Druids.TimeseriesQueryBuilder.copy(this).dataSource(dataSource).build();
}
@Override
public Query<Result<TimeseriesResultValue>> optimizeForSegment(PerSegmentQueryOptimizationContext optimizationContext)
{
return Druids.TimeseriesQueryBuilder.copy(this).aggregators(optimizeAggs(optimizationContext)).build();
}
@Override
public TimeseriesQuery withOverriddenContext(Map<String, Object> contextOverrides)
{
@ -170,6 +178,15 @@ public class TimeseriesQuery extends BaseQuery<Result<TimeseriesResultValue>>
return Druids.TimeseriesQueryBuilder.copy(this).postAggregators(postAggregatorSpecs).build();
}
private List<AggregatorFactory> optimizeAggs(PerSegmentQueryOptimizationContext optimizationContext)
{
List<AggregatorFactory> optimizedAggs = new ArrayList<>();
for (AggregatorFactory aggregatorFactory : aggregatorSpecs) {
optimizedAggs.add(aggregatorFactory.optimizeForSegment(optimizationContext));
}
return optimizedAggs;
}
@Override
public String toString()
{

View File

@ -26,6 +26,7 @@ import com.google.common.collect.ImmutableList;
import io.druid.java.util.common.granularity.Granularity;
import io.druid.query.BaseQuery;
import io.druid.query.DataSource;
import io.druid.query.PerSegmentQueryOptimizationContext;
import io.druid.query.Queries;
import io.druid.query.Query;
import io.druid.query.Result;
@ -36,6 +37,7 @@ import io.druid.query.filter.DimFilter;
import io.druid.query.spec.QuerySegmentSpec;
import io.druid.segment.VirtualColumns;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -183,6 +185,12 @@ public class TopNQuery extends BaseQuery<Result<TopNResultValue>>
return new TopNQueryBuilder(this).dataSource(dataSource).build();
}
@Override
public Query<Result<TopNResultValue>> optimizeForSegment(PerSegmentQueryOptimizationContext optimizationContext)
{
return new TopNQueryBuilder(this).aggregators(optimizeAggs(optimizationContext)).build();
}
public TopNQuery withThreshold(int threshold)
{
return new TopNQueryBuilder(this).threshold(threshold).build();
@ -252,4 +260,13 @@ public class TopNQuery extends BaseQuery<Result<TopNResultValue>>
postAggregatorSpecs
);
}
private List<AggregatorFactory> optimizeAggs(PerSegmentQueryOptimizationContext optimizationContext)
{
List<AggregatorFactory> optimizedAggs = new ArrayList<>();
for (AggregatorFactory aggregatorFactory : aggregatorSpecs) {
optimizedAggs.add(aggregatorFactory.optimizeForSegment(optimizationContext));
}
return optimizedAggs;
}
}

View File

@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.query.topn;
import io.druid.java.util.common.Intervals;
import io.druid.query.PerSegmentQueryOptimizationContext;
import io.druid.query.SegmentDescriptor;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.FilteredAggregatorFactory;
import io.druid.query.aggregation.LongSumAggregatorFactory;
import io.druid.query.aggregation.SuppressedAggregatorFactory;
import io.druid.query.filter.IntervalDimFilter;
import io.druid.segment.column.Column;
import org.joda.time.Interval;
import org.junit.Assert;
import org.junit.Test;
import java.util.Collections;
public class PerSegmentQueryOptimizeTest
{
@Test
public void testFilteredAggregatorOptimize()
{
LongSumAggregatorFactory longSumAggregatorFactory = new LongSumAggregatorFactory("test", "test");
FilteredAggregatorFactory aggregatorFactory = new FilteredAggregatorFactory(
longSumAggregatorFactory,
new IntervalDimFilter(
Column.TIME_COLUMN_NAME,
Collections.singletonList(Intervals.utc(1000, 2000)),
null
)
);
Interval exclude = Intervals.utc(2000, 3000);
Interval include = Intervals.utc(1500, 1600);
Interval partial = Intervals.utc(1500, 2500);
AggregatorFactory excludedAgg = aggregatorFactory.optimizeForSegment(getOptimizationContext(exclude));
AggregatorFactory expectedSuppressedAgg = new SuppressedAggregatorFactory(longSumAggregatorFactory);
Assert.assertEquals(expectedSuppressedAgg, excludedAgg);
AggregatorFactory includedAgg = aggregatorFactory.optimizeForSegment(getOptimizationContext(include));
Assert.assertEquals(longSumAggregatorFactory, includedAgg);
AggregatorFactory partialAgg = aggregatorFactory.optimizeForSegment(getOptimizationContext(partial));
AggregatorFactory expectedPartialFilteredAgg = new FilteredAggregatorFactory(
longSumAggregatorFactory,
new IntervalDimFilter(
Column.TIME_COLUMN_NAME,
Collections.singletonList(Intervals.utc(1500, 2000)),
null
)
);
Assert.assertEquals(expectedPartialFilteredAgg, partialAgg);
}
@Test
public void testFilteredAggregatorDontOptimizeOnNonTimeColumn()
{
// Filter is not on __time, so no optimizations should be made.
LongSumAggregatorFactory longSumAggregatorFactory = new LongSumAggregatorFactory("test", "test");
FilteredAggregatorFactory aggregatorFactory = new FilteredAggregatorFactory(
longSumAggregatorFactory,
new IntervalDimFilter(
"not_time",
Collections.singletonList(Intervals.utc(1000, 2000)),
null
)
);
Interval exclude = Intervals.utc(2000, 3000);
Interval include = Intervals.utc(1500, 1600);
Interval partial = Intervals.utc(1500, 2500);
AggregatorFactory excludedAgg = aggregatorFactory.optimizeForSegment(getOptimizationContext(exclude));
Assert.assertEquals(aggregatorFactory, excludedAgg);
AggregatorFactory includedAgg = aggregatorFactory.optimizeForSegment(getOptimizationContext(include));
Assert.assertEquals(aggregatorFactory, includedAgg);
AggregatorFactory partialAgg = aggregatorFactory.optimizeForSegment(getOptimizationContext(partial));
Assert.assertEquals(aggregatorFactory, partialAgg);
}
private PerSegmentQueryOptimizationContext getOptimizationContext(Interval segmentInterval)
{
return new PerSegmentQueryOptimizationContext(
new SegmentDescriptor(segmentInterval, "0", 0)
);
}
}

View File

@ -39,6 +39,8 @@ import io.druid.query.DataSource;
import io.druid.query.FinalizeResultsQueryRunner;
import io.druid.query.MetricsEmittingQueryRunner;
import io.druid.query.NoopQueryRunner;
import io.druid.query.PerSegmentOptimizingQueryRunner;
import io.druid.query.PerSegmentQueryOptimizationContext;
import io.druid.query.Query;
import io.druid.query.QueryMetrics;
import io.druid.query.QueryRunner;
@ -280,38 +282,54 @@ public class ServerManager implements QuerySegmentWalker
{
SpecificSegmentSpec segmentSpec = new SpecificSegmentSpec(segmentDescriptor);
String segmentId = adapter.getIdentifier();
MetricsEmittingQueryRunner metricsEmittingQueryRunnerInner = new MetricsEmittingQueryRunner<>(
emitter,
toolChest,
new ReferenceCountingSegmentQueryRunner<>(factory, adapter, segmentDescriptor),
QueryMetrics::reportSegmentTime,
queryMetrics -> queryMetrics.segment(segmentId)
);
CachingQueryRunner cachingQueryRunner = new CachingQueryRunner<>(
segmentId,
segmentDescriptor,
objectMapper,
cache,
toolChest,
metricsEmittingQueryRunnerInner,
cachingExec,
cacheConfig
);
BySegmentQueryRunner bySegmentQueryRunner = new BySegmentQueryRunner<>(
segmentId,
adapter.getDataInterval().getStart(),
cachingQueryRunner
);
MetricsEmittingQueryRunner metricsEmittingQueryRunnerOuter = new MetricsEmittingQueryRunner<>(
emitter,
toolChest,
bySegmentQueryRunner,
QueryMetrics::reportSegmentAndCacheTime,
queryMetrics -> queryMetrics.segment(segmentId)
).withWaitMeasuredFromNow();
SpecificSegmentQueryRunner specificSegmentQueryRunner = new SpecificSegmentQueryRunner<>(
metricsEmittingQueryRunnerOuter,
segmentSpec
);
PerSegmentOptimizingQueryRunner perSegmentOptimizingQueryRunner = new PerSegmentOptimizingQueryRunner<>(
specificSegmentQueryRunner,
new PerSegmentQueryOptimizationContext(segmentDescriptor)
);
return new SetAndVerifyContextQueryRunner<>(
serverConfig,
CPUTimeMetricQueryRunner.safeBuild(
new SpecificSegmentQueryRunner<>(
new MetricsEmittingQueryRunner<>(
emitter,
toolChest,
new BySegmentQueryRunner<>(
segmentId,
adapter.getDataInterval().getStart(),
new CachingQueryRunner<>(
segmentId,
segmentDescriptor,
objectMapper,
cache,
toolChest,
new MetricsEmittingQueryRunner<>(
emitter,
toolChest,
new ReferenceCountingSegmentQueryRunner<>(factory, adapter, segmentDescriptor),
QueryMetrics::reportSegmentTime,
queryMetrics -> queryMetrics.segment(segmentId)
),
cachingExec,
cacheConfig
)
),
QueryMetrics::reportSegmentAndCacheTime,
queryMetrics -> queryMetrics.segment(segmentId)
).withWaitMeasuredFromNow(),
segmentSpec
),
perSegmentOptimizingQueryRunner,
toolChest,
emitter,
cpuTimeAccumulator,