From 579fae5f0cea6219f474a4fb3c0a32a68cc98826 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Thu, 12 Sep 2019 07:48:40 +0200 Subject: [PATCH] LUCENE-8976: Use exact distance between point and bounding rectangle in FloatPointNearestNeighbor (#874) --- lucene/CHANGES.txt | 2 + .../document/FloatPointNearestNeighbor.java | 190 ++++++------------ .../TestFloatPointNearestNeighbor.java | 11 +- 3 files changed, 70 insertions(+), 133 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index e02db9ae933..c3d17498699 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -133,6 +133,8 @@ Improvements * LUCENE-8964: Fix geojson shape parsing on string arrays in properties (Alexander Reelsen) +* LUCENE-8976: Use exact distance between point and bounding rectangle in FloatPointNearestNeighbor. (Ignacio Vera) + Optimizations * LUCENE-8922: DisjunctionMaxQuery more efficiently leverages impacts to skip diff --git a/lucene/sandbox/src/java/org/apache/lucene/document/FloatPointNearestNeighbor.java b/lucene/sandbox/src/java/org/apache/lucene/document/FloatPointNearestNeighbor.java index eb3db1aa593..789e01a84d5 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/document/FloatPointNearestNeighbor.java +++ b/lucene/sandbox/src/java/org/apache/lucene/document/FloatPointNearestNeighbor.java @@ -19,7 +19,6 @@ package org.apache.lucene.document; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.PriorityQueue; @@ -46,7 +45,6 @@ public class FloatPointNearestNeighbor { final byte[] minPacked; final byte[] maxPacked; final BKDReader.IndexTree index; - /** The closest possible distance^2 of all points in this cell */ final double distanceSquared; @@ -75,21 +73,15 @@ public class FloatPointNearestNeighbor { final int topN; final PriorityQueue hitQueue; final float[] origin; - private int dims; - private int updateMinMaxCounter; - private float[] min; - private float[] max; - + final private int dims; + double bottomNearestDistanceSquared = Double.POSITIVE_INFINITY; + int bottomNearestDistanceDoc = Integer.MAX_VALUE; public NearestVisitor(PriorityQueue hitQueue, int topN, float[] origin) { this.hitQueue = hitQueue; this.topN = topN; this.origin = origin; - dims = origin.length; - min = new float[dims]; - max = new float[dims]; - Arrays.fill(min, Float.NEGATIVE_INFINITY); - Arrays.fill(max, Float.POSITIVE_INFINITY); + this.dims = origin.length; } @Override @@ -97,110 +89,59 @@ public class FloatPointNearestNeighbor { throw new AssertionError(); } - private static final int MANTISSA_BITS = 23; - - /** - * Returns the minimum value that will change the given distance when added to it. - * - * This value is calculated from the distance exponent reduced by (at most) 23, - * the number of bits in a float mantissa. This is necessary when the result of - * subtracting/adding the distance in a single dimension has an exponent that - * differs significantly from that of the distance value. Without this fudge - * factor (i.e. only subtracting/adding the distance), cells and values can be - * inappropriately judged as outside the search radius. - */ - private float getMinDelta(float distance) { - int exponent = Float.floatToIntBits(distance) >> MANTISSA_BITS; // extract biased exponent (distance is positive) - if (exponent == 0) { - return Float.MIN_VALUE; - } else { - exponent = exponent <= MANTISSA_BITS ? 1 : exponent - MANTISSA_BITS; // Avoid underflow - return Float.intBitsToFloat(exponent << MANTISSA_BITS); - } - } - - private void maybeUpdateMinMax() { - if (updateMinMaxCounter < 1024 || (updateMinMaxCounter & 0x3F) == 0x3F) { - NearestHit hit = hitQueue.peek(); - float distance = (float)Math.sqrt(hit.distanceSquared); - float minDelta = getMinDelta(distance); - // String oldMin = Arrays.toString(min); - // String oldMax = Arrays.toString(max); - for (int d = 0 ; d < dims ; ++d) { - min[d] = (origin[d] - distance) - minDelta; - max[d] = (origin[d] + distance) + minDelta; - // System.out.println("origin[" + d + "] (" + origin[d] + ") - distance (" + distance + ") - minDelta (" + minDelta + ") = min[" + d + "] (" + min[d] + ")"); - // System.out.println("origin[" + d + "] (" + origin[d] + ") + distance (" + distance + ") + minDelta (" + minDelta + ") = max[" + d + "] (" + max[d] + ")"); - } - // System.out.println("maybeUpdateMinMax: min: " + oldMin + " -> " + Arrays.toString(min) + " max: " + oldMax + " -> " + Arrays.toString(max)); - } - ++updateMinMaxCounter; - } - @Override public void visit(int docID, byte[] packedValue) { - // System.out.println("visit docID=" + docID + " liveDocs=" + curLiveDocs); - + // System.out.println("visit docID=" + docID + " liveDocs=" + curLiveDocs);; if (curLiveDocs != null && curLiveDocs.get(docID) == false) { return; } - float[] docPoint = new float[dims]; + double distanceSquared = 0.0d; for (int d = 0, offset = 0 ; d < dims ; ++d, offset += Float.BYTES) { - docPoint[d] = FloatPoint.decodeDimension(packedValue, offset); - if (docPoint[d] > max[d] || docPoint[d] < min[d]) { - - // if (docPoint[d] > max[d]) { - // System.out.println(" skipped because docPoint[" + d + "] (" + docPoint[d] + ") > max[" + d + "] (" + max[d] + ")"); - // } else { - // System.out.println(" skipped because docPoint[" + d + "] (" + docPoint[d] + ") < min[" + d + "] (" + min[d] + ")"); - // } - + double diff = (double) FloatPoint.decodeDimension(packedValue, offset) - (double) origin[d]; + distanceSquared += diff * diff; + if (distanceSquared > bottomNearestDistanceSquared) { return; } } - - double distanceSquared = euclideanDistanceSquared(origin, docPoint); // System.out.println(" visit docID=" + docID + " distanceSquared=" + distanceSquared + " value: " + Arrays.toString(docPoint)); int fullDocID = curDocBase + docID; if (hitQueue.size() == topN) { // queue already full - NearestHit bottom = hitQueue.peek(); - // System.out.println(" bottom distanceSquared=" + bottom.distanceSquared); - if (distanceSquared < bottom.distanceSquared - // we don't collect docs in order here, so we must also test the tie-break case ourselves: - || (distanceSquared == bottom.distanceSquared && fullDocID < bottom.docID)) { - hitQueue.poll(); - bottom.docID = fullDocID; - bottom.distanceSquared = distanceSquared; - hitQueue.offer(bottom); - // System.out.println(" ** keep1, now bottom=" + bottom); - maybeUpdateMinMax(); + if (distanceSquared == bottomNearestDistanceSquared && fullDocID > bottomNearestDistanceDoc) { + return; } + NearestHit bottom = hitQueue.poll(); + // System.out.println(" bottom distanceSquared=" + bottom.distanceSquared); + bottom.docID = fullDocID; + bottom.distanceSquared = distanceSquared; + hitQueue.offer(bottom); + updateBottomNearestDistance(); + // System.out.println(" ** keep1, now bottom=" + bottom); } else { NearestHit hit = new NearestHit(); hit.docID = fullDocID; hit.distanceSquared = distanceSquared; hitQueue.offer(hit); + if (hitQueue.size() == topN) { + updateBottomNearestDistance(); + } // System.out.println(" ** keep2, new addition=" + hit); } } + private void updateBottomNearestDistance() { + NearestHit newBottom = hitQueue.peek(); + bottomNearestDistanceSquared = newBottom.distanceSquared; + bottomNearestDistanceDoc = newBottom.docID; + } + @Override public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - for (int d = 0, offset = 0; d < dims; ++d, offset += Float.BYTES) { - float cellMaxAtDim = FloatPoint.decodeDimension(maxPackedValue, offset); - if (cellMaxAtDim < min[d]) { - // System.out.println(" skipped because cell max at " + d + " (" + cellMaxAtDim + ") < visitor.min[" + d + "] (" + min[d] + ")"); - return PointValues.Relation.CELL_OUTSIDE_QUERY; - } - float cellMinAtDim = FloatPoint.decodeDimension(minPackedValue, offset); - if (cellMinAtDim > max[d]) { - // System.out.println(" skipped because cell min at " + d + " (" + cellMinAtDim + ") > visitor.max[" + d + "] (" + max[d] + ")"); - return PointValues.Relation.CELL_OUTSIDE_QUERY; - } + if (hitQueue.size() == topN && pointToRectangleDistanceSquared(minPackedValue, maxPackedValue, origin) > bottomNearestDistanceSquared) { + return PointValues.Relation.CELL_OUTSIDE_QUERY; } return PointValues.Relation.CELL_CROSSES_QUERY; } @@ -252,33 +193,31 @@ public class FloatPointNearestNeighbor { states.add(state); cellQueue.offer(new Cell(state.index, i, reader.getMinPackedValue(), reader.getMaxPackedValue(), - approxBestDistanceSquared(minPackedValue, maxPackedValue, origin))); + pointToRectangleDistanceSquared(minPackedValue, maxPackedValue, origin))); } while (cellQueue.size() > 0) { Cell cell = cellQueue.poll(); // System.out.println(" visit " + cell); - // TODO: if we replace approxBestDistance with actualBestDistance, we can put an opto here to break once this "best" cell is fully outside of the hitQueue bottom's radius: - BKDReader reader = readers.get(cell.readerIndex); + if (cell.distanceSquared > visitor.bottomNearestDistanceSquared) { + break; + } + BKDReader reader = readers.get(cell.readerIndex); if (cell.index.isLeafNode()) { // System.out.println(" leaf"); // Leaf block: visit all points and possibly collect them: visitor.curDocBase = docBases.get(cell.readerIndex); visitor.curLiveDocs = liveDocs.get(cell.readerIndex); reader.visitLeafBlockValues(cell.index, states.get(cell.readerIndex)); + + //assert hitQueue.peek().distanceSquared >= cell.distanceSquared; // System.out.println(" now " + hitQueue.size() + " hits"); } else { // System.out.println(" non-leaf"); // Non-leaf block: split into two cells and put them back into the queue: - if (hitQueue.size() == topN) { - if (visitor.compare(cell.minPacked, cell.maxPacked) == PointValues.Relation.CELL_OUTSIDE_QUERY) { - // this cell is outside our search radius; don't bother exploring any more - continue; - } - } BytesRef splitValue = BytesRef.deepCopyOf(cell.index.getSplitDimValue()); int splitDim = cell.index.getSplitDim(); @@ -288,15 +227,19 @@ public class FloatPointNearestNeighbor { System.arraycopy(splitValue.bytes, splitValue.offset, splitPackedValue, splitDim * bytesPerDim, bytesPerDim); cell.index.pushLeft(); - cellQueue.offer(new Cell(cell.index, cell.readerIndex, cell.minPacked, splitPackedValue, - approxBestDistanceSquared(cell.minPacked, splitPackedValue, origin))); + double distanceLeft = pointToRectangleDistanceSquared(cell.minPacked, splitPackedValue, origin); + if (distanceLeft <= visitor.bottomNearestDistanceSquared) { + cellQueue.offer(new Cell(cell.index, cell.readerIndex, cell.minPacked, splitPackedValue, distanceLeft)); + } splitPackedValue = cell.minPacked.clone(); System.arraycopy(splitValue.bytes, splitValue.offset, splitPackedValue, splitDim * bytesPerDim, bytesPerDim); newIndex.pushRight(); - cellQueue.offer(new Cell(newIndex, cell.readerIndex, splitPackedValue, cell.maxPacked, - approxBestDistanceSquared(splitPackedValue, cell.maxPacked, origin))); + double distanceRight = pointToRectangleDistanceSquared(splitPackedValue, cell.maxPacked, origin); + if (distanceRight <= visitor.bottomNearestDistanceSquared) { + cellQueue.offer(new Cell(newIndex, cell.readerIndex, splitPackedValue, cell.maxPacked, distanceRight)); + } } } @@ -306,44 +249,27 @@ public class FloatPointNearestNeighbor { hits[downTo] = hitQueue.poll(); downTo--; } + //System.out.println(visitor.comp); return hits; } - private static double approxBestDistanceSquared(byte[] minPackedValue, byte[] maxPackedValue, float[] value) { - boolean insideCell = true; - float[] min = new float[value.length]; - float[] max = new float[value.length]; - double[] closest = new double[value.length]; - for (int i = 0, offset = 0 ; i < value.length ; ++i, offset += Float.BYTES) { - min[i] = FloatPoint.decodeDimension(minPackedValue, offset); - max[i] = FloatPoint.decodeDimension(maxPackedValue, offset); - if (insideCell) { - if (value[i] < min[i] || value[i] > max[i]) { - insideCell = false; - } - } - double minDiff = Math.abs((double)value[i] - (double)min[i]); - double maxDiff = Math.abs((double)value[i] - (double)max[i]); - closest[i] = minDiff < maxDiff ? minDiff : maxDiff; - } - if (insideCell) { - return 0.0f; - } + private static double pointToRectangleDistanceSquared(byte[] minPackedValue, byte[] maxPackedValue, float[] value) { double sumOfSquaredDiffs = 0.0d; - for (int d = 0 ; d < value.length ; ++d) { - sumOfSquaredDiffs += closest[d] * closest[d]; + for (int i = 0, offset = 0 ; i < value.length ; ++i, offset += Float.BYTES) { + double min = FloatPoint.decodeDimension(minPackedValue, offset); + if (value[i] < min) { + double diff = min - (double)value[i]; + sumOfSquaredDiffs += diff * diff; + continue; + } + double max = FloatPoint.decodeDimension(maxPackedValue, offset); + if (value[i] > max) { + double diff = max - (double)value[i]; + sumOfSquaredDiffs += diff * diff; + } } return sumOfSquaredDiffs; } - - static double euclideanDistanceSquared(float[] a, float[] b) { - double sumOfSquaredDifferences = 0.0d; - for (int d = 0 ; d < a.length ; ++d) { - double diff = (double)a[d] - (double)b[d]; - sumOfSquaredDifferences += diff * diff; - } - return sumOfSquaredDifferences; - } public static TopFieldDocs nearest(IndexSearcher searcher, String field, int topN, float... origin) throws IOException { if (topN < 1) { diff --git a/lucene/sandbox/src/test/org/apache/lucene/document/TestFloatPointNearestNeighbor.java b/lucene/sandbox/src/test/org/apache/lucene/document/TestFloatPointNearestNeighbor.java index 335ad1726ef..8379326cfae 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/document/TestFloatPointNearestNeighbor.java +++ b/lucene/sandbox/src/test/org/apache/lucene/document/TestFloatPointNearestNeighbor.java @@ -188,7 +188,7 @@ public class TestFloatPointNearestNeighbor extends LuceneTestCase { FloatPointNearestNeighbor.NearestHit[] expectedHits = new FloatPointNearestNeighbor.NearestHit[numPoints]; for (int id = 0 ; id < numPoints ; ++id) { FloatPointNearestNeighbor.NearestHit hit = new FloatPointNearestNeighbor.NearestHit(); - hit.distanceSquared = FloatPointNearestNeighbor.euclideanDistanceSquared(origin, values[id]); + hit.distanceSquared = euclideanDistanceSquared(origin, values[id]); hit.docID = id; expectedHits[id] = hit; } @@ -232,6 +232,15 @@ public class TestFloatPointNearestNeighbor extends LuceneTestCase { dir.close(); } + private static double euclideanDistanceSquared(float[] a, float[] b) { + double sumOfSquaredDifferences = 0.0d; + for (int d = 0 ; d < a.length ; ++d) { + double diff = (double)a[d] - (double)b[d]; + sumOfSquaredDifferences += diff * diff; + } + return sumOfSquaredDifferences; + } + private IndexWriterConfig getIndexWriterConfig() { IndexWriterConfig iwc = newIndexWriterConfig(); iwc.setCodec(Codec.forName("Lucene80"));