fix dependent PostAggregators

This commit is contained in:
nishantmonu51 2014-04-15 13:37:31 +05:30
parent ecfa6bd1b1
commit d66ae5ac10
15 changed files with 282 additions and 48 deletions

View File

@ -0,0 +1,53 @@
/*
* Druid - a distributed column store.
* Copyright (C) 2012, 2013, 2014 Metamarkets Group Inc.
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package io.druid.query.aggregation;
import com.google.common.collect.Lists;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
public class AggregatorUtil
{
/** returns the list of dependent postAggregators that should be calculated in order to calculate given postAgg
* @param postAggregatorList List of postAggregator, there is a restriction that the postAgg should be in order that it can
* @param postAggName name of the postAgg on which dependency is to be calculated
*/
public static List<PostAggregator> pruneDependentPostAgg(List<PostAggregator> postAggregatorList, String postAggName)
{
LinkedList<PostAggregator> rv = Lists.newLinkedList();
Set<String> deps = new HashSet<>();
deps.add(postAggName);
// Iterate backwards to calculate deps
for (int i = postAggregatorList.size() - 1; i >= 0; i--) {
PostAggregator agg = postAggregatorList.get(i);
if (deps.contains(agg.getName())) {
rv.addFirst(agg); // add to the beginning of List
deps.remove(agg.getName());
deps.addAll(agg.getDependentFields());
}
}
return rv;
}
}

View File

@ -56,7 +56,6 @@ public class AggregateTopNMetricFirstAlgorithm implements TopNAlgorithm<int[], T
this.bufferPool = bufferPool;
}
@Override
public TopNParams makeInitParams(
DimensionSelector dimSelector, Cursor cursor
@ -69,7 +68,12 @@ public class AggregateTopNMetricFirstAlgorithm implements TopNAlgorithm<int[], T
public TopNResultBuilder makeResultBuilder(TopNParams params)
{
return query.getTopNMetricSpec().getResultBuilder(
params.getCursor().getTime(), query.getDimensionSpec(), query.getThreshold(), comparator
params.getCursor().getTime(),
query.getDimensionSpec(),
query.getThreshold(),
comparator,
query.getAggregatorSpecs(),
query.getPostAggregatorSpecs()
);
}

View File

@ -60,7 +60,12 @@ public class DimExtractionTopNAlgorithm extends BaseTopNAlgorithm<Aggregator[][]
public TopNResultBuilder makeResultBuilder(TopNParams params)
{
return query.getTopNMetricSpec().getResultBuilder(
params.getCursor().getTime(), query.getDimensionSpec(), query.getThreshold(), comparator
params.getCursor().getTime(),
query.getDimensionSpec(),
query.getThreshold(),
comparator,
query.getAggregatorSpecs(),
query.getPostAggregatorSpecs()
);
}
@ -144,9 +149,7 @@ public class DimExtractionTopNAlgorithm extends BaseTopNAlgorithm<Aggregator[][]
resultBuilder.addEntry(
entry.getKey(),
entry.getKey(),
vals,
query.getAggregatorSpecs(),
query.getPostAggregatorSpecs()
vals
);
}
}

View File

@ -75,10 +75,12 @@ public class InvertedTopNMetricSpec implements TopNMetricSpec
DateTime timestamp,
DimensionSpec dimSpec,
int threshold,
Comparator comparator
Comparator comparator,
List<AggregatorFactory> aggFactories,
List<PostAggregator> postAggs
)
{
return delegate.getResultBuilder(timestamp, dimSpec, threshold, comparator);
return delegate.getResultBuilder(timestamp, dimSpec, threshold, comparator, aggFactories, postAggs);
}
@Override

View File

@ -80,10 +80,12 @@ public class LexicographicTopNMetricSpec implements TopNMetricSpec
DateTime timestamp,
DimensionSpec dimSpec,
int threshold,
Comparator comparator
Comparator comparator,
List<AggregatorFactory> aggFactories,
List<PostAggregator> postAggs
)
{
return new TopNLexicographicResultBuilder(timestamp, dimSpec, threshold, previousStop, comparator);
return new TopNLexicographicResultBuilder(timestamp, dimSpec, threshold, previousStop, comparator, aggFactories);
}
@Override

