diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/DoubleTerms.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/DoubleTerms.java index 0ff9a98f7f9..59957ef17d8 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/DoubleTerms.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/DoubleTerms.java @@ -111,11 +111,13 @@ public class DoubleTerms extends InternalTerms { public InternalTerms reduce(ReduceContext reduceContext) { List aggregations = reduceContext.aggregations(); if (aggregations.size() == 1) { - return (InternalTerms) aggregations.get(0); + InternalTerms terms = (InternalTerms) aggregations.get(0); + terms.trimExcessEntries(); + return terms; } InternalTerms reduced = null; - Recycler.V>> buckets = reduceContext.cacheRecycler().doubleObjectMap(-1); + Recycler.V>> buckets = null; for (InternalAggregation aggregation : aggregations) { InternalTerms terms = (InternalTerms) aggregation; if (terms instanceof UnmappedTerms) { @@ -124,8 +126,10 @@ public class DoubleTerms extends InternalTerms { if (reduced == null) { reduced = terms; } + if (buckets == null) { + buckets = reduceContext.cacheRecycler().doubleObjectMap(terms.buckets.size()); + } for (Terms.Bucket bucket : terms.buckets) { - List existingBuckets = buckets.v().get(((Bucket) bucket).term); if (existingBuckets == null) { existingBuckets = new ArrayList(aggregations.size()); diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/DoubleTermsAggregator.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/DoubleTermsAggregator.java index 4bde5c74672..0c93c2222ac 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/DoubleTermsAggregator.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/DoubleTermsAggregator.java @@ -41,15 +41,17 @@ public class DoubleTermsAggregator extends BucketsAggregator { private final InternalOrder order; private final int requiredSize; + private final int shardSize; private final NumericValuesSource valuesSource; private final LongHash bucketOrds; public DoubleTermsAggregator(String name, AggregatorFactories factories, NumericValuesSource valuesSource, - InternalOrder order, int requiredSize, AggregationContext aggregationContext, Aggregator parent) { + InternalOrder order, int requiredSize, int shardSize, AggregationContext aggregationContext, Aggregator parent) { super(name, BucketAggregationMode.PER_BUCKET, factories, INITIAL_CAPACITY, aggregationContext, parent); this.valuesSource = valuesSource; this.order = order; this.requiredSize = requiredSize; + this.shardSize = shardSize; bucketOrds = new LongHash(INITIAL_CAPACITY); } @@ -89,7 +91,7 @@ public class DoubleTermsAggregator extends BucketsAggregator { @Override public DoubleTerms buildAggregation(long owningBucketOrdinal) { assert owningBucketOrdinal == 0; - final int size = (int) Math.min(bucketOrds.size(), requiredSize); + final int size = (int) Math.min(bucketOrds.size(), shardSize); BucketPriorityQueue ordered = new BucketPriorityQueue(size, order.comparator()); OrdinalBucket spare = null; diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalTerms.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalTerms.java index 91bcea548ee..a1fbe96408d 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalTerms.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalTerms.java @@ -129,13 +129,15 @@ public abstract class InternalTerms extends InternalAggregation implements Terms public InternalTerms reduce(ReduceContext reduceContext) { List aggregations = reduceContext.aggregations(); if (aggregations.size() == 1) { - return (InternalTerms) aggregations.get(0); + InternalTerms terms = (InternalTerms) aggregations.get(0); + terms.trimExcessEntries(); + return terms; } InternalTerms reduced = null; // TODO: would it be better to use a hppc map and then directly work on the backing array instead of using a PQ? - Map> buckets = new HashMap>(requiredSize); + Map> buckets = null; for (InternalAggregation aggregation : aggregations) { InternalTerms terms = (InternalTerms) aggregation; if (terms instanceof UnmappedTerms) { @@ -144,6 +146,9 @@ public abstract class InternalTerms extends InternalAggregation implements Terms if (reduced == null) { reduced = terms; } + if (buckets == null) { + buckets = new HashMap>(terms.buckets.size()); + } for (Bucket bucket : terms.buckets) { List existingBuckets = buckets.get(bucket.getKey()); if (existingBuckets == null) { @@ -173,4 +178,23 @@ public abstract class InternalTerms extends InternalAggregation implements Terms return reduced; } + protected void trimExcessEntries() { + if (requiredSize >= buckets.size()) { + return; + } + + if (buckets instanceof List) { + buckets = ((List) buckets).subList(0, requiredSize); + return; + } + + int i = 0; + for (Iterator iter = buckets.iterator(); iter.hasNext();) { + iter.next(); + if (i++ >= requiredSize) { + iter.remove(); + } + } + } + } diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/LongTerms.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/LongTerms.java index 2c701390527..7f4c10ee9d2 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/LongTerms.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/LongTerms.java @@ -109,11 +109,13 @@ public class LongTerms extends InternalTerms { public InternalTerms reduce(ReduceContext reduceContext) { List aggregations = reduceContext.aggregations(); if (aggregations.size() == 1) { - return (InternalTerms) aggregations.get(0); + InternalTerms terms = (InternalTerms) aggregations.get(0); + terms.trimExcessEntries(); + return terms; } InternalTerms reduced = null; - Recycler.V>> buckets = reduceContext.cacheRecycler().longObjectMap(-1); + Recycler.V>> buckets = null; for (InternalAggregation aggregation : aggregations) { InternalTerms terms = (InternalTerms) aggregation; if (terms instanceof UnmappedTerms) { @@ -122,6 +124,9 @@ public class LongTerms extends InternalTerms { if (reduced == null) { reduced = terms; } + if (buckets == null) { + buckets = reduceContext.cacheRecycler().longObjectMap(terms.buckets.size()); + } for (Terms.Bucket bucket : terms.buckets) { List existingBuckets = buckets.v().get(((Bucket) bucket).term); if (existingBuckets == null) { diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/LongTermsAggregator.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/LongTermsAggregator.java index 9536021e335..c6837f013fc 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/LongTermsAggregator.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/LongTermsAggregator.java @@ -41,15 +41,17 @@ public class LongTermsAggregator extends BucketsAggregator { private final InternalOrder order; private final int requiredSize; + private final int shardSize; private final NumericValuesSource valuesSource; private final LongHash bucketOrds; public LongTermsAggregator(String name, AggregatorFactories factories, NumericValuesSource valuesSource, - InternalOrder order, int requiredSize, AggregationContext aggregationContext, Aggregator parent) { + InternalOrder order, int requiredSize, int shardSize, AggregationContext aggregationContext, Aggregator parent) { super(name, BucketAggregationMode.PER_BUCKET, factories, INITIAL_CAPACITY, aggregationContext, parent); this.valuesSource = valuesSource; this.order = order; this.requiredSize = requiredSize; + this.shardSize = shardSize; bucketOrds = new LongHash(INITIAL_CAPACITY); } @@ -88,7 +90,7 @@ public class LongTermsAggregator extends BucketsAggregator { @Override public LongTerms buildAggregation(long owningBucketOrdinal) { assert owningBucketOrdinal == 0; - final int size = (int) Math.min(bucketOrds.size(), requiredSize); + final int size = (int) Math.min(bucketOrds.size(), shardSize); BucketPriorityQueue ordered = new BucketPriorityQueue(size, order.comparator()); OrdinalBucket spare = null; diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/StringTermsAggregator.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/StringTermsAggregator.java index b539cb24976..8aa3a98a153 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/StringTermsAggregator.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/StringTermsAggregator.java @@ -43,15 +43,17 @@ public class StringTermsAggregator extends BucketsAggregator { private final ValuesSource valuesSource; private final InternalOrder order; private final int requiredSize; + private final int shardSize; private final BytesRefHash bucketOrds; public StringTermsAggregator(String name, AggregatorFactories factories, ValuesSource valuesSource, - InternalOrder order, int requiredSize, AggregationContext aggregationContext, Aggregator parent) { + InternalOrder order, int requiredSize, int shardSize, AggregationContext aggregationContext, Aggregator parent) { super(name, BucketAggregationMode.PER_BUCKET, factories, INITIAL_CAPACITY, aggregationContext, parent); this.valuesSource = valuesSource; this.order = order; this.requiredSize = requiredSize; + this.shardSize = shardSize; bucketOrds = new BytesRefHash(); } @@ -91,7 +93,7 @@ public class StringTermsAggregator extends BucketsAggregator { @Override public StringTerms buildAggregation(long owningBucketOrdinal) { assert owningBucketOrdinal == 0; - final int size = Math.min(bucketOrds.size(), requiredSize); + final int size = Math.min(bucketOrds.size(), shardSize); BucketPriorityQueue ordered = new BucketPriorityQueue(size, order.comparator()); OrdinalBucket spare = null; diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorFactory.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorFactory.java index db89aa0714f..42f6fb520a3 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorFactory.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorFactory.java @@ -22,11 +22,11 @@ package org.elasticsearch.search.aggregations.bucket.terms; import org.elasticsearch.search.aggregations.AggregationExecutionException; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.search.aggregations.support.ValueSourceAggregatorFactory; import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSourceConfig; import org.elasticsearch.search.aggregations.support.bytes.BytesValuesSource; import org.elasticsearch.search.aggregations.support.numeric.NumericValuesSource; -import org.elasticsearch.search.aggregations.support.ValueSourceAggregatorFactory; /** * @@ -35,11 +35,13 @@ public class TermsAggregatorFactory extends ValueSourceAggregatorFactory { private final InternalOrder order; private final int requiredSize; + private final int shardSize; - public TermsAggregatorFactory(String name, ValuesSourceConfig valueSourceConfig, InternalOrder order, int requiredSize) { + public TermsAggregatorFactory(String name, ValuesSourceConfig valueSourceConfig, InternalOrder order, int requiredSize, int shardSize) { super(name, StringTerms.TYPE.name(), valueSourceConfig); this.order = order; this.requiredSize = requiredSize; + this.shardSize = shardSize; } @Override @@ -50,14 +52,14 @@ public class TermsAggregatorFactory extends ValueSourceAggregatorFactory { @Override protected Aggregator create(ValuesSource valuesSource, long expectedBucketsCount, AggregationContext aggregationContext, Aggregator parent) { if (valuesSource instanceof BytesValuesSource) { - return new StringTermsAggregator(name, factories, valuesSource, order, requiredSize, aggregationContext, parent); + return new StringTermsAggregator(name, factories, valuesSource, order, requiredSize, shardSize, aggregationContext, parent); } if (valuesSource instanceof NumericValuesSource) { if (((NumericValuesSource) valuesSource).isFloatingPoint()) { - return new DoubleTermsAggregator(name, factories, (NumericValuesSource) valuesSource, order, requiredSize, aggregationContext, parent); + return new DoubleTermsAggregator(name, factories, (NumericValuesSource) valuesSource, order, requiredSize, shardSize, aggregationContext, parent); } - return new LongTermsAggregator(name, factories, (NumericValuesSource) valuesSource, order, requiredSize, aggregationContext, parent); + return new LongTermsAggregator(name, factories, (NumericValuesSource) valuesSource, order, requiredSize, shardSize, aggregationContext, parent); } throw new AggregationExecutionException("terms aggregation cannot be applied to field [" + valuesSourceConfig.fieldContext().field() + diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsBuilder.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsBuilder.java index 0d500fb3d3a..8aea5c0d6a9 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsBuilder.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsBuilder.java @@ -12,6 +12,7 @@ import java.util.Locale; public class TermsBuilder extends ValuesSourceAggregationBuilder { private int size = -1; + private int shardSize = -1; private Terms.ValueType valueType; private Terms.Order order; @@ -24,6 +25,11 @@ public class TermsBuilder extends ValuesSourceAggregationBuilder { return this; } + public TermsBuilder shardSize(int shardSize) { + this.shardSize = shardSize; + return this; + } + public TermsBuilder valueType(Terms.ValueType valueType) { this.valueType = valueType; return this; @@ -39,6 +45,9 @@ public class TermsBuilder extends ValuesSourceAggregationBuilder { if (size >=0) { builder.field("size", size); } + if (shardSize >= 0) { + builder.field("shard_size", shardSize); + } if (valueType != null) { builder.field("value_type", valueType.name().toLowerCase(Locale.ROOT)); } diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsParser.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsParser.java index 5f446b9514c..1da284415a8 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsParser.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsParser.java @@ -27,6 +27,7 @@ import org.elasticsearch.index.mapper.core.DateFieldMapper; import org.elasticsearch.index.mapper.ip.IpFieldMapper; import org.elasticsearch.script.SearchScript; import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorFactory; import org.elasticsearch.search.aggregations.support.FieldContext; import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSourceConfig; @@ -34,7 +35,6 @@ import org.elasticsearch.search.aggregations.support.bytes.BytesValuesSource; import org.elasticsearch.search.aggregations.support.numeric.NumericValuesSource; import org.elasticsearch.search.aggregations.support.numeric.ValueFormatter; import org.elasticsearch.search.aggregations.support.numeric.ValueParser; -import org.elasticsearch.search.aggregations.AggregatorFactory; import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; @@ -62,6 +62,7 @@ public class TermsParser implements Aggregator.Parser { Map scriptParams = null; Terms.ValueType valueType = null; int requiredSize = 10; + int shardSize = -1; String orderKey = "_count"; boolean orderAsc = false; String format = null; @@ -92,6 +93,8 @@ public class TermsParser implements Aggregator.Parser { } else if (token == XContentParser.Token.VALUE_NUMBER) { if ("size".equals(currentFieldName)) { requiredSize = parser.intValue(); + } else if ("shard_size".equals(currentFieldName) || "shardSize".equals(currentFieldName)) { + shardSize = parser.intValue(); } } else if (token == XContentParser.Token.START_OBJECT) { if ("params".equals(currentFieldName)) { @@ -110,6 +113,11 @@ public class TermsParser implements Aggregator.Parser { } } + // shard_size cannot be smaller than size as we need to at least fetch entries from every shards in order to return + if (shardSize < requiredSize) { + shardSize = requiredSize; + } + InternalOrder order = resolveOrder(orderKey, orderAsc); SearchScript searchScript = null; if (script != null) { @@ -131,14 +139,14 @@ public class TermsParser implements Aggregator.Parser { if (!assumeUnique) { config.ensureUnique(true); } - return new TermsAggregatorFactory(aggregationName, config, order, requiredSize); + return new TermsAggregatorFactory(aggregationName, config, order, requiredSize, shardSize); } FieldMapper mapper = context.smartNameFieldMapper(field); if (mapper == null) { ValuesSourceConfig config = new ValuesSourceConfig(BytesValuesSource.class); config.unmapped(true); - return new TermsAggregatorFactory(aggregationName, config, order, requiredSize); + return new TermsAggregatorFactory(aggregationName, config, order, requiredSize, shardSize); } IndexFieldData indexFieldData = context.fieldData().getForField(mapper); @@ -180,7 +188,7 @@ public class TermsParser implements Aggregator.Parser { config.ensureUnique(true); } - return new TermsAggregatorFactory(aggregationName, config, order, requiredSize); + return new TermsAggregatorFactory(aggregationName, config, order, requiredSize, shardSize); } static InternalOrder resolveOrder(String key, boolean asc) { diff --git a/src/main/java/org/elasticsearch/search/facet/terms/TermsFacetParser.java b/src/main/java/org/elasticsearch/search/facet/terms/TermsFacetParser.java index 5fb3bb5d6ab..2194c070fa5 100644 --- a/src/main/java/org/elasticsearch/search/facet/terms/TermsFacetParser.java +++ b/src/main/java/org/elasticsearch/search/facet/terms/TermsFacetParser.java @@ -130,7 +130,7 @@ public class TermsFacetParser extends AbstractComponent implements FacetParser { script = parser.text(); } else if ("size".equals(currentFieldName)) { size = parser.intValue(); - } else if ("shard_size".equals(currentFieldName)) { + } else if ("shard_size".equals(currentFieldName) || "shardSize".equals(currentFieldName)) { shardSize = parser.intValue(); } else if ("all_terms".equals(currentFieldName) || "allTerms".equals(currentFieldName)) { allTerms = parser.booleanValue(); diff --git a/src/test/java/org/elasticsearch/search/aggregations/bucket/ShardSizeTermsTests.java b/src/test/java/org/elasticsearch/search/aggregations/bucket/ShardSizeTermsTests.java new file mode 100644 index 00000000000..11a13976819 --- /dev/null +++ b/src/test/java/org/elasticsearch/search/aggregations/bucket/ShardSizeTermsTests.java @@ -0,0 +1,362 @@ +/* + * Licensed to ElasticSearch and Shay Banon under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. ElasticSearch licenses this + * file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.bucket; + +import com.google.common.collect.ImmutableMap; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.settings.ImmutableSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.test.ElasticsearchIntegrationTest; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery; +import static org.elasticsearch.search.aggregations.AggregationBuilders.terms; +import static org.elasticsearch.test.ElasticsearchIntegrationTest.ClusterScope; +import static org.elasticsearch.test.ElasticsearchIntegrationTest.Scope; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +/** + * + */ +@ClusterScope(scope = Scope.TEST) +public class ShardSizeTermsTests extends ElasticsearchIntegrationTest { + + /** + * to properly test the effect/functionality of shard_size, we need to force having 2 shards and also + * control the routing such that certain documents will end on each shard. Using "djb" routing hash + ignoring the + * doc type when hashing will ensure that docs with routing value "1" will end up in a different shard than docs with + * routing value "2". + */ + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return ImmutableSettings.builder() + .put("index.number_of_shards", 2) + .put("index.number_of_replicas", 0) + .put("cluster.routing.operation.hash.type", "djb") + .put("cluster.routing.operation.use_type", "false") + .build(); + } + + @Test + public void noShardSize_string() throws Exception { + + client().admin().indices().prepareCreate("idx") + .addMapping("type", "key", "type=string,index=not_analyzed") + .execute().actionGet(); + + indexData(); + + SearchResponse response = client().prepareSearch("idx").setTypes("type") + .setQuery(matchAllQuery()) + .addAggregation(terms("keys").field("key").size(3).order(Terms.Order.COUNT_DESC)) + .execute().actionGet(); + + Terms terms = response.getAggregations().get("keys"); + Collection buckets = terms.buckets(); + assertThat(buckets.size(), equalTo(3)); + Map expected = ImmutableMap.builder() + .put("1", 8l) + .put("3", 8l) + .put("2", 4l) + .build(); + for (Terms.Bucket bucket : buckets) { + assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKey().string()))); + } + } + + @Test + public void withShardSize_string() throws Exception { + + client().admin().indices().prepareCreate("idx") + .addMapping("type", "key", "type=string,index=not_analyzed") + .execute().actionGet(); + + indexData(); + + SearchResponse response = client().prepareSearch("idx").setTypes("type") + .setQuery(matchAllQuery()) + .addAggregation(terms("keys").field("key").size(3).shardSize(5).order(Terms.Order.COUNT_DESC)) + .execute().actionGet(); + + Terms terms = response.getAggregations().get("keys"); + Collection buckets = terms.buckets(); + assertThat(buckets.size(), equalTo(3)); // we still only return 3 entries (based on the 'size' param) + Map expected = ImmutableMap.builder() + .put("1", 8l) + .put("3", 8l) + .put("2", 5l) // <-- count is now fixed + .build(); + for (Terms.Bucket bucket : buckets) { + assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKey().string()))); + } + } + + @Test + public void withShardSize_string_singleShard() throws Exception { + + client().admin().indices().prepareCreate("idx") + .addMapping("type", "key", "type=string,index=not_analyzed") + .execute().actionGet(); + + indexData(); + + SearchResponse response = client().prepareSearch("idx").setTypes("type").setRouting("1") + .setQuery(matchAllQuery()) + .addAggregation(terms("keys").field("key").size(3).shardSize(5).order(Terms.Order.COUNT_DESC)) + .execute().actionGet(); + + Terms terms = response.getAggregations().get("keys"); + Collection buckets = terms.buckets(); + assertThat(buckets.size(), equalTo(3)); // we still only return 3 entries (based on the 'size' param) + Map expected = ImmutableMap.builder() + .put("1", 5l) + .put("2", 4l) + .put("3", 3l) // <-- count is now fixed + .build(); + for (Terms.Bucket bucket: buckets) { + assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKey().string()))); + } + } + + @Test + public void noShardSize_long() throws Exception { + + client().admin().indices().prepareCreate("idx") + .addMapping("type", "key", "type=long") + .execute().actionGet(); + + indexData(); + + SearchResponse response = client().prepareSearch("idx").setTypes("type") + .setQuery(matchAllQuery()) + .addAggregation(terms("keys").field("key").size(3).order(Terms.Order.COUNT_DESC)) + .execute().actionGet(); + + Terms terms = response.getAggregations().get("keys"); + Collection buckets = terms.buckets(); + assertThat(buckets.size(), equalTo(3)); + Map expected = ImmutableMap.builder() + .put(1, 8l) + .put(3, 8l) + .put(2, 4l) + .build(); + for (Terms.Bucket bucket : buckets) { + assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKeyAsNumber().intValue()))); + } + } + + @Test + public void withShardSize_long() throws Exception { + + client().admin().indices().prepareCreate("idx") + .addMapping("type", "key", "type=long") + .execute().actionGet(); + + indexData(); + + SearchResponse response = client().prepareSearch("idx").setTypes("type") + .setQuery(matchAllQuery()) + .addAggregation(terms("keys").field("key").size(3).shardSize(5).order(Terms.Order.COUNT_DESC)) + .execute().actionGet(); + + Terms terms = response.getAggregations().get("keys"); + Collection buckets = terms.buckets(); + assertThat(buckets.size(), equalTo(3)); // we still only return 3 entries (based on the 'size' param) + Map expected = ImmutableMap.builder() + .put(1, 8l) + .put(3, 8l) + .put(2, 5l) // <-- count is now fixed + .build(); + for (Terms.Bucket bucket : buckets) { + assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKeyAsNumber().intValue()))); + } + } + + @Test + public void withShardSize_long_singleShard() throws Exception { + + client().admin().indices().prepareCreate("idx") + .addMapping("type", "key", "type=long") + .execute().actionGet(); + + indexData(); + + SearchResponse response = client().prepareSearch("idx").setTypes("type").setRouting("1") + .setQuery(matchAllQuery()) + .addAggregation(terms("keys").field("key").size(3).shardSize(5).order(Terms.Order.COUNT_DESC)) + .execute().actionGet(); + + Terms terms = response.getAggregations().get("keys"); + Collection buckets = terms.buckets(); + assertThat(buckets.size(), equalTo(3)); // we still only return 3 entries (based on the 'size' param) + Map expected = ImmutableMap.builder() + .put(1, 5l) + .put(2, 4l) + .put(3, 3l) + .build(); + for (Terms.Bucket bucket : buckets) { + assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKeyAsNumber().intValue()))); + } + } + + @Test + public void noShardSize_double() throws Exception { + + client().admin().indices().prepareCreate("idx") + .addMapping("type", "key", "type=double") + .execute().actionGet(); + + indexData(); + + SearchResponse response = client().prepareSearch("idx").setTypes("type") + .setQuery(matchAllQuery()) + .addAggregation(terms("keys").field("key").size(3).order(Terms.Order.COUNT_DESC)) + .execute().actionGet(); + + Terms terms = response.getAggregations().get("keys"); + Collection buckets = terms.buckets(); + assertThat(buckets.size(), equalTo(3)); + Map expected = ImmutableMap.builder() + .put(1, 8l) + .put(3, 8l) + .put(2, 4l) + .build(); + for (Terms.Bucket bucket : buckets) { + assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKeyAsNumber().intValue()))); + } + } + + @Test + public void withShardSize_double() throws Exception { + + client().admin().indices().prepareCreate("idx") + .addMapping("type", "key", "type=double") + .execute().actionGet(); + + indexData(); + + SearchResponse response = client().prepareSearch("idx").setTypes("type") + .setQuery(matchAllQuery()) + .addAggregation(terms("keys").field("key").size(3).shardSize(5).order(Terms.Order.COUNT_DESC)) + .execute().actionGet(); + + Terms terms = response.getAggregations().get("keys"); + Collection buckets = terms.buckets(); + assertThat(buckets.size(), equalTo(3)); + Map expected = ImmutableMap.builder() + .put(1, 8l) + .put(3, 8l) + .put(2, 5l) // <-- count is now fixed + .build(); + for (Terms.Bucket bucket : buckets) { + assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKeyAsNumber().intValue()))); + } + } + + @Test + public void withShardSize_double_singleShard() throws Exception { + + client().admin().indices().prepareCreate("idx") + .addMapping("type", "key", "type=double") + .execute().actionGet(); + + indexData(); + + SearchResponse response = client().prepareSearch("idx").setTypes("type").setRouting("1") + .setQuery(matchAllQuery()) + .addAggregation(terms("keys").field("key").size(3).shardSize(5).order(Terms.Order.COUNT_DESC)) + .execute().actionGet(); + + Terms terms = response.getAggregations().get("keys"); + Collection buckets = terms.buckets(); + assertThat(buckets.size(), equalTo(3)); + Map expected = ImmutableMap.builder() + .put(1, 5l) + .put(2, 4l) + .put(3, 3l) + .build(); + for (Terms.Bucket bucket : buckets) { + assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKeyAsNumber().intValue()))); + } + } + + private void indexData() throws Exception { + + /* + + + || || size = 3, shard_size = 5 || shard_size = size = 3 || + ||==========||==================================================||===============================================|| + || shard 1: || "1" - 5 | "2" - 4 | "3" - 3 | "4" - 2 | "5" - 1 || "1" - 5 | "3" - 3 | "2" - 4 || + ||----------||--------------------------------------------------||-----------------------------------------------|| + || shard 2: || "1" - 3 | "2" - 1 | "3" - 5 | "4" - 2 | "5" - 1 || "1" - 3 | "3" - 5 | "4" - 2 || + ||----------||--------------------------------------------------||-----------------------------------------------|| + || reduced: || "1" - 8 | "2" - 5 | "3" - 8 | "4" - 4 | "5" - 2 || || + || || || "1" - 8, "3" - 8, "2" - 4 <= WRONG || + || || "1" - 8 | "3" - 8 | "2" - 5 <= CORRECT || || + + + */ + + List indexOps = new ArrayList(); + + indexDoc("1", "1", 5, indexOps); + indexDoc("1", "2", 4, indexOps); + indexDoc("1", "3", 3, indexOps); + indexDoc("1", "4", 2, indexOps); + indexDoc("1", "5", 1, indexOps); + + // total docs in shard "1" = 15 + + indexDoc("2", "1", 3, indexOps); + indexDoc("2", "2", 1, indexOps); + indexDoc("2", "3", 5, indexOps); + indexDoc("2", "4", 2, indexOps); + indexDoc("2", "5", 1, indexOps); + + // total docs in shard "2" = 12 + + indexRandom(true, indexOps); + + long totalOnOne = client().prepareSearch("idx").setTypes("type").setRouting("1").setQuery(matchAllQuery()).execute().actionGet().getHits().getTotalHits(); + assertThat(totalOnOne, is(15l)); + long totalOnTwo = client().prepareSearch("idx").setTypes("type").setRouting("2").setQuery(matchAllQuery()).execute().actionGet().getHits().getTotalHits(); + assertThat(totalOnTwo, is(12l)); + } + + private void indexDoc(String shard, String key, int times, List indexOps) throws Exception { + for (int i = 0; i < times; i++) { + indexOps.add(client().prepareIndex("idx", "type").setRouting(shard).setCreate(true).setSource(jsonBuilder() + .startObject() + .field("key", key) + .endObject())); + } + } + +}