Merge pull request #483 from metamx/fix-aggregatefirst-topn-algo

fix NPE in aggregatrFirstTopnAlgo
This commit is contained in:
fjy 2014-04-16 12:20:39 -06:00
commit 2183c7ab08
10 changed files with 109 additions and 86 deletions

View File

@ -20,6 +20,8 @@
package io.druid.query.aggregation; package io.druid.query.aggregation;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.metamx.common.ISE;
import com.metamx.common.Pair;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedList; import java.util.LinkedList;
@ -53,4 +55,31 @@ public class AggregatorUtil
return rv; return rv;
} }
public static Pair<List<AggregatorFactory>, List<PostAggregator>> condensedAggregators(
List<AggregatorFactory> aggList,
List<PostAggregator> postAggList,
String metric
)
{
List<PostAggregator> condensedPostAggs = AggregatorUtil.pruneDependentPostAgg(
postAggList,
metric
);
// calculate dependent aggregators for these postAgg
Set<String> dependencySet = new HashSet<>();
dependencySet.add(metric);
for (PostAggregator postAggregator : condensedPostAggs) {
dependencySet.addAll(postAggregator.getDependentFields());
}
List<AggregatorFactory> condensedAggs = Lists.newArrayList();
for (AggregatorFactory aggregatorSpec : aggList) {
if (dependencySet.contains(aggregatorSpec.getName())) {
condensedAggs.add(aggregatorSpec);
}
}
return new Pair(condensedAggs, condensedPostAggs);
}
} }

View File