View File

@ -121,10 +121,12 @@ public class NumericTopNMetricSpec implements TopNMetricSpec
DateTime timestamp,
DimensionSpec dimSpec,
int threshold,
Comparator comparator
Comparator comparator,
List<AggregatorFactory> aggFactories,
List<PostAggregator> postAggs
)
{
return new TopNNumericResultBuilder(timestamp, dimSpec, metric, threshold, comparator);
return new TopNNumericResultBuilder(timestamp, dimSpec, metric, threshold, comparator, aggFactories, postAggs);
}
@Override

View File

@ -35,7 +35,8 @@ import java.util.Comparator;
/**
*/
public class PooledTopNAlgorithm extends BaseTopNAlgorithm<int[], BufferAggregator[], PooledTopNAlgorithm.PooledTopNParams>
public class PooledTopNAlgorithm
extends BaseTopNAlgorithm<int[], BufferAggregator[], PooledTopNAlgorithm.PooledTopNParams>
{
private final Capabilities capabilities;
private final TopNQuery query;
@ -117,7 +118,12 @@ public class PooledTopNAlgorithm extends BaseTopNAlgorithm<int[], BufferAggregat
public TopNResultBuilder makeResultBuilder(PooledTopNParams params)
{
return query.getTopNMetricSpec().getResultBuilder(
params.getCursor().getTime(), query.getDimensionSpec(), query.getThreshold(), comparator
params.getCursor().getTime(),
query.getDimensionSpec(),
query.getThreshold(),
comparator,
query.getAggregatorSpecs(),
query.getPostAggregatorSpecs()
);
}
@ -217,9 +223,7 @@ public class PooledTopNAlgorithm extends BaseTopNAlgorithm<int[], BufferAggregat
resultBuilder.addEntry(
dimSelector.lookupName(i),
i,
vals,
query.getAggregatorSpecs(),
query.getPostAggregatorSpecs()
vals
);
}
}
@ -228,7 +232,7 @@ public class PooledTopNAlgorithm extends BaseTopNAlgorithm<int[], BufferAggregat
@Override
protected void closeAggregators(BufferAggregator[] bufferAggregators)
{
for(BufferAggregator agg : bufferAggregators) {
for (BufferAggregator agg : bufferAggregators) {
agg.close();
}
}
@ -246,11 +250,6 @@ public class PooledTopNAlgorithm extends BaseTopNAlgorithm<int[], BufferAggregat
public static class PooledTopNParams extends TopNParams
{
public static Builder builder()
{
return new Builder();
}
private final ResourceHolder<ByteBuffer> resultsBufHolder;
private final ByteBuffer resultsBuf;
private final int[] aggregatorSizes;
@ -278,6 +277,11 @@ public class PooledTopNAlgorithm extends BaseTopNAlgorithm<int[], BufferAggregat
this.arrayProvider = arrayProvider;
}
public static Builder builder()
{
return new Builder();
}
public ResourceHolder<ByteBuffer> getResultsBufHolder()
{
return resultsBufHolder;

View File

@ -24,6 +24,7 @@ import io.druid.granularity.AllGranularity;
import io.druid.granularity.QueryGranularity;
import io.druid.query.Result;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.AggregatorUtil;
import io.druid.query.aggregation.PostAggregator;
import io.druid.query.dimension.DimensionSpec;
import org.joda.time.DateTime;
@ -63,9 +64,13 @@ public class TopNBinaryFn implements BinaryFn<Result<TopNResultValue>, Result<To
this.topNMetricSpec = topNMetricSpec;
this.threshold = threshold;
this.aggregations = aggregatorSpecs;
this.postAggregations = postAggregatorSpecs;
this.dimensionSpec = dimSpec;
this.postAggregations = AggregatorUtil.pruneDependentPostAgg(
postAggregatorSpecs,
this.topNMetricSpec.getMetricName(this.dimensionSpec)
);
this.comparator = topNMetricSpec.getComparator(aggregatorSpecs, postAggregatorSpecs);
}
@ -94,7 +99,7 @@ public class TopNBinaryFn implements BinaryFn<Result<TopNResultValue>, Result<To
DimensionAndMetricValueExtractor arg1Val = retVals.get(dimensionValue);
if (arg1Val != null) {
// size of map = aggregattor + topNDim + postAgg (If sorting is done on post agg field)
// size of map = aggregator + topNDim + postAgg (If sorting is done on post agg field)
Map<String, Object> retVal = new LinkedHashMap<String, Object>(aggregations.size() + 2);
retVal.put(dimension, dimensionValue);
@ -103,10 +108,8 @@ public class TopNBinaryFn implements BinaryFn<Result<TopNResultValue>, Result<To
retVal.put(metricName, factory.combine(arg1Val.getMetric(metricName), arg2Val.getMetric(metricName)));
}
for (PostAggregator postAgg : postAggregations) {
if (postAgg.getName().equalsIgnoreCase(topNMetricName)) {
retVal.put(postAgg.getName(), postAgg.compute(retVal));
}
}
retVals.put(dimensionValue, new DimensionAndMetricValueExtractor(retVal));
} else {
@ -121,7 +124,14 @@ public class TopNBinaryFn implements BinaryFn<Result<TopNResultValue>, Result<To
timestamp = gran.toDateTime(gran.truncate(arg1.getTimestamp().getMillis()));
}
TopNResultBuilder bob = topNMetricSpec.getResultBuilder(timestamp, dimSpec, threshold, comparator);
TopNResultBuilder bob = topNMetricSpec.getResultBuilder(
timestamp,
dimSpec,
threshold,
comparator,
aggregations,
postAggregations
);
for (DimensionAndMetricValueExtractor extractor : retVals.values()) {
bob.addEntry(extractor);
}

