Aggregations: Full path validation for pipeline aggregations

Previously only the first aggregation in a buckets_path was check to make sure the aggregation existed. Now the whole path is checked to ensure an aggregation exists at each element in the buckets_path

Closes #12360
This commit is contained in:
Colin Goodheart-Smithe 2015-08-03 09:06:14 +01:00
parent 20ed7c1724
commit ade3881152
3 changed files with 102 additions and 18 deletions

View File

@ -22,6 +22,7 @@ import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorFactory;
import org.elasticsearch.search.aggregations.support.AggregationContext;
import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.search.aggregations.support.AggregationPath.PathElement;
import java.io.IOException;
import java.util.ArrayList;
@ -162,40 +163,79 @@ public class AggregatorFactories {
for (PipelineAggregatorFactory factory : pipelineAggregatorFactories) {
pipelineAggregatorFactoriesMap.put(factory.getName(), factory);
}
Set<String> aggFactoryNames = new HashSet<>();
Map<String, AggregatorFactory> aggFactoriesMap = new HashMap<>();
for (AggregatorFactory aggFactory : aggFactories) {
aggFactoryNames.add(aggFactory.name);
aggFactoriesMap.put(aggFactory.name, aggFactory);
}
List<PipelineAggregatorFactory> orderedPipelineAggregatorrs = new LinkedList<>();
List<PipelineAggregatorFactory> unmarkedFactories = new ArrayList<PipelineAggregatorFactory>(pipelineAggregatorFactories);
Set<PipelineAggregatorFactory> temporarilyMarked = new HashSet<PipelineAggregatorFactory>();
while (!unmarkedFactories.isEmpty()) {
PipelineAggregatorFactory factory = unmarkedFactories.get(0);
resolvePipelineAggregatorOrder(aggFactoryNames, pipelineAggregatorFactoriesMap, orderedPipelineAggregatorrs, unmarkedFactories, temporarilyMarked, factory);
resolvePipelineAggregatorOrder(aggFactoriesMap, pipelineAggregatorFactoriesMap, orderedPipelineAggregatorrs,
unmarkedFactories, temporarilyMarked, factory);
}
return orderedPipelineAggregatorrs;
}
private void resolvePipelineAggregatorOrder(Set<String> aggFactoryNames, Map<String, PipelineAggregatorFactory> pipelineAggregatorFactoriesMap,
private void resolvePipelineAggregatorOrder(Map<String, AggregatorFactory> aggFactoriesMap,
Map<String, PipelineAggregatorFactory> pipelineAggregatorFactoriesMap,
List<PipelineAggregatorFactory> orderedPipelineAggregators, List<PipelineAggregatorFactory> unmarkedFactories, Set<PipelineAggregatorFactory> temporarilyMarked,
PipelineAggregatorFactory factory) {
if (temporarilyMarked.contains(factory)) {
throw new IllegalStateException("Cyclical dependancy found with pipeline aggregator [" + factory.getName() + "]");
throw new IllegalArgumentException("Cyclical dependancy found with pipeline aggregator [" + factory.getName() + "]");
} else if (unmarkedFactories.contains(factory)) {
temporarilyMarked.add(factory);
String[] bucketsPaths = factory.getBucketsPaths();
for (String bucketsPath : bucketsPaths) {
List<String> bucketsPathElements = AggregationPath.parse(bucketsPath).getPathElementsAsStringList();
String firstAggName = bucketsPathElements.get(0);
if (bucketsPath.equals("_count") || bucketsPath.equals("_key") || aggFactoryNames.contains(firstAggName)) {
List<AggregationPath.PathElement> bucketsPathElements = AggregationPath.parse(bucketsPath).getPathElements();
String firstAggName = bucketsPathElements.get(0).name;
if (bucketsPath.equals("_count") || bucketsPath.equals("_key")) {
continue;
} else if (aggFactoriesMap.containsKey(firstAggName)) {
AggregatorFactory aggFactory = aggFactoriesMap.get(firstAggName);
for (int i = 1; i < bucketsPathElements.size(); i++) {
PathElement pathElement = bucketsPathElements.get(i);
String aggName = pathElement.name;
if ((i == bucketsPathElements.size() - 1) && (aggName.equalsIgnoreCase("_key") || aggName.equals("_count"))) {
break;
} else {
// Check the non-pipeline sub-aggregator
// factories
AggregatorFactory[] subFactories = aggFactory.factories.factories;
boolean foundSubFactory = false;
for (AggregatorFactory subFactory : subFactories) {
if (aggName.equals(subFactory.name)) {
aggFactory = subFactory;
foundSubFactory = true;
break;
}
}
// Check the pipeline sub-aggregator factories
if (!foundSubFactory && (i == bucketsPathElements.size() - 1)) {
List<PipelineAggregatorFactory> subPipelineFactories = aggFactory.factories.pipelineAggregatorFactories;
for (PipelineAggregatorFactory subFactory : subPipelineFactories) {
if (aggName.equals(subFactory.name())) {
foundSubFactory = true;
break;
}
}
}
if (!foundSubFactory) {
throw new IllegalArgumentException("No aggregation [" + aggName + "] found for path [" + bucketsPath
+ "]");
}
}
}
continue;
} else {
PipelineAggregatorFactory matchingFactory = pipelineAggregatorFactoriesMap.get(firstAggName);
if (matchingFactory != null) {
resolvePipelineAggregatorOrder(aggFactoryNames, pipelineAggregatorFactoriesMap, orderedPipelineAggregators, unmarkedFactories,
resolvePipelineAggregatorOrder(aggFactoriesMap, pipelineAggregatorFactoriesMap, orderedPipelineAggregators,
unmarkedFactories,
temporarilyMarked, matchingFactory);
} else {
throw new IllegalStateException("No aggregation found for path [" + bucketsPath + "]");
throw new IllegalArgumentException("No aggregation found for path [" + bucketsPath + "]");
}
}
}

View File

@ -49,6 +49,10 @@ public abstract class PipelineAggregatorFactory {
this.bucketsPaths = bucketsPaths;
}
public String name() {
return name;
}
/**
* Validates the state of this factory (makes sure the factory is properly
* configured)

View File

@ -19,15 +19,18 @@
package org.elasticsearch.search.aggregations.pipeline;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.bucket.histogram.Histogram;
import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram;
import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram.Bucket;
import org.elasticsearch.search.aggregations.metrics.stats.Stats;
import org.elasticsearch.search.aggregations.metrics.sum.Sum;
import org.elasticsearch.search.aggregations.pipeline.SimpleValue;
import org.elasticsearch.search.aggregations.pipeline.BucketHelpers.GapPolicy;
import org.elasticsearch.search.aggregations.pipeline.derivative.Derivative;
import org.elasticsearch.search.aggregations.support.AggregationPath;
@ -39,12 +42,13 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.derivative;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
import static org.elasticsearch.search.aggregations.AggregationBuilders.filters;
import static org.elasticsearch.search.aggregations.AggregationBuilders.histogram;
import static org.elasticsearch.search.aggregations.AggregationBuilders.stats;
import static org.elasticsearch.search.aggregations.AggregationBuilders.sum;
import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.derivative;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
import static org.hamcrest.Matchers.closeTo;
@ -228,7 +232,7 @@ public class DerivativeTests extends ElasticsearchIntegrationTest {
Derivative docCountDeriv = bucket.getAggregations().get("deriv");
if (i > 0) {
assertThat(docCountDeriv, notNullValue());
assertThat(docCountDeriv.value(), closeTo((double) (firstDerivValueCounts[i - 1]), 0.00001));
assertThat(docCountDeriv.value(), closeTo((firstDerivValueCounts[i - 1]), 0.00001));
assertThat(docCountDeriv.normalizedValue(), closeTo((double) (firstDerivValueCounts[i - 1]) / 5, 0.00001));
} else {
assertThat(docCountDeriv, nullValue());
@ -236,7 +240,7 @@ public class DerivativeTests extends ElasticsearchIntegrationTest {
Derivative docCount2ndDeriv = bucket.getAggregations().get("2nd_deriv");
if (i > 1) {
assertThat(docCount2ndDeriv, notNullValue());
assertThat(docCount2ndDeriv.value(), closeTo((double) (secondDerivValueCounts[i - 2]), 0.00001));
assertThat(docCount2ndDeriv.value(), closeTo((secondDerivValueCounts[i - 2]), 0.00001));
assertThat(docCount2ndDeriv.normalizedValue(), closeTo((double) (secondDerivValueCounts[i - 2]) * 2, 0.00001));
} else {
assertThat(docCount2ndDeriv, nullValue());
@ -596,6 +600,42 @@ public class DerivativeTests extends ElasticsearchIntegrationTest {
}
}
@Test
public void singleValueAggDerivative_invalidPath() throws Exception {
try {
client().prepareSearch("idx")
.addAggregation(
histogram("histo")
.field(SINGLE_VALUED_FIELD_NAME)
.interval(interval)
.subAggregation(
filters("filters").filter(QueryBuilders.termQuery("tag", "foo")).subAggregation(
sum("sum").field(SINGLE_VALUED_FIELD_NAME)))
.subAggregation(derivative("deriv").setBucketsPaths("filters>get>sum"))).execute().actionGet();
fail("Expected an Exception but didn't get one");
} catch (Exception e) {
Throwable cause = ExceptionsHelper.unwrapCause(e);
if (cause == null) {
throw e;
} else if (cause instanceof SearchPhaseExecutionException) {
ElasticsearchException[] rootCauses = ((SearchPhaseExecutionException) cause).guessRootCauses();
// If there is more than one root cause then something
// unexpected happened and we should re-throw the original
// exception
if (rootCauses.length > 1) {
throw e;
}
ElasticsearchException rootCauseWrapper = rootCauses[0];
Throwable rootCause = rootCauseWrapper.getCause();
if (rootCause == null || !(rootCause instanceof IllegalArgumentException)) {
throw e;
}
} else {
throw e;
}
}
}
private void checkBucketKeyAndDocCount(final String msg, final Histogram.Bucket bucket, final long expectedKey,
final long expectedDocCount) {
assertThat(msg, bucket, notNullValue());