Generalize range query optimization on sorted indexes to descending sorts. (#11972)

This generalizes #687 to indexes that are sorted in descending order. The main
challenge with descending sorts is that they require being able to compute the
last doc ID that matches a value, which would ideally require walking the BKD
tree in reverse order, but the API only support moving forward. This is worked
around by maintaining a stack of `PointTree` clones to perform the search.
This commit is contained in:
Adrien Grand 2022-12-08 08:38:53 +01:00 committed by GitHub
parent d0be9ab57c
commit 95df7e8109
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 299 additions and 93 deletions

View File

@ -205,6 +205,9 @@ Optimizations
* GITHUB#11895: count() in BooleanQuery could be early quit. (Lu Xugang)
* GITHUB#11972: `IndexSortSortedNumericDocValuesRangeQuery` can now also
optimize query execution with points for descending sorts. (Adrien Grand)
Other
---------------------

View File

@ -17,8 +17,9 @@
package org.apache.lucene.sandbox.search;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Objects;
import java.util.function.Predicate;
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.DocValues;
@ -27,6 +28,7 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.PointValues.IntersectVisitor;
import org.apache.lucene.index.PointValues.PointTree;
import org.apache.lucene.index.PointValues.Relation;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.ConstantScoreScorer;
@ -220,53 +222,55 @@ public class IndexSortSortedNumericDocValuesRangeQuery extends Query {
};
}
/**
* Returns the first document whose packed value is greater than or equal (if allowEqual is true)
* to the provided packed value or -1 if all packed values are smaller than the provided one,
*/
public final int nextDoc(PointValues values, byte[] packedValue, boolean allowEqual)
throws IOException {
assert values.getNumDimensions() == 1;
final int bytesPerDim = values.getBytesPerDimension();
final ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(bytesPerDim);
final Predicate<byte[]> biggerThan =
testPackedValue -> {
int cmp = comparator.compare(testPackedValue, 0, packedValue, 0);
return cmp > 0 || (cmp == 0 && allowEqual);
};
return nextDoc(values.getPointTree(), biggerThan);
private static class ValueAndDoc {
byte[] value;
int docID;
boolean done;
}
private int nextDoc(PointValues.PointTree pointTree, Predicate<byte[]> biggerThan)
/**
* Move to the minimum leaf node that has at least one value that is greater than (or equal to if
* {@code allowEqual}) {@code value}, and return the next greater value on this block. Upon
* returning, the {@code pointTree} must be on the leaf node where the value was found.
*/
private static ValueAndDoc findNextValue(
PointTree pointTree,
byte[] value,
boolean allowEqual,
ByteArrayComparator comparator,
boolean lastDoc)
throws IOException {
if (biggerThan.test(pointTree.getMaxPackedValue()) == false) {
// doc is before us
return -1;
} else if (pointTree.moveToChild()) {
// navigate down
do {
final int doc = nextDoc(pointTree, biggerThan);
if (doc != -1) {
return doc;
}
} while (pointTree.moveToSibling());
pointTree.moveToParent();
return -1;
} else {
// doc is in this leaf
final int[] doc = {-1};
int cmp = comparator.compare(pointTree.getMaxPackedValue(), 0, value, 0);
if (cmp < 0 || (cmp == 0 && allowEqual == false)) {
return null;
}
if (pointTree.moveToChild() == false) {
ValueAndDoc vd = new ValueAndDoc();
pointTree.visitDocValues(
new IntersectVisitor() {
@Override
public void visit(int docID) {
throw new AssertionError("Invalid call to visit(docID)");
public void visit(int docID, byte[] packedValue) throws IOException {
if (vd.value == null) {
int cmp = comparator.compare(packedValue, 0, value, 0);
if (cmp > 0 || (cmp == 0 && allowEqual)) {
vd.value = packedValue.clone();
vd.docID = docID;
}
} else if (lastDoc && vd.done == false) {
int cmp = comparator.compare(packedValue, 0, vd.value, 0);
assert cmp >= 0;
if (cmp > 0) {
vd.done = true;
} else {
vd.docID = docID;
}
}
}
@Override
public void visit(int docID, byte[] packedValue) {
if (doc[0] == -1 && biggerThan.test(packedValue)) {
doc[0] = docID;
}
public void visit(int docID) throws IOException {
throw new UnsupportedOperationException();
}
@Override
@ -274,8 +278,130 @@ public class IndexSortSortedNumericDocValuesRangeQuery extends Query {
return Relation.CELL_CROSSES_QUERY;
}
});
return doc[0];
if (vd.value != null) {
return vd;
} else {
return null;
}
}
// Recurse
do {
ValueAndDoc vd = findNextValue(pointTree, value, allowEqual, comparator, lastDoc);
if (vd != null) {
return vd;
}
} while (pointTree.moveToSibling());
boolean moved = pointTree.moveToParent();
assert moved;
return null;
}
/**
* Find the next value that is greater than (or equal to if {@code allowEqual}) and return either
* its first doc ID or last doc ID depending on {@code lastDoc}. This method returns -1 if there
* is no greater value in the dataset.
*/
private static int nextDoc(
PointTree pointTree,
byte[] value,
boolean allowEqual,
ByteArrayComparator comparator,
boolean lastDoc)
throws IOException {
ValueAndDoc vd = findNextValue(pointTree, value, allowEqual, comparator, lastDoc);
if (vd == null) {
return -1;
}
if (lastDoc == false || vd.done) {
return vd.docID;
}
// We found the next value, now we need the last doc ID.
int doc = lastDoc(pointTree, vd.value, comparator);
if (doc == -1) {
// vd.docID was actually the last doc ID
return vd.docID;
} else {
return doc;
}
}
/**
* Compute the last doc ID that matches the given value and is stored on a leaf node that compares
* greater than the current leaf node that the provided {@link PointTree} is positioned on. This
* returns -1 if no other leaf node contains the provided {@code value}.
*/
private static int lastDoc(PointTree pointTree, byte[] value, ByteArrayComparator comparator)
throws IOException {
// Create a stack of nodes that may contain value that we'll use to search for the last leaf
// node that contains `value`.
// While the logic looks a bit complicated due to the fact that the PointTree API doesn't allow
// moving back to previous siblings, this effectively performs a binary search.
Deque<PointTree> stack = new ArrayDeque<>();
outer:
while (true) {
// Move to the next node
while (pointTree.moveToSibling() == false) {
if (pointTree.moveToParent() == false) {
// No next node
break outer;
}
}
int cmp = comparator.compare(pointTree.getMinPackedValue(), 0, value, 0);
if (cmp > 0) {
// This node doesn't have `value`, so next nodes can't either
break;
}
stack.push(pointTree.clone());
}
while (stack.isEmpty() == false) {
PointTree next = stack.pop();
if (next.moveToChild() == false) {
int[] lastDoc = {-1};
next.visitDocValues(
new IntersectVisitor() {
@Override
public void visit(int docID) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public void visit(int docID, byte[] packedValue) throws IOException {
int cmp = comparator.compare(value, 0, packedValue, 0);
if (cmp == 0) {
lastDoc[0] = docID;
}
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return Relation.CELL_CROSSES_QUERY;
}
});
if (lastDoc[0] != -1) {
return lastDoc[0];
}
} else {
do {
int cmp = comparator.compare(next.getMinPackedValue(), 0, value, 0);
if (cmp > 0) {
// This node doesn't have `value`, so next nodes can't either
break;
}
stack.push(next.clone());
} while (next.moveToSibling());
}
}
return -1;
}
private boolean matchNone(PointValues points, byte[] queryLowerPoint, byte[] queryUpperPoint)
@ -311,8 +437,8 @@ public class IndexSortSortedNumericDocValuesRangeQuery extends Query {
Sort indexSort = context.reader().getMetaData().getSort();
if (indexSort != null
&& indexSort.getSort().length > 0
&& indexSort.getSort()[0].getField().equals(field)
&& indexSort.getSort()[0].getReverse() == false) {
&& indexSort.getSort()[0].getField().equals(field)) {
final boolean reverse = indexSort.getSort()[0].getReverse();
PointValues points = context.reader().getPointValues(field);
if (points == null) {
return null;
@ -327,44 +453,58 @@ public class IndexSortSortedNumericDocValuesRangeQuery extends Query {
return null;
}
// Each doc that has points has exactly one point.
if (points.size() == points.getDocCount()) {
if (points.size() != points.getDocCount()) {
return null;
}
byte[] queryLowerPoint;
byte[] queryUpperPoint;
if (points.getBytesPerDimension() == Integer.BYTES) {
queryLowerPoint = IntPoint.pack((int) lowerValue).bytes;
queryUpperPoint = IntPoint.pack((int) upperValue).bytes;
} else {
queryLowerPoint = LongPoint.pack(lowerValue).bytes;
queryUpperPoint = LongPoint.pack(upperValue).bytes;
}
if (lowerValue > upperValue || matchNone(points, queryLowerPoint, queryUpperPoint)) {
return new BoundedDocIdSetIterator(0, 0, null);
}
int minDocId, maxDocId;
if (matchAll(points, queryLowerPoint, queryUpperPoint)) {
minDocId = 0;
maxDocId = context.reader().maxDoc();
} else {
// >=queryLowerPoint
minDocId = nextDoc(points, queryLowerPoint, true);
byte[] queryLowerPoint;
byte[] queryUpperPoint;
if (points.getBytesPerDimension() == Integer.BYTES) {
queryLowerPoint = IntPoint.pack((int) lowerValue).bytes;
queryUpperPoint = IntPoint.pack((int) upperValue).bytes;
} else {
queryLowerPoint = LongPoint.pack(lowerValue).bytes;
queryUpperPoint = LongPoint.pack(upperValue).bytes;
}
if (lowerValue > upperValue || matchNone(points, queryLowerPoint, queryUpperPoint)) {
return new BoundedDocIdSetIterator(0, 0, null);
}
int minDocId, maxDocId;
if (matchAll(points, queryLowerPoint, queryUpperPoint)) {
minDocId = 0;
maxDocId = context.reader().maxDoc();
} else {
final ByteArrayComparator comparator =
ArrayUtil.getUnsignedComparator(points.getBytesPerDimension());
if (reverse) {
minDocId = nextDoc(points.getPointTree(), queryUpperPoint, false, comparator, true) + 1;
} else {
minDocId = nextDoc(points.getPointTree(), queryLowerPoint, true, comparator, false);
if (minDocId == -1) {
// No matches
return new BoundedDocIdSetIterator(0, 0, null);
}
// >queryUpperPoint,
maxDocId = nextDoc(points, queryUpperPoint, false);
}
if (reverse) {
maxDocId = nextDoc(points.getPointTree(), queryLowerPoint, true, comparator, true) + 1;
if (maxDocId == 0) {
// No matches
return new BoundedDocIdSetIterator(0, 0, null);
}
} else {
maxDocId = nextDoc(points.getPointTree(), queryUpperPoint, false, comparator, false);
if (maxDocId == -1) {
maxDocId = context.reader().maxDoc();
}
}
}
if ((points.getDocCount() == context.reader().maxDoc())) {
return new BoundedDocIdSetIterator(minDocId, maxDocId, null);
} else {
return new BoundedDocIdSetIterator(minDocId, maxDocId, delegate);
}
if ((points.getDocCount() == context.reader().maxDoc())) {
return new BoundedDocIdSetIterator(minDocId, maxDocId, null);
} else {
return new BoundedDocIdSetIterator(minDocId, maxDocId, delegate);
}
}
return null;

View File

@ -643,58 +643,112 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas
field, lowerValue, upperValue, fallbackQuery);
}
public void testCountWithBkd() throws IOException {
public void testCountWithBkdAsc() throws Exception {
doTestCountWithBkd(false);
}
public void testCountWithBkdDesc() throws Exception {
doTestCountWithBkd(true);
}
public void doTestCountWithBkd(boolean reverse) throws Exception {
String filedName = "field";
Directory dir = newDirectory();
IndexWriterConfig iwc = new IndexWriterConfig(new MockAnalyzer(random()));
Sort indexSort = new Sort(new SortedNumericSortField(filedName, SortField.Type.LONG, false));
Sort indexSort = new Sort(new SortedNumericSortField(filedName, SortField.Type.LONG, reverse));
iwc.setIndexSort(indexSort);
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc);
addDocWithBkd(writer, filedName, 6, 500);
addDocWithBkd(writer, filedName, 5, 500);
addDocWithBkd(writer, filedName, 8, 500);
addDocWithBkd(writer, filedName, 9, 500);
addDocWithBkd(writer, filedName, 7, 500);
addDocWithBkd(writer, filedName, 5, 600);
addDocWithBkd(writer, filedName, 11, 700);
addDocWithBkd(writer, filedName, 13, 800);
addDocWithBkd(writer, filedName, 9, 900);
writer.flush();
writer.forceMerge(1);
IndexReader reader = writer.getReader();
IndexSearcher searcher = newSearcher(reader);
Query fallbackQuery = LongPoint.newRangeQuery(filedName, 6, 8);
Query query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 6, 8, fallbackQuery);
// Both bounds exist in the dataset
Query fallbackQuery = LongPoint.newRangeQuery(filedName, 7, 9);
Query query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 7, 9, fallbackQuery);
Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
for (LeafReaderContext context : searcher.getLeafContexts()) {
assertEquals(1500, weight.count(context));
assertEquals(1400, weight.count(context));
}
// Both bounds do not exist in the dataset
fallbackQuery = LongPoint.newRangeQuery(filedName, 6, 10);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 6, 10, fallbackQuery);
weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
for (LeafReaderContext context : searcher.getLeafContexts()) {
assertEquals(2000, weight.count(context));
assertEquals(1400, weight.count(context));
}
fallbackQuery = LongPoint.newRangeQuery(filedName, 4, 6);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 4, 6, fallbackQuery);
// Min bound exists in the dataset, not the max
fallbackQuery = LongPoint.newRangeQuery(filedName, 7, 10);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 7, 10, fallbackQuery);
weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
for (LeafReaderContext context : searcher.getLeafContexts()) {
assertEquals(1000, weight.count(context));
assertEquals(1400, weight.count(context));
}
fallbackQuery = LongPoint.newRangeQuery(filedName, 2, 10);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 2, 10, fallbackQuery);
// Min bound doesn't exist in the dataset, max does
fallbackQuery = LongPoint.newRangeQuery(filedName, 6, 9);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 7, 10, fallbackQuery);
weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
for (LeafReaderContext context : searcher.getLeafContexts()) {
assertEquals(2500, weight.count(context));
assertEquals(1400, weight.count(context));
}
fallbackQuery = LongPoint.newRangeQuery(filedName, 5, 9);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 2, 10, fallbackQuery);
// Min bound is the min value of the dataset
fallbackQuery = LongPoint.newRangeQuery(filedName, 5, 8);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 4, 8, fallbackQuery);
weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
for (LeafReaderContext context : searcher.getLeafContexts()) {
assertEquals(2500, weight.count(context));
assertEquals(1100, weight.count(context));
}
// Min bound is less than min value of the dataset
fallbackQuery = LongPoint.newRangeQuery(filedName, 4, 8);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 4, 8, fallbackQuery);
weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
for (LeafReaderContext context : searcher.getLeafContexts()) {
assertEquals(1100, weight.count(context));
}
// Max bound is the max value of the dataset
fallbackQuery = LongPoint.newRangeQuery(filedName, 10, 13);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 10, 13, fallbackQuery);
weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
for (LeafReaderContext context : searcher.getLeafContexts()) {
assertEquals(1500, weight.count(context));
}
// Max bound is greater than max value of the dataset
fallbackQuery = LongPoint.newRangeQuery(filedName, 10, 14);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 10, 14, fallbackQuery);
weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
for (LeafReaderContext context : searcher.getLeafContexts()) {
assertEquals(1500, weight.count(context));
}
// Everything matches
fallbackQuery = LongPoint.newRangeQuery(filedName, 2, 14);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 2, 14, fallbackQuery);
weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
for (LeafReaderContext context : searcher.getLeafContexts()) {
assertEquals(3500, weight.count(context));
}
// Bounds equal to min/max values of the dataset, everything matches
fallbackQuery = LongPoint.newRangeQuery(filedName, 2, 14);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 2, 14, fallbackQuery);
weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
for (LeafReaderContext context : searcher.getLeafContexts()) {
assertEquals(3500, weight.count(context));
}
// Bounds are less than the min value of the dataset
fallbackQuery = LongPoint.newRangeQuery(filedName, 2, 3);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 2, 3, fallbackQuery);
weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
@ -702,8 +756,9 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas
assertEquals(0, weight.count(context));
}
fallbackQuery = LongPoint.newRangeQuery(filedName, 10, 11);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 10, 11, fallbackQuery);
// Bounds are greater than the max value of the dataset
fallbackQuery = LongPoint.newRangeQuery(filedName, 14, 15);
query = new IndexSortSortedNumericDocValuesRangeQuery(filedName, 14, 15, fallbackQuery);
weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
for (LeafReaderContext context : searcher.getLeafContexts()) {
assertEquals(0, weight.count(context));
@ -714,11 +769,19 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas
dir.close();
}
public void testRandomCountWithBkd() throws IOException {
public void testRandomCountWithBkdAsc() throws Exception {
doTestRandomCountWithBkd(false);
}
public void testRandomCountWithBkdDesc() throws Exception {
doTestRandomCountWithBkd(true);
}
private void doTestRandomCountWithBkd(boolean reverse) throws Exception {
String filedName = "field";
Directory dir = newDirectory();
IndexWriterConfig iwc = new IndexWriterConfig(new MockAnalyzer(random()));
Sort indexSort = new Sort(new SortedNumericSortField(filedName, SortField.Type.LONG, false));
Sort indexSort = new Sort(new SortedNumericSortField(filedName, SortField.Type.LONG, reverse));
iwc.setIndexSort(indexSort);
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc);
Random random = random();