View File

@ -40,7 +40,7 @@ public class TopNLexicographicResultBuilder implements TopNResultBuilder
private final DateTime timestamp;
private final DimensionSpec dimSpec;
private final String previousStop;
private final List<AggregatorFactory> aggFactories;
private MinMaxPriorityQueue<DimValHolder> pQueue = null;
public TopNLexicographicResultBuilder(
@ -48,12 +48,14 @@ public class TopNLexicographicResultBuilder implements TopNResultBuilder
DimensionSpec dimSpec,
int threshold,
String previousStop,
final Comparator comparator
final Comparator comparator,
List<AggregatorFactory> aggFactories
)
{
this.timestamp = timestamp;
this.dimSpec = dimSpec;
this.previousStop = previousStop;
this.aggFactories = aggFactories;
instantiatePQueue(threshold, comparator);
}
@ -62,9 +64,7 @@ public class TopNLexicographicResultBuilder implements TopNResultBuilder
public TopNResultBuilder addEntry(
String dimName,
Object dimValIndex,
Object[] metricVals,
List<AggregatorFactory> aggFactories,
List<PostAggregator> postAggs
Object[] metricVals
)
{
Map<String, Object> metricValues = Maps.newLinkedHashMap();

View File

@ -47,7 +47,9 @@ public interface TopNMetricSpec
DateTime timestamp,
DimensionSpec dimSpec,
int threshold,
Comparator comparator
Comparator comparator,
List<AggregatorFactory> aggFactories,
List<PostAggregator> postAggs
);
public byte[] getCacheKey();

View File

