Improve accuracy for Geo Centroid Aggregation (#41514)

keeps the partial results as doubles and uses Kahan summation to help reduce floating point errors.
This commit is contained in:
Ignacio Vera 2019-04-25 15:25:48 +02:00 committed by GitHub
parent cd830b53e2
commit d119abdf96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 37 deletions

View File

@ -58,8 +58,8 @@ The response for the above aggregation:
"aggregations": {
"centroid": {
"location": {
"lat": 51.009829603135586,
"lon": 3.9662130642682314
"lat": 51.00982965203002,
"lon": 3.9662131341174245
},
"count": 6
}
@ -111,8 +111,8 @@ The response for the above aggregation:
"doc_count": 3,
"centroid": {
"location": {
"lat": 52.371655642054975,
"lon": 4.9095632415264845
"lat": 52.371655656024814,
"lon": 4.909563297405839
},
"count": 3
}
@ -123,7 +123,7 @@ The response for the above aggregation:
"centroid": {
"location": {
"lat": 48.86055548675358,
"lon": 2.331694420427084
"lon": 2.3316944623366
},
"count": 2
}

View File

@ -23,6 +23,7 @@ import org.apache.lucene.index.LeafReaderContext;
import org.elasticsearch.common.geo.GeoPoint;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.DoubleArray;
import org.elasticsearch.common.util.LongArray;
import org.elasticsearch.index.fielddata.MultiGeoPointValues;
import org.elasticsearch.search.aggregations.Aggregator;
@ -42,7 +43,7 @@ import java.util.Map;
*/
final class GeoCentroidAggregator extends MetricsAggregator {
private final ValuesSource.GeoPoint valuesSource;
private LongArray centroids;
private DoubleArray lonSum, lonCompensations, latSum, latCompensations;
private LongArray counts;
GeoCentroidAggregator(String name, SearchContext context, Aggregator parent,
@ -52,7 +53,10 @@ final class GeoCentroidAggregator extends MetricsAggregator {
this.valuesSource = valuesSource;
if (valuesSource != null) {
final BigArrays bigArrays = context.bigArrays();
centroids = bigArrays.newLongArray(1, true);
lonSum = bigArrays.newDoubleArray(1, true);
lonCompensations = bigArrays.newDoubleArray(1, true);
latSum = bigArrays.newDoubleArray(1, true);
latCompensations = bigArrays.newDoubleArray(1, true);
counts = bigArrays.newLongArray(1, true);
}
}
@ -67,33 +71,41 @@ final class GeoCentroidAggregator extends MetricsAggregator {
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
centroids = bigArrays.grow(centroids, bucket + 1);
latSum = bigArrays.grow(latSum, bucket + 1);
lonSum = bigArrays.grow(lonSum, bucket + 1);
lonCompensations = bigArrays.grow(lonCompensations, bucket + 1);
latCompensations = bigArrays.grow(latCompensations, bucket + 1);
counts = bigArrays.grow(counts, bucket + 1);
if (values.advanceExact(doc)) {
final int valueCount = values.docValueCount();
double[] pt = new double[2];
// get the previously accumulated number of counts
long prevCounts = counts.get(bucket);
// increment by the number of points for this document
counts.increment(bucket, valueCount);
// get the previous GeoPoint if a moving avg was
// computed
if (prevCounts > 0) {
final long mortonCode = centroids.get(bucket);
pt[0] = InternalGeoCentroid.decodeLongitude(mortonCode);
pt[1] = InternalGeoCentroid.decodeLatitude(mortonCode);
}
// update the moving average
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sumLat = latSum.get(bucket);
double compensationLat = latCompensations.get(bucket);
double sumLon = lonSum.get(bucket);
double compensationLon = lonCompensations.get(bucket);
// update the sum
for (int i = 0; i < valueCount; ++i) {
GeoPoint value = values.nextValue();
pt[0] = pt[0] + (value.getLon() - pt[0]) / ++prevCounts;
pt[1] = pt[1] + (value.getLat() - pt[1]) / prevCounts;
//latitude
double correctedLat = value.getLat() - compensationLat;
double newSumLat = sumLat + correctedLat;
compensationLat = (newSumLat - sumLat) - correctedLat;
sumLat = newSumLat;
//longitude
double correctedLon = value.getLon() - compensationLon;
double newSumLon = sumLon + correctedLon;
compensationLon = (newSumLon - sumLon) - correctedLon;
sumLon = newSumLon;
}
// TODO: we do not need to interleave the lat and lon
// bits here
// should we just store them contiguously?
centroids.set(bucket, InternalGeoCentroid.encodeLatLon(pt[1], pt[0]));
lonSum.set(bucket, sumLon);
lonCompensations.set(bucket, compensationLon);
latSum.set(bucket, sumLat);
latCompensations.set(bucket, compensationLat);
}
}
};
@ -101,14 +113,12 @@ final class GeoCentroidAggregator extends MetricsAggregator {
@Override
public InternalAggregation buildAggregation(long bucket) {
if (valuesSource == null || bucket >= centroids.size()) {
if (valuesSource == null || bucket >= counts.size()) {
return buildEmptyAggregation();
}
final long bucketCount = counts.get(bucket);
final long mortonCode = centroids.get(bucket);
final GeoPoint bucketCentroid = (bucketCount > 0)
? new GeoPoint(InternalGeoCentroid.decodeLatitude(mortonCode),
InternalGeoCentroid.decodeLongitude(mortonCode))
? new GeoPoint(latSum.get(bucket) / bucketCount, lonSum.get(bucket) / bucketCount)
: null;
return new InternalGeoCentroid(name, bucketCentroid , bucketCount, pipelineAggregators(), metaData());
}
@ -120,6 +130,6 @@ final class GeoCentroidAggregator extends MetricsAggregator {
@Override
public void doClose() {
Releasables.close(centroids, counts);
Releasables.close(latSum, latCompensations, lonSum, lonCompensations, counts);
}
}

View File

@ -20,6 +20,7 @@
package org.elasticsearch.search.aggregations.metrics;
import org.apache.lucene.geo.GeoEncodingUtils;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.geo.GeoPoint;
import org.elasticsearch.common.io.stream.StreamInput;
@ -69,8 +70,13 @@ public class InternalGeoCentroid extends InternalAggregation implements GeoCentr
super(in);
count = in.readVLong();
if (in.readBoolean()) {
final long hash = in.readLong();
centroid = new GeoPoint(decodeLatitude(hash), decodeLongitude(hash));
if (in.getVersion().onOrAfter(Version.V_7_1_0)) {
centroid = new GeoPoint(in.readDouble(), in.readDouble());
} else {
final long hash = in.readLong();
centroid = new GeoPoint(decodeLatitude(hash), decodeLongitude(hash));
}
} else {
centroid = null;
}
@ -81,8 +87,12 @@ public class InternalGeoCentroid extends InternalAggregation implements GeoCentr
out.writeVLong(count);
if (centroid != null) {
out.writeBoolean(true);
// should we just write lat and lon separately?
out.writeLong(encodeLatLon(centroid.lat(), centroid.lon()));
if (out.getVersion().onOrAfter(Version.V_7_1_0)) {
out.writeDouble(centroid.lat());
out.writeDouble(centroid.lon());
} else {
out.writeLong(encodeLatLon(centroid.lat(), centroid.lon()));
}
} else {
out.writeBoolean(false);
}

View File

@ -29,8 +29,6 @@ import org.elasticsearch.common.geo.GeoPoint;
import org.elasticsearch.index.mapper.GeoPointFieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.metrics.GeoCentroidAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.InternalGeoCentroid;
import org.elasticsearch.search.aggregations.support.AggregationInspectionHelper;
import org.elasticsearch.test.geo.RandomGeoGenerator;
@ -38,7 +36,7 @@ import java.io.IOException;
public class GeoCentroidAggregatorTests extends AggregatorTestCase {
private static final double GEOHASH_TOLERANCE = 1E-4D;
private static final double GEOHASH_TOLERANCE = 1E-6D;
public void testEmpty() throws Exception {
try (Directory dir = newDirectory();