diff --git a/src/main/java/org/elasticsearch/common/util/Comparators.java b/src/main/java/org/elasticsearch/common/util/Comparators.java new file mode 100644 index 00000000000..d9943d12a05 --- /dev/null +++ b/src/main/java/org/elasticsearch/common/util/Comparators.java @@ -0,0 +1,43 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.common.util; + +import java.util.Comparator; + +/** + * {@link Comparator}-related utility methods. + */ +public enum Comparators { + ; + + /** + * Compare d1 against d2, pushing {@value Double#NaN} at the bottom. + */ + public static int compareDiscardNaN(double d1, double d2, boolean asc) { + if (Double.isNaN(d1)) { + return Double.isNaN(d2) ? 0 : 1; + } else if (Double.isNaN(d2)) { + return -1; + } else { + return asc ? Double.compare(d1, d2) : Double.compare(d2, d1); + } + } + +} diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/MultiBucketsAggregation.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/MultiBucketsAggregation.java index ffcae1a8181..4c7408cc5eb 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/MultiBucketsAggregation.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/MultiBucketsAggregation.java @@ -21,6 +21,7 @@ package org.elasticsearch.search.aggregations.bucket; import org.elasticsearch.ElasticsearchIllegalArgumentException; import org.elasticsearch.common.text.Text; +import org.elasticsearch.common.util.Comparators; import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.metrics.MetricsAggregation; @@ -99,7 +100,7 @@ public interface MultiBucketsAggregation extends Aggregation { public int compare(B b1, B b2) { double v1 = value(b1); double v2 = value(b2); - return asc ? Double.compare(v1, v2) : Double.compare(v2, v1); + return Comparators.compareDiscardNaN(v1, v2, asc); } private double value(B bucket) { diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalOrder.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalOrder.java index cfa4e9f68cd..73adfe91b26 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalOrder.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalOrder.java @@ -21,6 +21,7 @@ package org.elasticsearch.search.aggregations.bucket.terms; import com.google.common.primitives.Longs; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.Comparators; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.search.aggregations.AggregationExecutionException; import org.elasticsearch.search.aggregations.Aggregator; @@ -212,13 +213,7 @@ class InternalOrder extends Terms.Order { double v2 = ((MetricsAggregator.MultiValue) aggregator).metric(valueName, ((InternalTerms.Bucket) o2).bucketOrd); // some metrics may return NaN (eg. avg, variance, etc...) in which case we'd like to push all of those to // the bottom - if (Double.isNaN(v1)) { - return Double.isNaN(v2) ? 0 : 1; - } - if (Double.isNaN(v2)) { - return -1; - } - return asc ? Double.compare(v1, v2) : Double.compare(v2, v1); + return Comparators.compareDiscardNaN(v1, v2, asc); } }; } @@ -230,13 +225,7 @@ class InternalOrder extends Terms.Order { double v2 = ((MetricsAggregator.SingleValue) aggregator).metric(((InternalTerms.Bucket) o2).bucketOrd); // some metrics may return NaN (eg. avg, variance, etc...) in which case we'd like to push all of those to // the bottom - if (Double.isNaN(v1)) { - return Double.isNaN(v2) ? 0 : 1; - } - if (Double.isNaN(v2)) { - return -1; - } - return asc ? Double.compare(v1, v2) : Double.compare(v2, v1); + return Comparators.compareDiscardNaN(v1, v2, asc); } }; } diff --git a/src/test/java/org/elasticsearch/search/aggregations/bucket/NaNSortingTests.java b/src/test/java/org/elasticsearch/search/aggregations/bucket/NaNSortingTests.java new file mode 100644 index 00000000000..789b8144119 --- /dev/null +++ b/src/test/java/org/elasticsearch/search/aggregations/bucket/NaNSortingTests.java @@ -0,0 +1,183 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.bucket; + +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.settings.ImmutableSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.Comparators; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.search.aggregations.Aggregation; +import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.metrics.MetricsAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.avg.Avg; +import org.elasticsearch.search.aggregations.metrics.stats.extended.ExtendedStats; +import org.elasticsearch.test.ElasticsearchIntegrationTest; +import org.junit.Before; +import org.junit.Test; + +import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.elasticsearch.search.aggregations.AggregationBuilders.*; +import static org.hamcrest.core.IsNull.notNullValue; + +public class NaNSortingTests extends ElasticsearchIntegrationTest { + + @Override + public Settings indexSettings() { + return ImmutableSettings.builder() + .put("index.number_of_shards", between(1, 5)) + .put("index.number_of_replicas", between(0, 1)) + .build(); + } + + private enum SubAggregation { + AVG("avg") { + @Override + public MetricsAggregationBuilder builder() { + return avg(name).field("numeric_field"); + } + @Override + public double getValue(Aggregation aggregation) { + return ((Avg) aggregation).getValue(); + } + }, + VARIANCE("variance") { + @Override + public MetricsAggregationBuilder builder() { + return extendedStats(name).field("numeric_field"); + } + @Override + public String sortKey() { + return name + ".variance"; + } + @Override + public double getValue(Aggregation aggregation) { + return ((ExtendedStats) aggregation).getVariance(); + } + }, + STD_DEVIATION("std_deviation"){ + @Override + public MetricsAggregationBuilder builder() { + return extendedStats(name).field("numeric_field"); + } + @Override + public String sortKey() { + return name + ".std_deviation"; + } + @Override + public double getValue(Aggregation aggregation) { + return ((ExtendedStats) aggregation).getStdDeviation(); + } + }; + + SubAggregation(String name) { + this.name = name; + } + + public String name; + + public abstract MetricsAggregationBuilder builder(); + + public String sortKey() { + return name; + } + + public abstract double getValue(Aggregation aggregation); + } + + @Before + public void init() throws Exception { + createIndex("idx"); + final int numDocs = randomIntBetween(2, 10); + for (int i = 0; i < numDocs; ++i) { + final long value = randomInt(5); + XContentBuilder source = jsonBuilder().startObject().field("long_value", value).field("double_value", value + 0.05).field("string_value", "str_" + value); + if (randomBoolean()) { + source.field("numeric_value", randomDouble()); + } + client().prepareIndex("idx", "type").setSource(source.endObject()).execute().actionGet(); + } + refresh(); + ensureSearchable(); + } + + private void assertCorrectlySorted(Terms terms, boolean asc, SubAggregation agg) { + assertThat(terms, notNullValue()); + double previousValue = asc ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY; + for (Terms.Bucket bucket : terms.getBuckets()) { + Aggregation sub = bucket.getAggregations().get(agg.name); + double value = agg.getValue(sub); + assertTrue(Comparators.compareDiscardNaN(previousValue, value, asc) <= 0); + previousValue = value; + } + } + + private void assertCorrectlySorted(Histogram histo, boolean asc, SubAggregation agg) { + assertThat(histo, notNullValue()); + double previousValue = asc ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY; + for (Histogram.Bucket bucket : histo.getBuckets()) { + Aggregation sub = bucket.getAggregations().get(agg.name); + double value = agg.getValue(sub); + assertTrue(Comparators.compareDiscardNaN(previousValue, value, asc) <= 0); + previousValue = value; + } + } + + public void testTerms(String fieldName) { + final boolean asc = randomBoolean(); + SubAggregation agg = randomFrom(SubAggregation.values()); + SearchResponse response = client().prepareSearch("idx") + .addAggregation(terms("terms").field(fieldName).subAggregation(agg.builder()).order(Terms.Order.aggregation(agg.sortKey(), asc))) + .execute().actionGet(); + + final Terms terms = response.getAggregations().get("terms"); + assertCorrectlySorted(terms, asc, agg); + } + + @Test + public void stringTerms() { + testTerms("string_value"); + } + + @Test + public void longTerms() { + testTerms("long_value"); + } + + @Test + public void doubleTerms() { + testTerms("double_value"); + } + + @Test + public void longHistogram() { + final boolean asc = randomBoolean(); + SubAggregation agg = randomFrom(SubAggregation.values()); + SearchResponse response = client().prepareSearch("idx") + .addAggregation(histogram("histo") + .field("long_value").interval(randomIntBetween(1, 2)).subAggregation(agg.builder()).order(Histogram.Order.aggregation(agg.sortKey(), asc))) + .execute().actionGet(); + + final Histogram histo = response.getAggregations().get("histo"); + assertCorrectlySorted(histo, asc, agg); + } + +}