@ -23,6 +23,7 @@ import com.google.common.collect.Maps;
import com.google.common.collect.MinMaxPriorityQueue;
import io.druid.query.Result;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.AggregatorUtil;
import io.druid.query.aggregation.PostAggregator;
import io.druid.query.dimension.DimensionSpec;
import org.joda.time.DateTime;
@ -40,6 +41,8 @@ public class TopNNumericResultBuilder implements TopNResultBuilder
private final DateTime timestamp;
private final DimensionSpec dimSpec;
private final String metricName;
private final List<AggregatorFactory> aggFactories;
private final List<PostAggregator> postAggs;
private MinMaxPriorityQueue<DimValHolder> pQueue = null;
public TopNNumericResultBuilder(
@ -47,12 +50,16 @@ public class TopNNumericResultBuilder implements TopNResultBuilder
DimensionSpec dimSpec,
String metricName,
int threshold,
final Comparator comparator
final Comparator comparator,
List<AggregatorFactory> aggFactories,
List<PostAggregator> postAggs
)
{
this.timestamp = timestamp;
this.dimSpec = dimSpec;
this.metricName = metricName;
this.aggFactories = aggFactories;
this.postAggs = AggregatorUtil.pruneDependentPostAgg(postAggs, this.metricName);
instantiatePQueue(threshold, comparator);
}
@ -61,9 +68,7 @@ public class TopNNumericResultBuilder implements TopNResultBuilder
public TopNResultBuilder addEntry(
String dimName,
Object dimValIndex,
Object[] metricVals,
List<AggregatorFactory> aggFactories,
List<PostAggregator> postAggs
Object[] metricVals
)
{
Map<String, Object> metricValues = Maps.newLinkedHashMap();
@ -76,10 +81,7 @@ public class TopNNumericResultBuilder implements TopNResultBuilder
}
for (PostAggregator postAgg : postAggs) {
if (postAgg.getName().equalsIgnoreCase(metricName)) {
metricValues.put(postAgg.getName(), postAgg.compute(metricValues));
break;
}
}
Object topNMetricVal = metricValues.get(metricName);

View File

