Added sort mode to geo distance sorting. Closes #1846

This commit is contained in:
Martijn van Groningen 2013-03-28 15:17:30 +01:00
parent 9bc50ea609
commit 941aa17a43
5 changed files with 246 additions and 48 deletions

View File

@ -40,13 +40,14 @@ public class GeoDistanceComparator extends FieldComparator<Double> {
protected final DistanceUnit unit;
protected final GeoDistance geoDistance;
protected final GeoDistance.FixedSourceDistance fixedSourceDistance;
protected final SortMode sortMode;
private final double[] values;
private double bottom;
private GeoPointValues readerValues;
private GeoDistanceValues geoDistanceValues;
public GeoDistanceComparator(int numHits, IndexGeoPointFieldData<?> indexFieldData, double lat, double lon, DistanceUnit unit, GeoDistance geoDistance) {
public GeoDistanceComparator(int numHits, IndexGeoPointFieldData<?> indexFieldData, double lat, double lon, DistanceUnit unit, GeoDistance geoDistance, SortMode sortMode) {
this.values = new double[numHits];
this.indexFieldData = indexFieldData;
this.lat = lat;
@ -54,11 +55,17 @@ public class GeoDistanceComparator extends FieldComparator<Double> {
this.unit = unit;
this.geoDistance = geoDistance;
this.fixedSourceDistance = geoDistance.fixedSourceDistance(lat, lon, unit);
this.sortMode = sortMode;
}
@Override
public FieldComparator<Double> setNextReader(AtomicReaderContext context) throws IOException {
this.readerValues = indexFieldData.load(context).getGeoPointValues();
GeoPointValues readerValues = indexFieldData.load(context).getGeoPointValues();
if (readerValues.isMultiValued()) {
geoDistanceValues = new MV(readerValues, fixedSourceDistance, sortMode);
} else {
geoDistanceValues = new SV(readerValues, fixedSourceDistance);
}
return this;
}
@ -77,15 +84,7 @@ public class GeoDistanceComparator extends FieldComparator<Double> {
@Override
public int compareBottom(int doc) {
double distance;
GeoPoint geoPoint = readerValues.getValue(doc);
if (geoPoint == null) {
// is this true? push this to the "end"
distance = Double.MAX_VALUE;
} else {
distance = fixedSourceDistance.calculate(geoPoint.lat(), geoPoint.lon());
}
final double v2 = distance;
final double v2 = geoDistanceValues.computeDistance(doc);
if (bottom > v2) {
return 1;
} else if (bottom < v2) {
@ -97,14 +96,7 @@ public class GeoDistanceComparator extends FieldComparator<Double> {
@Override
public int compareDocToValue(int doc, Double distance2) throws IOException {
double distance1;
GeoPoint geoPoint = readerValues.getValue(doc);
if (geoPoint == null) {
// is this true? push this to the "end"
distance1 = Double.MAX_VALUE;
} else {
distance1 = fixedSourceDistance.calculate(geoPoint.lat(), geoPoint.lon());
}
double distance1 = geoDistanceValues.computeDistance(doc);
if (distance1 < distance2) {
return -1;
} else if (distance1 == distance2) {
@ -116,15 +108,7 @@ public class GeoDistanceComparator extends FieldComparator<Double> {
@Override
public void copy(int slot, int doc) {
double distance;
GeoPoint geoPoint = readerValues.getValue(doc);
if (geoPoint == null) {
// is this true? push this to the "end"
distance = Double.MAX_VALUE;
} else {
distance = fixedSourceDistance.calculate(geoPoint.lat(), geoPoint.lon());
}
values[slot] = distance;
values[slot] = geoDistanceValues.computeDistance(doc);
}
@Override
@ -136,4 +120,81 @@ public class GeoDistanceComparator extends FieldComparator<Double> {
public Double value(int slot) {
return values[slot];
}
// Computes the distance based on geo points.
// Due to this abstractions the geo distance comparator doesn't need to deal with whether fields have one
// or multiple geo points per document.
private static abstract class GeoDistanceValues {
protected final GeoPointValues readerValues;
protected final GeoDistance.FixedSourceDistance fixedSourceDistance;
protected GeoDistanceValues(GeoPointValues readerValues, GeoDistance.FixedSourceDistance fixedSourceDistance) {
this.readerValues = readerValues;
this.fixedSourceDistance = fixedSourceDistance;
}
public abstract double computeDistance(int doc);
}
// Deals with one geo point per document
private static final class SV extends GeoDistanceValues {
SV(GeoPointValues readerValues, GeoDistance.FixedSourceDistance fixedSourceDistance) {
super(readerValues, fixedSourceDistance);
}
@Override
public double computeDistance(int doc) {
GeoPoint geoPoint = readerValues.getValue(doc);
if (geoPoint == null) {
// is this true? push this to the "end"
return Double.MAX_VALUE;
} else {
return fixedSourceDistance.calculate(geoPoint.lat(), geoPoint.lon());
}
}
}
// Deals with more than one geo point per document
private static final class MV extends GeoDistanceValues {
private final SortMode sortMode;
MV(GeoPointValues readerValues, GeoDistance.FixedSourceDistance fixedSourceDistance, SortMode sortMode) {
super(readerValues, fixedSourceDistance);
this.sortMode = sortMode;
}
@Override
public double computeDistance(int doc) {
GeoPointValues.Iter iter = readerValues.getIter(doc);
if (!iter.hasNext()) {
return Double.MAX_VALUE;
}
GeoPoint point = iter.next();
double distance = fixedSourceDistance.calculate(point.lat(), point.lon());
while (iter.hasNext()) {
point = iter.next();
double newDistance = fixedSourceDistance.calculate(point.lat(), point.lon());
switch (sortMode) {
case MIN:
if (distance > newDistance) {
distance = newDistance;
}
break;
case MAX:
if (distance < newDistance) {
distance = newDistance;
}
break;
}
}
return distance;
}
}
}

View File

@ -37,13 +37,15 @@ public class GeoDistanceComparatorSource extends IndexFieldData.XFieldComparator
private final double lon;
private final DistanceUnit unit;
private final GeoDistance geoDistance;
private final SortMode sortMode;
public GeoDistanceComparatorSource(IndexGeoPointFieldData<?> indexFieldData, double lat, double lon, DistanceUnit unit, GeoDistance geoDistance) {
public GeoDistanceComparatorSource(IndexGeoPointFieldData<?> indexFieldData, double lat, double lon, DistanceUnit unit, GeoDistance geoDistance, SortMode sortMode) {
this.indexFieldData = indexFieldData;
this.lat = lat;
this.lon = lon;
this.unit = unit;
this.geoDistance = geoDistance;
this.sortMode = sortMode;
}
@Override
@ -54,7 +56,6 @@ public class GeoDistanceComparatorSource extends IndexFieldData.XFieldComparator
@Override
public FieldComparator<?> newComparator(String fieldname, int numHits, int sortPos, boolean reversed) throws IOException {
assert indexFieldData.getFieldNames().indexName().equals(fieldname);
// TODO support multi value?
return new GeoDistanceComparator(numHits, indexFieldData, lat, lon, unit, geoDistance);
return new GeoDistanceComparator(numHits, indexFieldData, lat, lon, unit, geoDistance, sortMode);
}
}

View File

@ -39,6 +39,7 @@ public class GeoDistanceSortBuilder extends SortBuilder {
private GeoDistance geoDistance;
private DistanceUnit unit;
private SortOrder order;
private String sortMode;
/**
* Constructs a new distance based sort on a geo point like field.
@ -102,6 +103,15 @@ public class GeoDistanceSortBuilder extends SortBuilder {
return this;
}
/**
* Defines which distance to use for sorting in the case a document contains multiple geo points.
* Possible values: min and max
*/
public SortBuilder sortMode(String sortMode) {
this.sortMode = sortMode;
return this;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject("_geo_distance");
@ -121,6 +131,9 @@ public class GeoDistanceSortBuilder extends SortBuilder {
if (order == SortOrder.DESC) {
builder.field("reverse", true);
}
if (sortMode != null) {
builder.field("mode", sortMode);
}
builder.endObject();
return builder;

View File

@ -29,6 +29,7 @@ import org.elasticsearch.common.unit.DistanceUnit;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.fielddata.IndexGeoPointFieldData;
import org.elasticsearch.index.fielddata.fieldcomparator.GeoDistanceComparatorSource;
import org.elasticsearch.index.fielddata.fieldcomparator.SortMode;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.geo.GeoPointFieldMapper;
import org.elasticsearch.search.internal.SearchContext;
@ -50,6 +51,7 @@ public class GeoDistanceSortParser implements SortParser {
DistanceUnit unit = DistanceUnit.KILOMETERS;
GeoDistance geoDistance = GeoDistance.ARC;
boolean reverse = false;
SortMode sortMode = null;
boolean normalizeLon = true;
boolean normalizeLat = true;
@ -96,6 +98,8 @@ public class GeoDistanceSortParser implements SortParser {
} else if ("normalize".equals(currentName)) {
normalizeLat = parser.booleanValue();
normalizeLon = parser.booleanValue();
} else if ("mode".equals(currentName)) {
sortMode = SortMode.fromString(parser.text());
} else {
point.resetFromString(parser.text());
fieldName = currentName;
@ -107,12 +111,16 @@ public class GeoDistanceSortParser implements SortParser {
GeoUtils.normalizePoint(point, normalizeLat, normalizeLon);
}
if (sortMode == null) {
sortMode = reverse ? SortMode.MAX : SortMode.MIN;
}
FieldMapper mapper = context.smartNameFieldMapper(fieldName);
if (mapper == null) {
throw new ElasticSearchIllegalArgumentException("failed to find mapper for [" + fieldName + "] for geo distance based sort");
}
IndexGeoPointFieldData indexFieldData = context.fieldData().getForField(mapper);
return new SortField(fieldName, new GeoDistanceComparatorSource(indexFieldData, point.lat(), point.lon(), unit, geoDistance), reverse);
return new SortField(fieldName, new GeoDistanceComparatorSource(indexFieldData, point.lat(), point.lon(), unit, geoDistance, sortMode), reverse);
}
}

View File

@ -29,19 +29,18 @@ import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.integration.AbstractNodesTests;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import static org.elasticsearch.common.settings.ImmutableSettings.settingsBuilder;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.index.query.FilterBuilders.geoDistanceFilter;
import static org.elasticsearch.index.query.FilterBuilders.geoDistanceRangeFilter;
import static org.elasticsearch.index.query.QueryBuilders.filteredQuery;
import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.*;
/**
*/
@ -244,7 +243,123 @@ public class GeoDistanceTests extends AbstractNodesTests {
assertThat(searchResponse.getHits().getAt(1).id(), equalTo("2"));
assertThat(searchResponse.getHits().getAt(0).id(), equalTo("7"));
}
@Test
public void testDistanceSortingMVFields() throws Exception {
client.admin().indices().prepareDelete().execute().actionGet();
String mapping = XContentFactory.jsonBuilder().startObject().startObject("type1")
.startObject("properties").startObject("locations").field("type", "geo_point").field("lat_lon", true).endObject().endObject()
.endObject().endObject().string();
client.admin().indices().prepareCreate("test")
.setSettings(settingsBuilder().put("index.number_of_shards", 1).put("index.number_of_replicas", 0))
.addMapping("type1", mapping)
.execute().actionGet();
client.admin().cluster().prepareHealth("test").setWaitForEvents(Priority.LANGUID).setWaitForGreenStatus().execute().actionGet();
client.prepareIndex("test", "type1", "1").setSource(jsonBuilder().startObject()
.field("names", "New York")
.startObject("locations").field("lat", 40.7143528).field("lon", -74.0059731).endObject()
.endObject()).execute().actionGet();
client.prepareIndex("test", "type1", "2").setSource(jsonBuilder().startObject()
.field("names", "Times Square", "Tribeca")
.startArray("locations")
// to NY: 5.286 km
.startObject().field("lat", 40.759011).field("lon", -73.9844722).endObject()
// to NY: 0.4621 km
.startObject().field("lat", 40.718266).field("lon", -74.007819).endObject()
.endArray()
.endObject()).execute().actionGet();
client.prepareIndex("test", "type1", "3").setSource(jsonBuilder().startObject()
.field("names", "Wall Street", "Soho")
.startArray("locations")
// to NY: 1.055 km
.startObject().field("lat", 40.7051157).field("lon", -74.0088305).endObject()
// to NY: 1.258 km
.startObject().field("lat", 40.7247222).field("lon", -74).endObject()
.endArray()
.endObject()).execute().actionGet();
client.prepareIndex("test", "type1", "4").setSource(jsonBuilder().startObject()
.field("names", "Greenwich Village", "Brooklyn")
.startArray("locations")
// to NY: 2.029 km
.startObject().field("lat", 40.731033).field("lon", -73.9962255).endObject()
// to NY: 8.572 km
.startObject().field("lat", 40.65).field("lon", -73.95).endObject()
.endArray()
.endObject()).execute().actionGet();
client.admin().indices().prepareRefresh().execute().actionGet();
// Order: Asc
SearchResponse searchResponse = client.prepareSearch("test").setQuery(matchAllQuery())
.addSort(SortBuilders.geoDistanceSort("locations").point(40.7143528, -74.0059731).order(SortOrder.ASC))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(4l));
assertThat(searchResponse.getHits().hits().length, equalTo(4));
assertThat(searchResponse.getHits().getAt(0).id(), equalTo("1"));
assertThat(((Number) searchResponse.getHits().getAt(0).sortValues()[0]).doubleValue(), equalTo(0d));
assertThat(searchResponse.getHits().getAt(1).id(), equalTo("2"));
assertThat(((Number) searchResponse.getHits().getAt(1).sortValues()[0]).doubleValue(), closeTo(0.4621d, 0.01d));
assertThat(searchResponse.getHits().getAt(2).id(), equalTo("3"));
assertThat(((Number) searchResponse.getHits().getAt(2).sortValues()[0]).doubleValue(), closeTo(1.055d, 0.01d));
assertThat(searchResponse.getHits().getAt(3).id(), equalTo("4"));
assertThat(((Number) searchResponse.getHits().getAt(3).sortValues()[0]).doubleValue(), closeTo(2.029d, 0.01d));
// Order: Asc, Mode: max
searchResponse = client.prepareSearch("test").setQuery(matchAllQuery())
.addSort(SortBuilders.geoDistanceSort("locations").point(40.7143528, -74.0059731).order(SortOrder.ASC).sortMode("max"))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(4l));
assertThat(searchResponse.getHits().hits().length, equalTo(4));
assertThat(searchResponse.getHits().getAt(0).id(), equalTo("1"));
assertThat(((Number) searchResponse.getHits().getAt(0).sortValues()[0]).doubleValue(), equalTo(0d));
assertThat(searchResponse.getHits().getAt(1).id(), equalTo("3"));
assertThat(((Number) searchResponse.getHits().getAt(1).sortValues()[0]).doubleValue(), closeTo(1.258d, 0.01d));
assertThat(searchResponse.getHits().getAt(2).id(), equalTo("2"));
assertThat(((Number) searchResponse.getHits().getAt(2).sortValues()[0]).doubleValue(), closeTo(5.286d, 0.01d));
assertThat(searchResponse.getHits().getAt(3).id(), equalTo("4"));
assertThat(((Number) searchResponse.getHits().getAt(3).sortValues()[0]).doubleValue(), closeTo(8.572d, 0.01d));
// Order: Desc
searchResponse = client.prepareSearch("test").setQuery(matchAllQuery())
.addSort(SortBuilders.geoDistanceSort("locations").point(40.7143528, -74.0059731).order(SortOrder.DESC))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(4l));
assertThat(searchResponse.getHits().hits().length, equalTo(4));
assertThat(searchResponse.getHits().getAt(0).id(), equalTo("4"));
assertThat(((Number) searchResponse.getHits().getAt(0).sortValues()[0]).doubleValue(), closeTo(8.572d, 0.01d));
assertThat(searchResponse.getHits().getAt(1).id(), equalTo("2"));
assertThat(((Number) searchResponse.getHits().getAt(1).sortValues()[0]).doubleValue(), closeTo(5.286d, 0.01d));
assertThat(searchResponse.getHits().getAt(2).id(), equalTo("3"));
assertThat(((Number) searchResponse.getHits().getAt(2).sortValues()[0]).doubleValue(), closeTo(1.258d, 0.01d));
assertThat(searchResponse.getHits().getAt(3).id(), equalTo("1"));
assertThat(((Number) searchResponse.getHits().getAt(3).sortValues()[0]).doubleValue(), equalTo(0d));
// Order: Desc, Mode: min
searchResponse = client.prepareSearch("test").setQuery(matchAllQuery())
.addSort(SortBuilders.geoDistanceSort("locations").point(40.7143528, -74.0059731).order(SortOrder.DESC).sortMode("min"))
.execute().actionGet();
assertThat(searchResponse.getHits().getTotalHits(), equalTo(4l));
assertThat(searchResponse.getHits().hits().length, equalTo(4));
assertThat(searchResponse.getHits().getAt(0).id(), equalTo("4"));
assertThat(((Number) searchResponse.getHits().getAt(0).sortValues()[0]).doubleValue(), closeTo(2.029d, 0.01d));
assertThat(searchResponse.getHits().getAt(1).id(), equalTo("3"));
assertThat(((Number) searchResponse.getHits().getAt(1).sortValues()[0]).doubleValue(), closeTo(1.055d, 0.01d));
assertThat(searchResponse.getHits().getAt(2).id(), equalTo("2"));
assertThat(((Number) searchResponse.getHits().getAt(2).sortValues()[0]).doubleValue(), closeTo(0.4621d, 0.01d));
assertThat(searchResponse.getHits().getAt(3).id(), equalTo("1"));
assertThat(((Number) searchResponse.getHits().getAt(3).sortValues()[0]).doubleValue(), equalTo(0d));
}
@Test
public void distanceScriptTests() throws Exception {
try {
@ -252,12 +367,12 @@ public class GeoDistanceTests extends AbstractNodesTests {
} catch (Exception e) {
// ignore
}
double source_lat = 32.798;
double source_long = -117.151;
double target_lat = 32.81;
double target_long = -117.21;
String mapping = XContentFactory.jsonBuilder().startObject().startObject("type1")
.startObject("properties").startObject("location").field("type", "geo_point").field("lat_lon", true).endObject().endObject()
.endObject().endObject().string();
@ -271,27 +386,27 @@ public class GeoDistanceTests extends AbstractNodesTests {
client.admin().indices().prepareRefresh().execute().actionGet();
SearchResponse searchResponse1 = client.prepareSearch().addField("_source").addScriptField("distance", "doc['location'].arcDistance("+target_lat+","+target_long+")").execute().actionGet();
SearchResponse searchResponse1 = client.prepareSearch().addField("_source").addScriptField("distance", "doc['location'].arcDistance(" + target_lat + "," + target_long + ")").execute().actionGet();
Double resultDistance1 = searchResponse1.getHits().getHits()[0].getFields().get("distance").getValue();
assertThat(resultDistance1, equalTo(GeoDistance.ARC.calculate(source_lat, source_long, target_lat, target_long, DistanceUnit.MILES)));
SearchResponse searchResponse2 = client.prepareSearch().addField("_source").addScriptField("distance", "doc['location'].distance("+target_lat+","+target_long+")").execute().actionGet();
SearchResponse searchResponse2 = client.prepareSearch().addField("_source").addScriptField("distance", "doc['location'].distance(" + target_lat + "," + target_long + ")").execute().actionGet();
Double resultDistance2 = searchResponse2.getHits().getHits()[0].getFields().get("distance").getValue();
assertThat(resultDistance2, equalTo(GeoDistance.PLANE.calculate(source_lat, source_long, target_lat, target_long, DistanceUnit.MILES)));
SearchResponse searchResponse3 = client.prepareSearch().addField("_source").addScriptField("distance", "doc['location'].arcDistanceInKm("+target_lat+","+target_long+")").execute().actionGet();
SearchResponse searchResponse3 = client.prepareSearch().addField("_source").addScriptField("distance", "doc['location'].arcDistanceInKm(" + target_lat + "," + target_long + ")").execute().actionGet();
Double resultArcDistance3 = searchResponse3.getHits().getHits()[0].getFields().get("distance").getValue();
assertThat(resultArcDistance3, equalTo(GeoDistance.ARC.calculate(source_lat, source_long, target_lat, target_long, DistanceUnit.KILOMETERS)));
SearchResponse searchResponse4 = client.prepareSearch().addField("_source").addScriptField("distance", "doc['location'].distanceInKm("+target_lat+","+target_long+")").execute().actionGet();
SearchResponse searchResponse4 = client.prepareSearch().addField("_source").addScriptField("distance", "doc['location'].distanceInKm(" + target_lat + "," + target_long + ")").execute().actionGet();
Double resultDistance4 = searchResponse4.getHits().getHits()[0].getFields().get("distance").getValue();
assertThat(resultDistance4, equalTo(GeoDistance.PLANE.calculate(source_lat, source_long, target_lat, target_long, DistanceUnit.KILOMETERS)));
SearchResponse searchResponse5 = client.prepareSearch().addField("_source").addScriptField("distance", "doc['location'].arcDistanceInKm("+(target_lat)+","+(target_long+360)+")").execute().actionGet();
SearchResponse searchResponse5 = client.prepareSearch().addField("_source").addScriptField("distance", "doc['location'].arcDistanceInKm(" + (target_lat) + "," + (target_long + 360) + ")").execute().actionGet();
Double resultArcDistance5 = searchResponse5.getHits().getHits()[0].getFields().get("distance").getValue();
assertThat(resultArcDistance5, equalTo(GeoDistance.ARC.calculate(source_lat, source_long, target_lat, target_long, DistanceUnit.KILOMETERS)));
SearchResponse searchResponse6 = client.prepareSearch().addField("_source").addScriptField("distance", "doc['location'].arcDistanceInKm("+(target_lat+360)+","+(target_long)+")").execute().actionGet();
SearchResponse searchResponse6 = client.prepareSearch().addField("_source").addScriptField("distance", "doc['location'].arcDistanceInKm(" + (target_lat + 360) + "," + (target_long) + ")").execute().actionGet();
Double resultArcDistance6 = searchResponse6.getHits().getHits()[0].getFields().get("distance").getValue();
assertThat(resultArcDistance6, equalTo(GeoDistance.ARC.calculate(source_lat, source_long, target_lat, target_long, DistanceUnit.KILOMETERS)));
}