@ -19,10 +19,11 @@
package io.druid.query.topn; package io.druid.query.topn;
import com.google.common.collect.Lists;
import com.metamx.common.ISE; import com.metamx.common.ISE;
import com.metamx.common.Pair;
import io.druid.collections.StupidPool; import io.druid.collections.StupidPool;
import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.AggregatorUtil;
import io.druid.query.aggregation.PostAggregator; import io.druid.query.aggregation.PostAggregator;
import io.druid.segment.Capabilities; import io.druid.segment.Capabilities;
import io.druid.segment.Cursor; import io.druid.segment.Cursor;
@ -64,70 +65,27 @@ public class AggregateTopNMetricFirstAlgorithm implements TopNAlgorithm<int[], T
return new TopNParams(dimSelector, cursor, dimSelector.getValueCardinality(), Integer.MAX_VALUE); return new TopNParams(dimSelector, cursor, dimSelector.getValueCardinality(), Integer.MAX_VALUE);
} }
@Override
public TopNResultBuilder makeResultBuilder(TopNParams params)
{
return query.getTopNMetricSpec().getResultBuilder(
params.getCursor().getTime(),
query.getDimensionSpec(),
query.getThreshold(),
comparator,
query.getAggregatorSpecs(),
query.getPostAggregatorSpecs()
);
}
@Override @Override
public void run( public void run(
TopNParams params, TopNResultBuilder resultBuilder, int[] ints TopNParams params, TopNResultBuilder resultBuilder, int[] ints
) )
{ {
final TopNResultBuilder singleMetricResultBuilder = makeResultBuilder(params); final String metric = query.getTopNMetricSpec().getMetricName(query.getDimensionSpec());
final String metric; Pair<List<AggregatorFactory>, List<PostAggregator>> condensedAggPostAggPair = AggregatorUtil.condensedAggregators(
// ugly query.getAggregatorSpecs(),
TopNMetricSpec spec = query.getTopNMetricSpec(); query.getPostAggregatorSpecs(),
if (spec instanceof InvertedTopNMetricSpec metric
&& ((InvertedTopNMetricSpec) spec).getDelegate() instanceof NumericTopNMetricSpec) { );
metric = ((NumericTopNMetricSpec) ((InvertedTopNMetricSpec) spec).getDelegate()).getMetric();
} else if (spec instanceof NumericTopNMetricSpec) {
metric = ((NumericTopNMetricSpec) query.getTopNMetricSpec()).getMetric();
} else {
throw new ISE("WTF?! We are in AggregateTopNMetricFirstAlgorithm with a [%s] spec", spec.getClass().getName());
}
// Find either the aggregator or post aggregator to do the topN over if (condensedAggPostAggPair.lhs.isEmpty() && condensedAggPostAggPair.rhs.isEmpty()) {
List<AggregatorFactory> condensedAggs = Lists.newArrayList();
for (AggregatorFactory aggregatorSpec : query.getAggregatorSpecs()) {
if (aggregatorSpec.getName().equalsIgnoreCase(metric)) {
condensedAggs.add(aggregatorSpec);
break;
}
}
List<PostAggregator> condensedPostAggs = Lists.newArrayList();
if (condensedAggs.isEmpty()) {
for (PostAggregator postAggregator : query.getPostAggregatorSpecs()) {
if (postAggregator.getName().equalsIgnoreCase(metric)) {
condensedPostAggs.add(postAggregator);
// Add all dependent metrics
for (AggregatorFactory aggregatorSpec : query.getAggregatorSpecs()) {
if (postAggregator.getDependentFields().contains(aggregatorSpec.getName())) {
condensedAggs.add(aggregatorSpec);
}
}
break;
}
}
}
if (condensedAggs.isEmpty() && condensedPostAggs.isEmpty()) {
throw new ISE("WTF! Can't find the metric to do topN over?"); throw new ISE("WTF! Can't find the metric to do topN over?");
} }
// Run topN for only a single metric // Run topN for only a single metric
TopNQuery singleMetricQuery = new TopNQueryBuilder().copy(query) TopNQuery singleMetricQuery = new TopNQueryBuilder().copy(query)
.aggregators(condensedAggs) .aggregators(condensedAggPostAggPair.lhs)
.postAggregators(condensedPostAggs) .postAggregators(condensedAggPostAggPair.rhs)
.build(); .build();
final TopNResultBuilder singleMetricResultBuilder = BaseTopNAlgorithm.makeResultBuilder(params, singleMetricQuery);
PooledTopNAlgorithm singleMetricAlgo = new PooledTopNAlgorithm(capabilities, singleMetricQuery, bufferPool); PooledTopNAlgorithm singleMetricAlgo = new PooledTopNAlgorithm(capabilities, singleMetricQuery, bufferPool);
PooledTopNAlgorithm.PooledTopNParams singleMetricParam = null; PooledTopNAlgorithm.PooledTopNParams singleMetricParam = null;

View File

@ -28,6 +28,7 @@ import io.druid.segment.Cursor;
import io.druid.segment.DimensionSelector; import io.druid.segment.DimensionSelector;
import java.util.Arrays; import java.util.Arrays;
import java.util.Comparator;
import java.util.List; import java.util.List;
/** /**
@ -230,4 +231,18 @@ public abstract class BaseTopNAlgorithm<DimValSelector, DimValAggregateStore, Pa
return Pair.of(startIndex, endIndex); return Pair.of(startIndex, endIndex);
} }
} }
public static TopNResultBuilder makeResultBuilder(TopNParams params, TopNQuery query)
{
Comparator comparator = query.getTopNMetricSpec()
.getComparator(query.getAggregatorSpecs(), query.getPostAggregatorSpecs());
return query.getTopNMetricSpec().getResultBuilder(
params.getCursor().getTime(),
query.getDimensionSpec(),
query.getThreshold(),
comparator,
query.getAggregatorSpecs(),
query.getPostAggregatorSpecs()
);
}
} }

View File

@ -56,19 +56,6 @@ public class DimExtractionTopNAlgorithm extends BaseTopNAlgorithm<Aggregator[][]
return new TopNParams(dimSelector, cursor, dimSelector.getValueCardinality(), Integer.MAX_VALUE); return new TopNParams(dimSelector, cursor, dimSelector.getValueCardinality(), Integer.MAX_VALUE);
} }
@Override
public TopNResultBuilder makeResultBuilder(TopNParams params)
{
return query.getTopNMetricSpec().getResultBuilder(
params.getCursor().getTime(),
query.getDimensionSpec(),
query.getThreshold(),
comparator,
query.getAggregatorSpecs(),
query.getPostAggregatorSpecs()
);
}
@Override @Override
protected Aggregator[][] makeDimValSelector(TopNParams params, int numProcessed, int numToProcess) protected Aggregator[][] makeDimValSelector(TopNParams params, int numProcessed, int numToProcess)
{ {

View File

@ -114,18 +114,7 @@ public class PooledTopNAlgorithm
.build(); .build();
} }
@Override
public TopNResultBuilder makeResultBuilder(PooledTopNParams params)
{
return query.getTopNMetricSpec().getResultBuilder(
params.getCursor().getTime(),
query.getDimensionSpec(),
query.getThreshold(),
comparator,
query.getAggregatorSpecs(),
query.getPostAggregatorSpecs()
);
}
@Override @Override
protected int[] makeDimValSelector(PooledTopNParams params, int numProcessed, int numToProcess) protected int[] makeDimValSelector(PooledTopNParams params, int numProcessed, int numToProcess)

View File

@ -33,8 +33,6 @@ public interface TopNAlgorithm<DimValSelector, Parameters extends TopNParams>
public TopNParams makeInitParams(DimensionSelector dimSelector, Cursor cursor); public TopNParams makeInitParams(DimensionSelector dimSelector, Cursor cursor);
public TopNResultBuilder makeResultBuilder(Parameters params);
public void run( public void run(
Parameters params, Parameters params,
TopNResultBuilder resultBuilder, TopNResultBuilder resultBuilder,

View File

@ -24,6 +24,8 @@ import io.druid.query.Result;
import io.druid.segment.Cursor; import io.druid.segment.Cursor;
import io.druid.segment.DimensionSelector; import io.druid.segment.DimensionSelector;
import java.util.Comparator;
public class TopNMapFn implements Function<Cursor, Result<TopNResultValue>> public class TopNMapFn implements Function<Cursor, Result<TopNResultValue>>
{ {
private final TopNQuery query; private final TopNQuery query;
@ -52,7 +54,7 @@ public class TopNMapFn implements Function<Cursor, Result<TopNResultValue>>
try { try {
params = topNAlgorithm.makeInitParams(dimSelector, cursor); params = topNAlgorithm.makeInitParams(dimSelector, cursor);
TopNResultBuilder resultBuilder = topNAlgorithm.makeResultBuilder(params); TopNResultBuilder resultBuilder = BaseTopNAlgorithm.makeResultBuilder(params, query);
topNAlgorithm.run(params, resultBuilder, null); topNAlgorithm.run(params, resultBuilder, null);

View File

@ -81,7 +81,8 @@ public class QueryRunnerTestHelper
"+", "+",
Lists.newArrayList( Lists.newArrayList(
constant, constant,
new FieldAccessPostAggregator(addRowsIndexConstantMetric, addRowsIndexConstantMetric) new FieldAccessPostAggregator(addRowsIndexConstantMetric, addRowsIndexConstantMetric),
new FieldAccessPostAggregator("rows", "rows")
) )
); );

View File

@ -19,15 +19,22 @@
package io.druid.query.aggregation; package io.druid.query.aggregation;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.metamx.common.Pair;
import io.druid.query.QueryRunnerTestHelper;
import io.druid.query.aggregation.post.ArithmeticPostAggregator; import io.druid.query.aggregation.post.ArithmeticPostAggregator;
import io.druid.query.aggregation.post.ConstantPostAggregator; import io.druid.query.aggregation.post.ConstantPostAggregator;
import io.druid.query.aggregation.post.FieldAccessPostAggregator; import io.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import static io.druid.query.QueryRunnerTestHelper.dependentPostAggMetric;
public class AggregatorUtilTest public class AggregatorUtilTest
{ {
@ -101,4 +108,41 @@ public class AggregatorUtilTest
Assert.assertEquals(Lists.newArrayList(dependency1, aggregator), prunedAgg); Assert.assertEquals(Lists.newArrayList(dependency1, aggregator), prunedAgg);
} }
@Test
public void testCondenseAggregators()
{
ArrayList<AggregatorFactory> aggregatorFactories = Lists.<AggregatorFactory>newArrayList(
Iterables.concat(
QueryRunnerTestHelper.commonAggregators,
Lists.newArrayList(
new MaxAggregatorFactory("maxIndex", "index"),
new MinAggregatorFactory("minIndex", "index")
)
)
);
List<PostAggregator> postAggregatorList = Arrays.<PostAggregator>asList(
QueryRunnerTestHelper.addRowsIndexConstant,
QueryRunnerTestHelper.dependentPostAgg
);
Pair<List<AggregatorFactory>, List<PostAggregator>> aggregatorsPair = AggregatorUtil.condensedAggregators(
aggregatorFactories,
postAggregatorList,
dependentPostAggMetric
);
// verify aggregators
Assert.assertEquals(
Lists.newArrayList(QueryRunnerTestHelper.rowsCount, QueryRunnerTestHelper.indexDoubleSum),
aggregatorsPair.lhs
);
Assert.assertEquals(
Lists.newArrayList(
QueryRunnerTestHelper.addRowsIndexConstant,
QueryRunnerTestHelper.dependentPostAgg
), aggregatorsPair.rhs
);
}
} }

View File

@ -1215,7 +1215,7 @@ public class TopNQueryRunnerTest
.put("rows", 186L) .put("rows", 186L)
.put("index", 215679.82879638672D) .put("index", 215679.82879638672D)
.put("addRowsIndexConstant", 215866.82879638672D) .put("addRowsIndexConstant", 215866.82879638672D)
.put(QueryRunnerTestHelper.dependentPostAggMetric, 215867.82879638672D) .put(QueryRunnerTestHelper.dependentPostAggMetric, 216053.82879638672D)
.put("uniques", QueryRunnerTestHelper.UNIQUES_2) .put("uniques", QueryRunnerTestHelper.UNIQUES_2)
.put("maxIndex", 1743.9217529296875D) .put("maxIndex", 1743.9217529296875D)
.put("minIndex", 792.3260498046875D) .put("minIndex", 792.3260498046875D)
@ -1225,7 +1225,7 @@ public class TopNQueryRunnerTest
.put("rows", 186L) .put("rows", 186L)
.put("index", 192046.1060180664D) .put("index", 192046.1060180664D)
.put("addRowsIndexConstant", 192233.1060180664D) .put("addRowsIndexConstant", 192233.1060180664D)
.put(QueryRunnerTestHelper.dependentPostAggMetric, 192234.1060180664D) .put(QueryRunnerTestHelper.dependentPostAggMetric, 192420.1060180664D)
.put("uniques", QueryRunnerTestHelper.UNIQUES_2) .put("uniques", QueryRunnerTestHelper.UNIQUES_2)
.put("maxIndex", 1870.06103515625D) .put("maxIndex", 1870.06103515625D)
.put("minIndex", 545.9906005859375D) .put("minIndex", 545.9906005859375D)
@ -1235,7 +1235,7 @@ public class TopNQueryRunnerTest
.put("rows", 837L) .put("rows", 837L)
.put("index", 95606.57232284546D) .put("index", 95606.57232284546D)
.put("addRowsIndexConstant", 96444.57232284546D) .put("addRowsIndexConstant", 96444.57232284546D)
.put(QueryRunnerTestHelper.dependentPostAggMetric, 96445.57232284546D) .put(QueryRunnerTestHelper.dependentPostAggMetric, 97282.57232284546D)
.put("uniques", QueryRunnerTestHelper.UNIQUES_9) .put("uniques", QueryRunnerTestHelper.UNIQUES_9)
.put("maxIndex", 277.2735290527344D) .put("maxIndex", 277.2735290527344D)
.put("minIndex", 59.02102279663086D) .put("minIndex", 59.02102279663086D)