@ -33,9 +33,7 @@ public interface TopNResultBuilder
public TopNResultBuilder addEntry(
String dimName,
Object dimValIndex,
Object[] metricVals,
List<AggregatorFactory> aggFactories,
List<PostAggregator> postAggs
Object[] metricVals
);
public TopNResultBuilder addEntry(

View File

@ -60,6 +60,7 @@ public class QueryRunnerTestHelper
public static final String indexMetric = "index";
public static final String uniqueMetric = "uniques";
public static final String addRowsIndexConstantMetric = "addRowsIndexConstant";
public static String dependentPostAggMetric = "dependentPostAgg";
public static final CountAggregatorFactory rowsCount = new CountAggregatorFactory("rows");
public static final LongSumAggregatorFactory indexLongSum = new LongSumAggregatorFactory("index", "index");
public static final DoubleSumAggregatorFactory indexDoubleSum = new DoubleSumAggregatorFactory("index", "index");
@ -72,8 +73,19 @@ public class QueryRunnerTestHelper
public static final FieldAccessPostAggregator indexPostAgg = new FieldAccessPostAggregator("index", "index");
public static final ArithmeticPostAggregator addRowsIndexConstant =
new ArithmeticPostAggregator(
"addRowsIndexConstant", "+", Lists.newArrayList(constant, rowsPostAgg, indexPostAgg)
addRowsIndexConstantMetric, "+", Lists.newArrayList(constant, rowsPostAgg, indexPostAgg)
);
// dependent on AddRowsIndexContact postAgg
public static final ArithmeticPostAggregator dependentPostAgg = new ArithmeticPostAggregator(
dependentPostAggMetric,
"+",
Lists.newArrayList(
constant,
new FieldAccessPostAggregator(addRowsIndexConstantMetric, addRowsIndexConstantMetric)
)
);
public static final List<AggregatorFactory> commonAggregators = Arrays.asList(
rowsCount,
indexDoubleSum,

View File

@ -0,0 +1,69 @@
/*
* Druid - a distributed column store.
* Copyright (C) 2012, 2013, 2014 Metamarkets Group Inc.
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package io.druid.query.aggregation;
import com.google.common.collect.Lists;
import io.druid.query.aggregation.post.ArithmeticPostAggregator;
import io.druid.query.aggregation.post.ConstantPostAggregator;
import io.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.junit.Assert;
import org.junit.Test;
import java.util.List;
public class AggregatorUtilTest
{
@Test
public void testPruneDependentPostAgg()
{
PostAggregator agg1 = new ArithmeticPostAggregator(
"abc", "+", Lists.<PostAggregator>newArrayList(
new ConstantPostAggregator("1", 1L, 1L), new ConstantPostAggregator("2", 2L, 2L)
)
);
PostAggregator dependency1 = new ArithmeticPostAggregator(
"dep1", "+", Lists.<PostAggregator>newArrayList(
new ConstantPostAggregator("1", 1L, 1L), new ConstantPostAggregator("4", 4L, 4L)
)
);
PostAggregator agg2 = new FieldAccessPostAggregator("def", "def");
PostAggregator dependency2 = new FieldAccessPostAggregator("dep2", "dep2");
PostAggregator aggregator = new ArithmeticPostAggregator(
"finalAgg",
"+",
Lists.<PostAggregator>newArrayList(
new FieldAccessPostAggregator("dep1", "dep1"),
new FieldAccessPostAggregator("dep2", "dep2")
)
);
List<PostAggregator> prunedAgg = AggregatorUtil.pruneDependentPostAgg(
Lists.newArrayList(
agg1,
dependency1,
agg2,
dependency2,
aggregator
), aggregator.getName()
);
Assert.assertEquals(Lists.newArrayList(dependency1, dependency2, aggregator), prunedAgg);
}
}

View File

@ -1176,4 +1176,75 @@ public class TopNQueryRunnerTest
TestHelper.assertExpectedResults(expectedResults, runner.run(query));
}
@Test
public void testTopNDependentPostAgg() {
TopNQuery query = new TopNQueryBuilder()
.dataSource(QueryRunnerTestHelper.dataSource)
.granularity(QueryRunnerTestHelper.allGran)
.dimension(providerDimension)
.metric(QueryRunnerTestHelper.dependentPostAggMetric)
.threshold(4)
.intervals(QueryRunnerTestHelper.fullOnInterval)
.aggregators(
Lists.<AggregatorFactory>newArrayList(
Iterables.concat(
QueryRunnerTestHelper.commonAggregators,
Lists.newArrayList(
new MaxAggregatorFactory("maxIndex", "index"),
new MinAggregatorFactory("minIndex", "index")
)
)
)
)
.postAggregators(
Arrays.<PostAggregator>asList(
QueryRunnerTestHelper.addRowsIndexConstant,
QueryRunnerTestHelper.dependentPostAgg
)
)
.build();
List<Result<TopNResultValue>> expectedResults = Arrays.asList(
new Result<TopNResultValue>(
new DateTime("2011-01-12T00:00:00.000Z"),
new TopNResultValue(
Arrays.<Map<String, Object>>asList(
ImmutableMap.<String, Object>builder()
.put(providerDimension, "total_market")
.put("rows", 186L)
.put("index", 215679.82879638672D)
.put("addRowsIndexConstant", 215866.82879638672D)
.put(QueryRunnerTestHelper.dependentPostAggMetric, 215867.82879638672D)
.put("uniques", QueryRunnerTestHelper.UNIQUES_2)
.put("maxIndex", 1743.9217529296875D)
.put("minIndex", 792.3260498046875D)
.build(),
ImmutableMap.<String, Object>builder()
.put(providerDimension, "upfront")
.put("rows", 186L)
.put("index", 192046.1060180664D)
.put("addRowsIndexConstant", 192233.1060180664D)
.put(QueryRunnerTestHelper.dependentPostAggMetric, 192234.1060180664D)
.put("uniques", QueryRunnerTestHelper.UNIQUES_2)
.put("maxIndex", 1870.06103515625D)
.put("minIndex", 545.9906005859375D)
.build(),
ImmutableMap.<String, Object>builder()
.put(providerDimension, "spot")
.put("rows", 837L)
.put("index", 95606.57232284546D)
.put("addRowsIndexConstant", 96444.57232284546D)
.put(QueryRunnerTestHelper.dependentPostAggMetric, 96445.57232284546D)
.put("uniques", QueryRunnerTestHelper.UNIQUES_9)
.put("maxIndex", 277.2735290527344D)
.put("minIndex", 59.02102279663086D)
.build()
)
)
)
);
TestHelper.assertExpectedResults(expectedResults, runner.run(query));
}
}