diff --git a/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java index 4403e456c32..bbc9a540ebe 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java @@ -64,6 +64,7 @@ public class PointInSetQuery extends Query { BytesRefBuilder previous = null; BytesRef current; while ((current = packedPoints.next()) != null) { + // nocommit make sure a test tests this: if (current.length != numDims * bytesPerDim) { throw new IllegalArgumentException("packed point length should be " + (numDims * bytesPerDim) + " but got " + current.length + "; field=\"" + field + "\", numDims=" + numDims + " bytesPerDim=" + bytesPerDim); } @@ -145,148 +146,21 @@ public class PointInSetQuery extends Query { DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc()); int[] hitCount = new int[1]; - final TermIterator iterator = sortedPackedPoints.iterator(); - byte[] pointBytes = new byte[bytesPerDim * numDims]; if (numDims == 1) { - final BytesRef scratch = new BytesRef(); - scratch.length = bytesPerDim; + // We optimize this common case, effectively doing a merge sort of the indexed values vs the queried set: + values.intersect(field, new MergePointVisitor(sortedPackedPoints.iterator(), hitCount, result)); - // Optimize this common case, effectively doing a merge sort of the indexed values vs the queried set: - values.intersect(field, - new IntersectVisitor() { - - private BytesRef nextQueryPoint = iterator.next(); - - @Override - public void grow(int count) { - result.grow(count); - } - - @Override - public void visit(int docID) { - hitCount[0]++; - result.add(docID); - } - - @Override - public void visit(int docID, byte[] packedValue) { - scratch.bytes = packedValue; - while (nextQueryPoint != null) { - int cmp = nextQueryPoint.compareTo(scratch); - if (cmp == 0) { - // Query point equals index point, so collect and return - hitCount[0]++; - result.add(docID); - break; - } else if (cmp < 0) { - // Query point is before index point, so we move to next query point - nextQueryPoint = iterator.next(); - } else { - // Query point is after index point, so we don't collect and we return: - break; - } - } - } - - @Override - public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - - while (nextQueryPoint != null) { - scratch.bytes = minPackedValue; - int cmpMin = nextQueryPoint.compareTo(scratch); - if (cmpMin < 0) { - // query point is before the start of this cell - nextQueryPoint = iterator.next(); - continue; - } - scratch.bytes = maxPackedValue; - int cmpMax = nextQueryPoint.compareTo(scratch); - if (cmpMax > 0) { - // query point is after the end of this cell - return Relation.CELL_OUTSIDE_QUERY; - } - - if (cmpMin == 0 && cmpMax == 0) { - // NOTE: we only hit this if we are on a cell whose min and max values are exactly equal to our point, - // which can easily happen if many (> 1024) docs share this one value - return Relation.CELL_INSIDE_QUERY; - } else { - return Relation.CELL_CROSSES_QUERY; - } - } - - // We exhausted all points in the query: - return Relation.CELL_OUTSIDE_QUERY; - } - }); } else { + // NOTE: this is naive implementation, where for each point we re-walk the KD tree to intersect. We could instead do a similar + // optimization as the 1D case, but I think it'd mean building a query-time KD tree so we could efficiently intersect against the + // index, which is probably tricky! + SinglePointVisitor visitor = new SinglePointVisitor(hitCount, result); + TermIterator iterator = sortedPackedPoints.iterator(); for (BytesRef point = iterator.next(); point != null; point = iterator.next()) { - // nocommit make sure a test tests this: - assert point.length == pointBytes.length; - System.arraycopy(point.bytes, point.offset, pointBytes, 0, pointBytes.length); - - final BytesRef finalPoint = point; - - values.intersect(field, - // nocommit don't make new instance of this for each point? - new IntersectVisitor() { - - @Override - public void grow(int count) { - result.grow(count); - } - - @Override - public void visit(int docID) { - hitCount[0]++; - result.add(docID); - } - - @Override - public void visit(int docID, byte[] packedValue) { - assert packedValue.length == finalPoint.length; - if (Arrays.equals(packedValue, pointBytes)) { - // The point for this doc matches the point we are querying on - hitCount[0]++; - result.add(docID); - } - } - - @Override - public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - - boolean crosses = false; - - for(int dim=0;dim 0) { - return Relation.CELL_OUTSIDE_QUERY; - } - - int cmpMax = StringHelper.compare(bytesPerDim, maxPackedValue, offset, pointBytes, offset); - if (cmpMax < 0) { - return Relation.CELL_OUTSIDE_QUERY; - } - - if (cmpMin != 0 || cmpMax != 0) { - crosses = true; - } - } - - if (crosses) { - return Relation.CELL_CROSSES_QUERY; - } else { - // nocommit make sure tests hit this case: - // NOTE: we only hit this if we are on a cell whose min and max values are exactly equal to our point, - // which can easily happen if many docs share this one value - return Relation.CELL_INSIDE_QUERY; - } - } - }); + visitor.setPoint(point); + values.intersect(field, visitor); } } @@ -296,6 +170,162 @@ public class PointInSetQuery extends Query { }; } + /** Essentially does a merge sort, only collecting hits when the indexed point and query point are the same. This is an optimization, + * used in the 1D case. */ + private class MergePointVisitor implements IntersectVisitor { + + private final DocIdSetBuilder result; + private final int[] hitCount; + private final TermIterator iterator; + private BytesRef nextQueryPoint; + private final BytesRef scratch = new BytesRef(); + + public MergePointVisitor(TermIterator iterator, int[] hitCount, DocIdSetBuilder result) throws IOException { + this.hitCount = hitCount; + this.result = result; + this.iterator = iterator; + nextQueryPoint = iterator.next(); + scratch.length = bytesPerDim; + } + + @Override + public void grow(int count) { + result.grow(count); + } + + @Override + public void visit(int docID) { + hitCount[0]++; + result.add(docID); + } + + @Override + public void visit(int docID, byte[] packedValue) { + scratch.bytes = packedValue; + while (nextQueryPoint != null) { + int cmp = nextQueryPoint.compareTo(scratch); + if (cmp == 0) { + // Query point equals index point, so collect and return + hitCount[0]++; + result.add(docID); + break; + } else if (cmp < 0) { + // Query point is before index point, so we move to next query point + nextQueryPoint = iterator.next(); + } else { + // Query point is after index point, so we don't collect and we return: + break; + } + } + } + + @Override + public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + + while (nextQueryPoint != null) { + scratch.bytes = minPackedValue; + int cmpMin = nextQueryPoint.compareTo(scratch); + if (cmpMin < 0) { + // query point is before the start of this cell + nextQueryPoint = iterator.next(); + continue; + } + scratch.bytes = maxPackedValue; + int cmpMax = nextQueryPoint.compareTo(scratch); + if (cmpMax > 0) { + // query point is after the end of this cell + return Relation.CELL_OUTSIDE_QUERY; + } + + if (cmpMin == 0 && cmpMax == 0) { + // NOTE: we only hit this if we are on a cell whose min and max values are exactly equal to our point, + // which can easily happen if many (> 1024) docs share this one value + return Relation.CELL_INSIDE_QUERY; + } else { + return Relation.CELL_CROSSES_QUERY; + } + } + + // We exhausted all points in the query: + return Relation.CELL_OUTSIDE_QUERY; + } + } + + /** IntersectVisitor that queries against a highly degenerate shape: a single point. This is used in the > 1D case. */ + private class SinglePointVisitor implements IntersectVisitor { + + private final DocIdSetBuilder result; + private final int[] hitCount; + public BytesRef point; + private final byte[] pointBytes; + + public SinglePointVisitor(int[] hitCount, DocIdSetBuilder result) { + this.hitCount = hitCount; + this.result = result; + this.pointBytes = new byte[bytesPerDim * numDims]; + } + + public void setPoint(BytesRef point) { + // we verified this up front in query's ctor: + assert point.length == pointBytes.length; + System.arraycopy(point.bytes, point.offset, pointBytes, 0, pointBytes.length); + } + + @Override + public void grow(int count) { + result.grow(count); + } + + @Override + public void visit(int docID) { + hitCount[0]++; + result.add(docID); + } + + @Override + public void visit(int docID, byte[] packedValue) { + assert packedValue.length == point.length; + if (Arrays.equals(packedValue, pointBytes)) { + // The point for this doc matches the point we are querying on + hitCount[0]++; + result.add(docID); + } + } + + @Override + public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + + boolean crosses = false; + + for(int dim=0;dim 0) { + return Relation.CELL_OUTSIDE_QUERY; + } + + int cmpMax = StringHelper.compare(bytesPerDim, maxPackedValue, offset, pointBytes, offset); + if (cmpMax < 0) { + return Relation.CELL_OUTSIDE_QUERY; + } + + if (cmpMin != 0 || cmpMax != 0) { + crosses = true; + } + } + + if (crosses) { + return Relation.CELL_CROSSES_QUERY; + } else { + // nocommit make sure tests hit this case: + // NOTE: we only hit this if we are on a cell whose min and max values are exactly equal to our point, + // which can easily happen if many docs share this one value + return Relation.CELL_INSIDE_QUERY; + } + } + } + @Override public int hashCode() { int hash = super.hashCode();