Geo Distance / Range Facets might count documents several times for a range entry if the field is multi valued, closes #824.

This commit is contained in:
kimchy 2011-04-04 17:44:38 +03:00
parent 5d6e84f206
commit 105d60ac9c
12 changed files with 254 additions and 129 deletions

View File

@ -142,6 +142,12 @@ public abstract class GeoPointFieldData extends FieldData<GeoPointDocFieldData>
void onValue(double lat, double lon);
}
public abstract void forEachValueInDoc(int docId, ValueInDocProc proc);
public static interface ValueInDocProc {
void onValue(int docId, double lat, double lon);
}
public static GeoPointFieldData load(IndexReader reader, String field) throws IOException {
return FieldDataLoader.load(reader, field, new StringTypeLoader());
}

View File

@ -108,6 +108,15 @@ public class MultiValueGeoPointFieldData extends GeoPointFieldData {
}
}
@Override public void forEachValueInDoc(int docId, ValueInDocProc proc) {
for (int[] ordinal : ordinals) {
int loc = ordinal[docId];
if (loc != 0) {
proc.onValue(docId, lat[loc], lon[loc]);
}
}
}
@Override public void forEachOrdinalInDoc(int docId, OrdinalInDocProc proc) {
for (int[] ordinal : ordinals) {
proc.onOrdinal(docId, ordinal[docId]);

View File

@ -84,6 +84,14 @@ public class SingleValueGeoPointFieldData extends GeoPointFieldData {
proc.onOrdinal(docId, ordinals[docId]);
}
@Override public void forEachValueInDoc(int docId, ValueInDocProc proc) {
int loc = ordinals[docId];
if (loc == 0) {
return;
}
proc.onValue(docId, lat[loc], lon[loc]);
}
@Override public GeoPoint value(int docId) {
int loc = ordinals[docId];
if (loc == 0) {

View File

@ -53,6 +53,11 @@ public interface GeoDistanceFacet extends Facet, Iterable<GeoDistanceFacet.Entry
double total;
/**
* internal field used to see if this entry was already found for a doc
*/
boolean foundInDoc = false;
Entry() {
}

View File

@ -54,6 +54,8 @@ public class GeoDistanceFacetCollector extends AbstractFacetCollector {
protected final GeoDistanceFacet.Entry[] entries;
protected GeoPointFieldData.ValueInDocProc aggregator;
public GeoDistanceFacetCollector(String facetName, String fieldName, double lat, double lon, DistanceUnit unit, GeoDistance geoDistance,
GeoDistanceFacet.Entry[] entries, SearchContext context) {
super(facetName);
@ -78,6 +80,7 @@ public class GeoDistanceFacetCollector extends AbstractFacetCollector {
}
this.indexFieldName = smartMappers.mapper().names().indexName();
this.aggregator = new Aggregator(lat, lon, geoDistance, unit, entries);
}
@Override protected void doSetNextReader(IndexReader reader, int docBase) throws IOException {
@ -85,34 +88,48 @@ public class GeoDistanceFacetCollector extends AbstractFacetCollector {
}
@Override protected void doCollect(int doc) throws IOException {
if (!fieldData.hasValue(doc)) {
return;
for (GeoDistanceFacet.Entry entry : entries) {
entry.foundInDoc = false;
}
fieldData.forEachValueInDoc(doc, aggregator);
}
@Override public Facet facet() {
return new InternalGeoDistanceFacet(facetName, entries);
}
public static class Aggregator implements GeoPointFieldData.ValueInDocProc {
protected final double lat;
protected final double lon;
private final GeoDistance geoDistance;
private final DistanceUnit unit;
private final GeoDistanceFacet.Entry[] entries;
public Aggregator(double lat, double lon, GeoDistance geoDistance, DistanceUnit unit, GeoDistanceFacet.Entry[] entries) {
this.lat = lat;
this.lon = lon;
this.geoDistance = geoDistance;
this.unit = unit;
this.entries = entries;
}
if (fieldData.multiValued()) {
double[] lats = fieldData.latValues(doc);
double[] lons = fieldData.lonValues(doc);
for (int i = 0; i < lats.length; i++) {
double distance = geoDistance.calculate(lat, lon, lats[i], lons[i], unit);
for (GeoDistanceFacet.Entry entry : entries) {
if (distance >= entry.getFrom() && distance < entry.getTo()) {
entry.count++;
entry.total += distance;
}
}
}
} else {
double distance = geoDistance.calculate(lat, lon, fieldData.latValue(doc), fieldData.lonValue(doc), unit);
@Override public void onValue(int docId, double lat, double lon) {
double distance = geoDistance.calculate(this.lat, this.lon, lat, lon, unit);
for (GeoDistanceFacet.Entry entry : entries) {
if (entry.foundInDoc) {
continue;
}
if (distance >= entry.getFrom() && distance < entry.getTo()) {
entry.foundInDoc = true;
entry.count++;
entry.total += distance;
}
}
}
}
@Override public Facet facet() {
return new InternalGeoDistanceFacet(facetName, entries);
}
}

View File

@ -22,7 +22,7 @@ package org.elasticsearch.search.facet.geodistance;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Scorer;
import org.elasticsearch.common.unit.DistanceUnit;
import org.elasticsearch.index.mapper.xcontent.geo.GeoPoint;
import org.elasticsearch.index.mapper.xcontent.geo.GeoPointFieldData;
import org.elasticsearch.index.search.geo.GeoDistance;
import org.elasticsearch.script.SearchScript;
import org.elasticsearch.search.internal.SearchContext;
@ -37,12 +37,16 @@ public class ScriptGeoDistanceFacetCollector extends GeoDistanceFacetCollector {
private final SearchScript script;
private Aggregator scriptAggregator;
public ScriptGeoDistanceFacetCollector(String facetName, String fieldName, double lat, double lon, DistanceUnit unit, GeoDistance geoDistance,
GeoDistanceFacet.Entry[] entries, SearchContext context,
String scriptLang, String script, Map<String, Object> params) {
super(facetName, fieldName, lat, lon, unit, geoDistance, entries, context);
this.script = context.scriptService().search(context.lookup(), scriptLang, script, params);
this.aggregator = new Aggregator(lat, lon, geoDistance, unit, entries);
this.scriptAggregator = (Aggregator) this.aggregator;
}
@Override public void setScorer(Scorer scorer) throws IOException {
@ -55,32 +59,43 @@ public class ScriptGeoDistanceFacetCollector extends GeoDistanceFacetCollector {
}
@Override protected void doCollect(int doc) throws IOException {
if (!fieldData.hasValue(doc)) {
return;
script.setNextDocId(doc);
this.scriptAggregator.scriptValue = script.runAsDouble();
super.doCollect(doc);
}
public static class Aggregator implements GeoPointFieldData.ValueInDocProc {
protected final double lat;
protected final double lon;
private final GeoDistance geoDistance;
private final DistanceUnit unit;
private final GeoDistanceFacet.Entry[] entries;
double scriptValue;
public Aggregator(double lat, double lon, GeoDistance geoDistance, DistanceUnit unit, GeoDistanceFacet.Entry[] entries) {
this.lat = lat;
this.lon = lon;
this.geoDistance = geoDistance;
this.unit = unit;
this.entries = entries;
}
script.setNextDocId(doc);
double value = script.runAsDouble();
if (fieldData.multiValued()) {
GeoPoint[] points = fieldData.values(doc);
for (GeoPoint point : points) {
double distance = geoDistance.calculate(lat, lon, point.lat(), point.lon(), unit);
for (GeoDistanceFacet.Entry entry : entries) {
if (distance >= entry.getFrom() && distance < entry.getTo()) {
entry.count++;
entry.total += value;
}
}
}
} else {
GeoPoint point = fieldData.value(doc);
double distance = geoDistance.calculate(lat, lon, point.lat(), point.lon(), unit);
@Override public void onValue(int docId, double lat, double lon) {
double distance = geoDistance.calculate(this.lat, this.lon, lat, lon, unit);
for (GeoDistanceFacet.Entry entry : entries) {
if (entry.foundInDoc) {
continue;
}
if (distance >= entry.getFrom() && distance < entry.getTo()) {
entry.foundInDoc = true;
entry.count++;
entry.total += value;
entry.total += scriptValue;
}
}
}

View File

@ -24,9 +24,8 @@ import org.elasticsearch.common.unit.DistanceUnit;
import org.elasticsearch.index.field.data.FieldDataType;
import org.elasticsearch.index.field.data.NumericFieldData;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.xcontent.geo.GeoPoint;
import org.elasticsearch.index.mapper.xcontent.geo.GeoPointFieldData;
import org.elasticsearch.index.search.geo.GeoDistance;
import org.elasticsearch.search.facet.Facet;
import org.elasticsearch.search.facet.FacetPhaseExecutionException;
import org.elasticsearch.search.internal.SearchContext;
@ -41,8 +40,6 @@ public class ValueGeoDistanceFacetCollector extends GeoDistanceFacetCollector {
private final FieldDataType valueFieldDataType;
private NumericFieldData valueFieldData;
public ValueGeoDistanceFacetCollector(String facetName, String fieldName, double lat, double lon, DistanceUnit unit, GeoDistance geoDistance,
GeoDistanceFacet.Entry[] entries, SearchContext context, String valueFieldName) {
super(facetName, fieldName, lat, lon, unit, geoDistance, entries, context);
@ -53,56 +50,55 @@ public class ValueGeoDistanceFacetCollector extends GeoDistanceFacetCollector {
}
this.indexValueFieldName = valueFieldName;
this.valueFieldDataType = mapper.fieldDataType();
this.aggregator = new Aggregator(lat, lon, geoDistance, unit, entries);
}
@Override protected void doSetNextReader(IndexReader reader, int docBase) throws IOException {
super.doSetNextReader(reader, docBase);
valueFieldData = (NumericFieldData) fieldDataCache.cache(valueFieldDataType, reader, indexValueFieldName);
((Aggregator) this.aggregator).valueFieldData = (NumericFieldData) fieldDataCache.cache(valueFieldDataType, reader, indexValueFieldName);
}
@Override protected void doCollect(int doc) throws IOException {
if (!fieldData.hasValue(doc)) {
return;
public static class Aggregator implements GeoPointFieldData.ValueInDocProc {
protected final double lat;
protected final double lon;
private final GeoDistance geoDistance;
private final DistanceUnit unit;
private final GeoDistanceFacet.Entry[] entries;
NumericFieldData valueFieldData;
public Aggregator(double lat, double lon, GeoDistance geoDistance, DistanceUnit unit, GeoDistanceFacet.Entry[] entries) {
this.lat = lat;
this.lon = lon;
this.geoDistance = geoDistance;
this.unit = unit;
this.entries = entries;
}
if (fieldData.multiValued()) {
GeoPoint[] points = fieldData.values(doc);
double[] values = valueFieldData.multiValued() ? valueFieldData.doubleValues(doc) : null;
for (int i = 0; i < points.length; i++) {
double distance = geoDistance.calculate(lat, lon, points[i].lat(), points[i].lon(), unit);
for (GeoDistanceFacet.Entry entry : entries) {
if (distance >= entry.getFrom() && distance < entry.getTo()) {
entry.count++;
if (values != null) {
if (i < values.length) {
entry.total += values[i];
}
} else if (valueFieldData.hasValue(doc)) {
entry.total += valueFieldData.doubleValue(doc);
}
}
}
}
} else {
GeoPoint point = fieldData.value(doc);
double distance = geoDistance.calculate(lat, lon, point.lat(), point.lon(), unit);
@Override public void onValue(int docId, double lat, double lon) {
double distance = geoDistance.calculate(this.lat, this.lon, lat, lon, unit);
for (GeoDistanceFacet.Entry entry : entries) {
if (entry.foundInDoc) {
continue;
}
if (distance >= entry.getFrom() && distance < entry.getTo()) {
entry.foundInDoc = true;
entry.count++;
if (valueFieldData.multiValued()) {
double[] values = valueFieldData.doubleValues(doc);
double[] values = valueFieldData.doubleValues(docId);
for (double value : values) {
entry.total += value;
}
} else if (valueFieldData.hasValue(doc)) {
entry.total += valueFieldData.doubleValue(doc);
} else if (valueFieldData.hasValue(docId)) {
entry.total += valueFieldData.doubleValue(docId);
}
}
}
}
}
@Override public Facet facet() {
return new InternalGeoDistanceFacet(facetName, entries);
}
}

View File

@ -51,6 +51,8 @@ public class KeyValueRangeFacetCollector extends AbstractFacetCollector {
private final RangeFacet.Entry[] entries;
private final RangeProc rangeProc;
public KeyValueRangeFacetCollector(String facetName, String keyFieldName, String valueFieldName, RangeFacet.Entry[] entries, SearchContext context) {
super(facetName);
this.entries = entries;
@ -75,66 +77,61 @@ public class KeyValueRangeFacetCollector extends AbstractFacetCollector {
}
valueIndexFieldName = mapper.names().indexName();
valueFieldDataType = mapper.fieldDataType();
this.rangeProc = new RangeProc(entries);
}
@Override protected void doSetNextReader(IndexReader reader, int docBase) throws IOException {
keyFieldData = (NumericFieldData) fieldDataCache.cache(keyFieldDataType, reader, keyIndexFieldName);
valueFieldData = (NumericFieldData) fieldDataCache.cache(valueFieldDataType, reader, valueIndexFieldName);
rangeProc.valueFieldData = (NumericFieldData) fieldDataCache.cache(valueFieldDataType, reader, valueIndexFieldName);
}
@Override protected void doCollect(int doc) throws IOException {
if (keyFieldData.multiValued()) {
if (valueFieldData.multiValued()) {
// both multi valued, intersect based on the minimum size
double[] keys = keyFieldData.doubleValues(doc);
double[] values = valueFieldData.doubleValues(doc);
int size = Math.min(keys.length, values.length);
for (int i = 0; i < size; i++) {
double key = keys[i];
for (RangeFacet.Entry entry : entries) {
if (key >= entry.getFrom() && key < entry.getTo()) {
entry.count++;
entry.total += values[i];
}
}
}
} else {
// key multi valued, value is a single value
double value = valueFieldData.doubleValue(doc);
for (double key : keyFieldData.doubleValues(doc)) {
for (RangeFacet.Entry entry : entries) {
if (key >= entry.getFrom() && key < entry.getTo()) {
entry.count++;
entry.total += value;
}
}
}
}
} else {
double key = keyFieldData.doubleValue(doc);
if (valueFieldData.multiValued()) {
for (RangeFacet.Entry entry : entries) {
if (key >= entry.getFrom() && key < entry.getTo()) {
entry.count++;
for (double value : valueFieldData.doubleValues(doc)) {
entry.total += value;
}
}
}
} else {
// both key and value are not multi valued
double value = valueFieldData.doubleValue(doc);
for (RangeFacet.Entry entry : entries) {
if (key >= entry.getFrom() && key < entry.getTo()) {
entry.count++;
entry.total += value;
}
}
}
for (RangeFacet.Entry entry : entries) {
entry.foundInDoc = false;
}
keyFieldData.forEachValueInDoc(doc, rangeProc);
}
@Override public Facet facet() {
return new InternalRangeFacet(facetName, entries);
}
public static class RangeProc implements NumericFieldData.DoubleValueInDocProc {
private final RangeFacet.Entry[] entries;
private int missing;
NumericFieldData valueFieldData;
public RangeProc(RangeFacet.Entry[] entries) {
this.entries = entries;
}
@Override public void onValue(int docId, double value) {
for (RangeFacet.Entry entry : entries) {
if (entry.foundInDoc) {
continue;
}
if (value >= entry.getFrom() && value < entry.getTo()) {
entry.foundInDoc = true;
entry.count++;
if (valueFieldData.multiValued()) {
double[] valuesValues = valueFieldData.doubleValues(docId);
for (double valueValue : valuesValues) {
entry.total += valueValue;
}
} else {
double valueValue = valueFieldData.doubleValue(docId);
entry.total += valueValue;
}
}
}
}
@Override public void onMissing(int docId) {
missing++;
}
}
}

View File

@ -57,6 +57,11 @@ public interface RangeFacet extends Facet, Iterable<RangeFacet.Entry> {
double total;
/**
* Internal field used in facet collection
*/
boolean foundInDoc;
Entry() {
}

View File

@ -74,6 +74,9 @@ public class RangeFacetCollector extends AbstractFacetCollector {
}
@Override protected void doCollect(int doc) throws IOException {
for (RangeFacet.Entry entry : entries) {
entry.foundInDoc = false;
}
fieldData.forEachValueInDoc(doc, rangeProc);
}
@ -93,7 +96,11 @@ public class RangeFacetCollector extends AbstractFacetCollector {
@Override public void onValue(int docId, double value) {
for (RangeFacet.Entry entry : entries) {
if (entry.foundInDoc) {
continue;
}
if (value >= entry.getFrom() && value < entry.getTo()) {
entry.foundInDoc = true;
entry.count++;
entry.total += value;
}

View File

@ -1147,8 +1147,8 @@ public class SimpleFacetsTests extends AbstractNodesTests {
assertThat(facet.entries().get(0).total(), closeTo(3, 0.000001));
assertThat(facet.entries().get(1).from(), closeTo(10, 0.000001));
assertThat(facet.entries().get(1).to(), closeTo(26, 0.000001));
assertThat(facet.entries().get(1).count(), equalTo(5l));
assertThat(facet.entries().get(1).total(), closeTo(1 * 2 + 2 + 3 * 2, 0.000001));
assertThat(facet.entries().get(1).count(), equalTo(3l));
assertThat(facet.entries().get(1).total(), closeTo(1 + 2 + 3, 0.000001));
assertThat(facet.entries().get(2).from(), closeTo(20, 0.000001));
assertThat(facet.entries().get(2).count(), equalTo(3l));
assertThat(facet.entries().get(2).total(), closeTo(1 + 2 + 3, 0.000001));

View File

@ -29,6 +29,8 @@ import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import java.util.Arrays;
import static org.elasticsearch.common.xcontent.XContentFactory.*;
import static org.elasticsearch.index.query.xcontent.QueryBuilders.*;
import static org.elasticsearch.search.facet.FacetBuilders.*;
@ -217,4 +219,62 @@ public class GeoDistanceFacetTests extends AbstractNodesTests {
assertThat(facet.entries().get(3).count(), equalTo(5l));
assertThat(facet.entries().get(3).total(), closeTo(24, 0.00001));
}
@Test public void multiLocationGeoDistanceTest() throws Exception {
try {
client.admin().indices().prepareDelete("test").execute().actionGet();
} catch (Exception e) {
// ignore
}
String mapping = XContentFactory.jsonBuilder().startObject().startObject("type1")
.startObject("properties").startObject("location").field("type", "geo_point").field("lat_lon", true).endObject().endObject()
.endObject().endObject().string();
client.admin().indices().prepareCreate("test").addMapping("type1", mapping).execute().actionGet();
client.admin().cluster().prepareHealth().setWaitForGreenStatus().execute().actionGet();
client.prepareIndex("test", "type1", "1").setSource(jsonBuilder().startObject()
.field("num", 1)
.startArray("location")
// to NY: 0
.startObject().field("lat", 40.7143528).field("lon", -74.0059731).endObject()
// to NY: 5.286 km
.startObject().field("lat", 40.759011).field("lon", -73.9844722).endObject()
.endArray()
.endObject()).execute().actionGet();
client.prepareIndex("test", "type1", "3").setSource(jsonBuilder().startObject()
.field("num", 3)
.startArray("location")
// to NY: 0.4621 km
.startObject().field("lat", 40.718266).field("lon", -74.007819).endObject()
// to NY: 1.055 km
.startObject().field("lat", 40.7051157).field("lon", -74.0088305).endObject()
.endArray()
.endObject()).execute().actionGet();
client.admin().indices().prepareRefresh().execute().actionGet();
SearchResponse searchResponse = client.prepareSearch() // from NY
.setQuery(matchAllQuery())
.addFacet(geoDistanceFacet("geo1").field("location").point(40.7143528, -74.0059731).unit(DistanceUnit.KILOMETERS)
.addRange(0, 2)
.addRange(2, 10)
)
.execute().actionGet();
assertThat(Arrays.toString(searchResponse.shardFailures()), searchResponse.failedShards(), equalTo(0));
assertThat(searchResponse.hits().totalHits(), equalTo(2l));
GeoDistanceFacet facet = searchResponse.facets().facet("geo1");
assertThat(facet.entries().size(), equalTo(2));
assertThat(facet.entries().get(0).from(), closeTo(0, 0.000001));
assertThat(facet.entries().get(0).to(), closeTo(2, 0.000001));
assertThat(facet.entries().get(0).count(), equalTo(2l));
assertThat(facet.entries().get(1).from(), closeTo(2, 0.000001));
assertThat(facet.entries().get(1).to(), closeTo(10, 0.000001));
assertThat(facet.entries().get(1).count(), equalTo(1l));
}
}