LUCENE-8952: Use a sort key instead of true distance in NearestNeighbor. (#832)

This commit is contained in:
Julie Tibshirani 2019-08-23 02:16:27 -07:00 committed by Ignacio Vera
parent 1cbc5eaf51
commit 152756fcbd
4 changed files with 52 additions and 41 deletions

View File

@ -102,6 +102,8 @@ Improvements
* SOLR-13663: Introduce <SpanPositionRange> into XML Query Parser (Alessandro Benedetti via Mikhail Khludnev)
* LUCENE-8952: Use a sort key instead of true distance in NearestNeighbor (Julie Tibshirani).
Optimizations
* LUCENE-8922: DisjunctionMaxQuery more efficiently leverages impacts to skip

View File

@ -27,6 +27,7 @@ import org.apache.lucene.geo.GeoUtils;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.SloppyMath;
import org.apache.lucene.util.bkd.BKDReader;
/**
@ -104,7 +105,8 @@ public class LatLonPointPrototypeQueries {
ScoreDoc[] scoreDocs = new ScoreDoc[hits.length];
for(int i=0;i<hits.length;i++) {
NearestNeighbor.NearestHit hit = hits[i];
scoreDocs[i] = new FieldDoc(hit.docID, 0.0f, new Object[] {Double.valueOf(hit.distanceMeters)});
double hitDistance = SloppyMath.haversinMeters(hit.distanceSortKey);
scoreDocs[i] = new FieldDoc(hit.docID, 0.0f, new Object[] {Double.valueOf(hitDistance)});
}
return new TopFieldDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs, null);
}

View File

@ -28,9 +28,9 @@ import org.apache.lucene.index.PointValues.Relation;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SloppyMath;
import org.apache.lucene.util.bkd.BKDReader;
import org.apache.lucene.util.bkd.BKDReader.IndexTree;
import org.apache.lucene.util.bkd.BKDReader.IntersectState;
import org.apache.lucene.util.bkd.BKDReader;
import static org.apache.lucene.geo.GeoEncodingUtils.decodeLatitude;
import static org.apache.lucene.geo.GeoEncodingUtils.decodeLongitude;
@ -48,19 +48,23 @@ class NearestNeighbor {
final byte[] maxPacked;
final IndexTree index;
/** The closest possible distance of all points in this cell */
final double distanceMeters;
/**
* The closest distance from a point in this cell to the query point, computed as a sort key through
* {@link SloppyMath#haversinSortKey}. Note that this is an approximation to the closest distance,
* and there could be a point in the cell that is closer.
*/
final double distanceSortKey;
public Cell(IndexTree index, int readerIndex, byte[] minPacked, byte[] maxPacked, double distanceMeters) {
public Cell(IndexTree index, int readerIndex, byte[] minPacked, byte[] maxPacked, double distanceSortKey) {
this.index = index;
this.readerIndex = readerIndex;
this.minPacked = minPacked.clone();
this.maxPacked = maxPacked.clone();
this.distanceMeters = distanceMeters;
this.distanceSortKey = distanceSortKey;
}
public int compareTo(Cell other) {
return Double.compare(distanceMeters, other.distanceMeters);
return Double.compare(distanceSortKey, other.distanceSortKey);
}
@Override
@ -69,7 +73,7 @@ class NearestNeighbor {
double minLon = decodeLongitude(minPacked, Integer.BYTES);
double maxLat = decodeLatitude(maxPacked, 0);
double maxLon = decodeLongitude(maxPacked, Integer.BYTES);
return "Cell(readerIndex=" + readerIndex + " nodeID=" + index.getNodeID() + " isLeaf=" + index.isLeafNode() + " lat=" + minLat + " TO " + maxLat + ", lon=" + minLon + " TO " + maxLon + "; distanceMeters=" + distanceMeters + ")";
return "Cell(readerIndex=" + readerIndex + " nodeID=" + index.getNodeID() + " isLeaf=" + index.isLeafNode() + " lat=" + minLat + " TO " + maxLat + ", lon=" + minLon + " TO " + maxLon + "; distanceSortKey=" + distanceSortKey + ")";
}
}
@ -106,7 +110,8 @@ class NearestNeighbor {
private void maybeUpdateBBox() {
if (setBottomCounter < 1024 || (setBottomCounter & 0x3F) == 0x3F) {
NearestHit hit = hitQueue.peek();
Rectangle box = Rectangle.fromPointDistance(pointLat, pointLon, hit.distanceMeters);
Rectangle box = Rectangle.fromPointDistance(pointLat, pointLon,
SloppyMath.haversinMeters(hit.distanceSortKey));
//System.out.println(" update bbox to " + box);
minLat = box.minLat;
maxLat = box.maxLat;
@ -134,8 +139,6 @@ class NearestNeighbor {
return;
}
// TODO: work in int space, use haversinSortKey
double docLatitude = decodeLatitude(packedValue, 0);
double docLongitude = decodeLongitude(packedValue, Integer.BYTES);
@ -147,21 +150,22 @@ class NearestNeighbor {
return;
}
double distanceMeters = SloppyMath.haversinMeters(pointLat, pointLon, docLatitude, docLongitude);
// Use the haversin sort key when comparing hits, as it is faster to compute than the true distance.
double distanceSortKey = SloppyMath.haversinSortKey(pointLat, pointLon, docLatitude, docLongitude);
//System.out.println(" visit docID=" + docID + " distanceMeters=" + distanceMeters + " docLat=" + docLatitude + " docLon=" + docLongitude);
//System.out.println(" visit docID=" + docID + " distanceSortKey=" + distanceSortKey + " docLat=" + docLatitude + " docLon=" + docLongitude);
int fullDocID = curDocBase + docID;
if (hitQueue.size() == topN) {
// queue already full
NearestHit hit = hitQueue.peek();
//System.out.println(" bottom distanceMeters=" + hit.distanceMeters);
//System.out.println(" bottom distanceSortKey=" + hit.distanceSortKey);
// we don't collect docs in order here, so we must also test the tie-break case ourselves:
if (distanceMeters < hit.distanceMeters || (distanceMeters == hit.distanceMeters && fullDocID < hit.docID)) {
if (distanceSortKey < hit.distanceSortKey || (distanceSortKey == hit.distanceSortKey && fullDocID < hit.docID)) {
hitQueue.poll();
hit.docID = fullDocID;
hit.distanceMeters = distanceMeters;
hit.distanceSortKey = distanceSortKey;
hitQueue.offer(hit);
//System.out.println(" ** keep2, now bottom=" + hit);
maybeUpdateBBox();
@ -170,7 +174,7 @@ class NearestNeighbor {
} else {
NearestHit hit = new NearestHit();
hit.docID = fullDocID;
hit.distanceMeters = distanceMeters;
hit.distanceSortKey = distanceSortKey;
hitQueue.offer(hit);
//System.out.println(" ** keep1, now bottom=" + hit);
}
@ -182,14 +186,18 @@ class NearestNeighbor {
}
}
/** Holds one hit from {@link LatLonPointPrototypeQueries#nearest} */
/** Holds one hit from {@link NearestNeighbor#nearest} */
static class NearestHit {
public int docID;
public double distanceMeters;
/**
* The distance from the hit to the query point, computed as a sort key through {@link SloppyMath#haversinSortKey}.
*/
public double distanceSortKey;
@Override
public String toString() {
return "NearestHit(docID=" + docID + " distanceMeters=" + distanceMeters + ")";
return "NearestHit(docID=" + docID + " distanceSortKey=" + distanceSortKey + ")";
}
}
@ -204,8 +212,8 @@ class NearestNeighbor {
final PriorityQueue<NearestHit> hitQueue = new PriorityQueue<>(n, new Comparator<NearestHit>() {
@Override
public int compare(NearestHit a, NearestHit b) {
// sort by opposite distanceMeters natural order
int cmp = Double.compare(a.distanceMeters, b.distanceMeters);
// sort by opposite distanceSortKey natural order
int cmp = Double.compare(a.distanceSortKey, b.distanceSortKey);
if (cmp != 0) {
return -cmp;
}
@ -319,10 +327,10 @@ class NearestNeighbor {
return 0.0;
}
double d1 = SloppyMath.haversinMeters(pointLat, pointLon, minLat, minLon);
double d2 = SloppyMath.haversinMeters(pointLat, pointLon, minLat, maxLon);
double d3 = SloppyMath.haversinMeters(pointLat, pointLon, maxLat, maxLon);
double d4 = SloppyMath.haversinMeters(pointLat, pointLon, maxLat, minLon);
double d1 = SloppyMath.haversinSortKey(pointLat, pointLon, minLat, minLon);
double d2 = SloppyMath.haversinSortKey(pointLat, pointLon, minLat, maxLon);
double d3 = SloppyMath.haversinSortKey(pointLat, pointLon, maxLat, maxLon);
double d4 = SloppyMath.haversinSortKey(pointLat, pointLon, maxLat, minLon);
return Math.min(Math.min(d1, d2), Math.min(d3, d4));
}

View File

@ -26,7 +26,6 @@ import org.apache.lucene.document.LatLonDocValuesField;
import org.apache.lucene.document.LatLonPoint;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.search.NearestNeighbor.NearestHit;
import org.apache.lucene.geo.GeoEncodingUtils;
import org.apache.lucene.geo.GeoTestUtil;
import org.apache.lucene.index.DirectoryReader;
@ -190,23 +189,22 @@ public class TestNearest extends LuceneTestCase {
double pointLon = GeoTestUtil.nextLongitude();
// dumb brute force search to get the expected result:
NearestHit[] expectedHits = new NearestHit[lats.length];
FieldDoc[] expectedHits = new FieldDoc[lats.length];
for(int id=0;id<lats.length;id++) {
NearestHit hit = new NearestHit();
hit.distanceMeters = SloppyMath.haversinMeters(pointLat, pointLon, lats[id], lons[id]);
hit.docID = id;
double distance = SloppyMath.haversinMeters(pointLat, pointLon, lats[id], lons[id]);
FieldDoc hit = new FieldDoc(id, 0.0f, new Object[] {Double.valueOf(distance)});
expectedHits[id] = hit;
}
Arrays.sort(expectedHits, new Comparator<NearestHit>() {
Arrays.sort(expectedHits, new Comparator<FieldDoc>() {
@Override
public int compare(NearestHit a, NearestHit b) {
int cmp = Double.compare(a.distanceMeters, b.distanceMeters);
public int compare(FieldDoc a, FieldDoc b) {
int cmp = Double.compare(((Double) a.fields[0]).doubleValue(), ((Double) b.fields[0]).doubleValue());
if (cmp != 0) {
return cmp;
}
// tie break by smaller docID:
return a.docID - b.docID;
return a.doc - b.doc;
}
});
@ -221,22 +219,23 @@ public class TestNearest extends LuceneTestCase {
ScoreDoc[] hits = LatLonPointPrototypeQueries.nearest(s, "point", pointLat, pointLon, topN).scoreDocs;
for(int i=0;i<topN;i++) {
NearestHit expected = expectedHits[i];
FieldDoc expected = expectedHits[i];
FieldDoc expected2 = (FieldDoc) fieldDocs.scoreDocs[i];
FieldDoc actual = (FieldDoc) hits[i];
Document actualDoc = r.document(actual.doc);
if (VERBOSE) {
System.out.println("hit " + i);
System.out.println(" expected id=" + expected.docID + " lat=" + lats[expected.docID] + " lon=" + lons[expected.docID] + " distance=" + expected.distanceMeters + " meters");
System.out.println(" expected id=" + expected.doc+ " lat=" + lats[expected.doc] + " lon=" + lons[expected.doc]
+ " distance=" + ((Double) expected.fields[0]).doubleValue() + " meters");
System.out.println(" actual id=" + actualDoc.getField("id") + " distance=" + actual.fields[0] + " meters");
}
assertEquals(expected.docID, actual.doc);
assertEquals(expected.distanceMeters, ((Double) actual.fields[0]).doubleValue(), 0.0);
assertEquals(expected.doc, actual.doc);
assertEquals(((Double) expected.fields[0]).doubleValue(), ((Double) actual.fields[0]).doubleValue(), 0.0);
assertEquals(expected.docID, expected.docID);
assertEquals(((Double) expected2.fields[0]).doubleValue(), expected.distanceMeters, 0.0);
assertEquals(expected2.doc, actual.doc);
assertEquals(((Double) expected2.fields[0]).doubleValue(), ((Double) actual.fields[0]).doubleValue(), 0.0);
}
}