LUCENE-8976: Use exact distance between point and bounding rectangle in FloatPointNearestNeighbor (#874)

This commit is contained in:
Ignacio Vera 2019-09-12 07:48:40 +02:00 committed by GitHub
parent fb5a3e28fe
commit 579fae5f0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 133 deletions

View File

@ -133,6 +133,8 @@ Improvements
* LUCENE-8964: Fix geojson shape parsing on string arrays in properties
(Alexander Reelsen)
* LUCENE-8976: Use exact distance between point and bounding rectangle in FloatPointNearestNeighbor. (Ignacio Vera)
Optimizations
* LUCENE-8922: DisjunctionMaxQuery more efficiently leverages impacts to skip

View File

@ -19,7 +19,6 @@ package org.apache.lucene.document;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
@ -46,7 +45,6 @@ public class FloatPointNearestNeighbor {
final byte[] minPacked;
final byte[] maxPacked;
final BKDReader.IndexTree index;
/** The closest possible distance^2 of all points in this cell */
final double distanceSquared;
@ -75,21 +73,15 @@ public class FloatPointNearestNeighbor {
final int topN;
final PriorityQueue<NearestHit> hitQueue;
final float[] origin;
private int dims;
private int updateMinMaxCounter;
private float[] min;
private float[] max;
final private int dims;
double bottomNearestDistanceSquared = Double.POSITIVE_INFINITY;
int bottomNearestDistanceDoc = Integer.MAX_VALUE;
public NearestVisitor(PriorityQueue<NearestHit> hitQueue, int topN, float[] origin) {
this.hitQueue = hitQueue;
this.topN = topN;
this.origin = origin;
dims = origin.length;
min = new float[dims];
max = new float[dims];
Arrays.fill(min, Float.NEGATIVE_INFINITY);
Arrays.fill(max, Float.POSITIVE_INFINITY);
this.dims = origin.length;
}
@Override
@ -97,110 +89,59 @@ public class FloatPointNearestNeighbor {
throw new AssertionError();
}
private static final int MANTISSA_BITS = 23;
/**
* Returns the minimum value that will change the given distance when added to it.
*
* This value is calculated from the distance exponent reduced by (at most) 23,
* the number of bits in a float mantissa. This is necessary when the result of
* subtracting/adding the distance in a single dimension has an exponent that
* differs significantly from that of the distance value. Without this fudge
* factor (i.e. only subtracting/adding the distance), cells and values can be
* inappropriately judged as outside the search radius.
*/
private float getMinDelta(float distance) {
int exponent = Float.floatToIntBits(distance) >> MANTISSA_BITS; // extract biased exponent (distance is positive)
if (exponent == 0) {
return Float.MIN_VALUE;
} else {
exponent = exponent <= MANTISSA_BITS ? 1 : exponent - MANTISSA_BITS; // Avoid underflow
return Float.intBitsToFloat(exponent << MANTISSA_BITS);
}
}
private void maybeUpdateMinMax() {
if (updateMinMaxCounter < 1024 || (updateMinMaxCounter & 0x3F) == 0x3F) {
NearestHit hit = hitQueue.peek();
float distance = (float)Math.sqrt(hit.distanceSquared);
float minDelta = getMinDelta(distance);
// String oldMin = Arrays.toString(min);
// String oldMax = Arrays.toString(max);
for (int d = 0 ; d < dims ; ++d) {
min[d] = (origin[d] - distance) - minDelta;
max[d] = (origin[d] + distance) + minDelta;
// System.out.println("origin[" + d + "] (" + origin[d] + ") - distance (" + distance + ") - minDelta (" + minDelta + ") = min[" + d + "] (" + min[d] + ")");
// System.out.println("origin[" + d + "] (" + origin[d] + ") + distance (" + distance + ") + minDelta (" + minDelta + ") = max[" + d + "] (" + max[d] + ")");
}
// System.out.println("maybeUpdateMinMax: min: " + oldMin + " -> " + Arrays.toString(min) + " max: " + oldMax + " -> " + Arrays.toString(max));
}
++updateMinMaxCounter;
}
@Override
public void visit(int docID, byte[] packedValue) {
// System.out.println("visit docID=" + docID + " liveDocs=" + curLiveDocs);
// System.out.println("visit docID=" + docID + " liveDocs=" + curLiveDocs);;
if (curLiveDocs != null && curLiveDocs.get(docID) == false) {
return;
}
float[] docPoint = new float[dims];
double distanceSquared = 0.0d;
for (int d = 0, offset = 0 ; d < dims ; ++d, offset += Float.BYTES) {
docPoint[d] = FloatPoint.decodeDimension(packedValue, offset);
if (docPoint[d] > max[d] || docPoint[d] < min[d]) {
// if (docPoint[d] > max[d]) {
// System.out.println(" skipped because docPoint[" + d + "] (" + docPoint[d] + ") > max[" + d + "] (" + max[d] + ")");
// } else {
// System.out.println(" skipped because docPoint[" + d + "] (" + docPoint[d] + ") < min[" + d + "] (" + min[d] + ")");
// }
double diff = (double) FloatPoint.decodeDimension(packedValue, offset) - (double) origin[d];
distanceSquared += diff * diff;
if (distanceSquared > bottomNearestDistanceSquared) {
return;
}
}
double distanceSquared = euclideanDistanceSquared(origin, docPoint);
// System.out.println(" visit docID=" + docID + " distanceSquared=" + distanceSquared + " value: " + Arrays.toString(docPoint));
int fullDocID = curDocBase + docID;
if (hitQueue.size() == topN) { // queue already full
NearestHit bottom = hitQueue.peek();
// System.out.println(" bottom distanceSquared=" + bottom.distanceSquared);
if (distanceSquared < bottom.distanceSquared
// we don't collect docs in order here, so we must also test the tie-break case ourselves:
|| (distanceSquared == bottom.distanceSquared && fullDocID < bottom.docID)) {
hitQueue.poll();
bottom.docID = fullDocID;
bottom.distanceSquared = distanceSquared;
hitQueue.offer(bottom);
// System.out.println(" ** keep1, now bottom=" + bottom);
maybeUpdateMinMax();
if (distanceSquared == bottomNearestDistanceSquared && fullDocID > bottomNearestDistanceDoc) {
return;
}
NearestHit bottom = hitQueue.poll();
// System.out.println(" bottom distanceSquared=" + bottom.distanceSquared);
bottom.docID = fullDocID;
bottom.distanceSquared = distanceSquared;
hitQueue.offer(bottom);
updateBottomNearestDistance();
// System.out.println(" ** keep1, now bottom=" + bottom);
} else {
NearestHit hit = new NearestHit();
hit.docID = fullDocID;
hit.distanceSquared = distanceSquared;
hitQueue.offer(hit);
if (hitQueue.size() == topN) {
updateBottomNearestDistance();
}
// System.out.println(" ** keep2, new addition=" + hit);
}
}
private void updateBottomNearestDistance() {
NearestHit newBottom = hitQueue.peek();
bottomNearestDistanceSquared = newBottom.distanceSquared;
bottomNearestDistanceDoc = newBottom.docID;
}
@Override
public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
for (int d = 0, offset = 0; d < dims; ++d, offset += Float.BYTES) {
float cellMaxAtDim = FloatPoint.decodeDimension(maxPackedValue, offset);
if (cellMaxAtDim < min[d]) {
// System.out.println(" skipped because cell max at " + d + " (" + cellMaxAtDim + ") < visitor.min[" + d + "] (" + min[d] + ")");
return PointValues.Relation.CELL_OUTSIDE_QUERY;
}
float cellMinAtDim = FloatPoint.decodeDimension(minPackedValue, offset);
if (cellMinAtDim > max[d]) {
// System.out.println(" skipped because cell min at " + d + " (" + cellMinAtDim + ") > visitor.max[" + d + "] (" + max[d] + ")");
return PointValues.Relation.CELL_OUTSIDE_QUERY;
}
if (hitQueue.size() == topN && pointToRectangleDistanceSquared(minPackedValue, maxPackedValue, origin) > bottomNearestDistanceSquared) {
return PointValues.Relation.CELL_OUTSIDE_QUERY;
}
return PointValues.Relation.CELL_CROSSES_QUERY;
}
@ -252,33 +193,31 @@ public class FloatPointNearestNeighbor {
states.add(state);
cellQueue.offer(new Cell(state.index, i, reader.getMinPackedValue(), reader.getMaxPackedValue(),
approxBestDistanceSquared(minPackedValue, maxPackedValue, origin)));
pointToRectangleDistanceSquared(minPackedValue, maxPackedValue, origin)));
}
while (cellQueue.size() > 0) {
Cell cell = cellQueue.poll();
// System.out.println(" visit " + cell);
// TODO: if we replace approxBestDistance with actualBestDistance, we can put an opto here to break once this "best" cell is fully outside of the hitQueue bottom's radius:
BKDReader reader = readers.get(cell.readerIndex);
if (cell.distanceSquared > visitor.bottomNearestDistanceSquared) {
break;
}
BKDReader reader = readers.get(cell.readerIndex);
if (cell.index.isLeafNode()) {
// System.out.println(" leaf");
// Leaf block: visit all points and possibly collect them:
visitor.curDocBase = docBases.get(cell.readerIndex);
visitor.curLiveDocs = liveDocs.get(cell.readerIndex);
reader.visitLeafBlockValues(cell.index, states.get(cell.readerIndex));
//assert hitQueue.peek().distanceSquared >= cell.distanceSquared;
// System.out.println(" now " + hitQueue.size() + " hits");
} else {
// System.out.println(" non-leaf");
// Non-leaf block: split into two cells and put them back into the queue:
if (hitQueue.size() == topN) {
if (visitor.compare(cell.minPacked, cell.maxPacked) == PointValues.Relation.CELL_OUTSIDE_QUERY) {
// this cell is outside our search radius; don't bother exploring any more
continue;
}
}
BytesRef splitValue = BytesRef.deepCopyOf(cell.index.getSplitDimValue());
int splitDim = cell.index.getSplitDim();
@ -288,15 +227,19 @@ public class FloatPointNearestNeighbor {
System.arraycopy(splitValue.bytes, splitValue.offset, splitPackedValue, splitDim * bytesPerDim, bytesPerDim);
cell.index.pushLeft();
cellQueue.offer(new Cell(cell.index, cell.readerIndex, cell.minPacked, splitPackedValue,
approxBestDistanceSquared(cell.minPacked, splitPackedValue, origin)));
double distanceLeft = pointToRectangleDistanceSquared(cell.minPacked, splitPackedValue, origin);
if (distanceLeft <= visitor.bottomNearestDistanceSquared) {
cellQueue.offer(new Cell(cell.index, cell.readerIndex, cell.minPacked, splitPackedValue, distanceLeft));
}
splitPackedValue = cell.minPacked.clone();
System.arraycopy(splitValue.bytes, splitValue.offset, splitPackedValue, splitDim * bytesPerDim, bytesPerDim);
newIndex.pushRight();
cellQueue.offer(new Cell(newIndex, cell.readerIndex, splitPackedValue, cell.maxPacked,
approxBestDistanceSquared(splitPackedValue, cell.maxPacked, origin)));
double distanceRight = pointToRectangleDistanceSquared(splitPackedValue, cell.maxPacked, origin);
if (distanceRight <= visitor.bottomNearestDistanceSquared) {
cellQueue.offer(new Cell(newIndex, cell.readerIndex, splitPackedValue, cell.maxPacked, distanceRight));
}
}
}
@ -306,45 +249,28 @@ public class FloatPointNearestNeighbor {
hits[downTo] = hitQueue.poll();
downTo--;
}
//System.out.println(visitor.comp);
return hits;
}
private static double approxBestDistanceSquared(byte[] minPackedValue, byte[] maxPackedValue, float[] value) {
boolean insideCell = true;
float[] min = new float[value.length];
float[] max = new float[value.length];
double[] closest = new double[value.length];
for (int i = 0, offset = 0 ; i < value.length ; ++i, offset += Float.BYTES) {
min[i] = FloatPoint.decodeDimension(minPackedValue, offset);
max[i] = FloatPoint.decodeDimension(maxPackedValue, offset);
if (insideCell) {
if (value[i] < min[i] || value[i] > max[i]) {
insideCell = false;
}
}
double minDiff = Math.abs((double)value[i] - (double)min[i]);
double maxDiff = Math.abs((double)value[i] - (double)max[i]);
closest[i] = minDiff < maxDiff ? minDiff : maxDiff;
}
if (insideCell) {
return 0.0f;
}
private static double pointToRectangleDistanceSquared(byte[] minPackedValue, byte[] maxPackedValue, float[] value) {
double sumOfSquaredDiffs = 0.0d;
for (int d = 0 ; d < value.length ; ++d) {
sumOfSquaredDiffs += closest[d] * closest[d];
for (int i = 0, offset = 0 ; i < value.length ; ++i, offset += Float.BYTES) {
double min = FloatPoint.decodeDimension(minPackedValue, offset);
if (value[i] < min) {
double diff = min - (double)value[i];
sumOfSquaredDiffs += diff * diff;
continue;
}
double max = FloatPoint.decodeDimension(maxPackedValue, offset);
if (value[i] > max) {
double diff = max - (double)value[i];
sumOfSquaredDiffs += diff * diff;
}
}
return sumOfSquaredDiffs;
}
static double euclideanDistanceSquared(float[] a, float[] b) {
double sumOfSquaredDifferences = 0.0d;
for (int d = 0 ; d < a.length ; ++d) {
double diff = (double)a[d] - (double)b[d];
sumOfSquaredDifferences += diff * diff;
}
return sumOfSquaredDifferences;
}
public static TopFieldDocs nearest(IndexSearcher searcher, String field, int topN, float... origin) throws IOException {
if (topN < 1) {
throw new IllegalArgumentException("topN must be at least 1; got " + topN);

View File

@ -188,7 +188,7 @@ public class TestFloatPointNearestNeighbor extends LuceneTestCase {
FloatPointNearestNeighbor.NearestHit[] expectedHits = new FloatPointNearestNeighbor.NearestHit[numPoints];
for (int id = 0 ; id < numPoints ; ++id) {
FloatPointNearestNeighbor.NearestHit hit = new FloatPointNearestNeighbor.NearestHit();
hit.distanceSquared = FloatPointNearestNeighbor.euclideanDistanceSquared(origin, values[id]);
hit.distanceSquared = euclideanDistanceSquared(origin, values[id]);
hit.docID = id;
expectedHits[id] = hit;
}
@ -232,6 +232,15 @@ public class TestFloatPointNearestNeighbor extends LuceneTestCase {
dir.close();
}
private static double euclideanDistanceSquared(float[] a, float[] b) {
double sumOfSquaredDifferences = 0.0d;
for (int d = 0 ; d < a.length ; ++d) {
double diff = (double)a[d] - (double)b[d];
sumOfSquaredDifferences += diff * diff;
}
return sumOfSquaredDifferences;
}
private IndexWriterConfig getIndexWriterConfig() {
IndexWriterConfig iwc = newIndexWriterConfig();
iwc.setCodec(Codec.forName("Lucene80"));