LUCENE-10275: Speed up MultiRangeQuery by using an interval tree

This commit is contained in:
Ignacio Vera 2021-12-02 09:53:23 +01:00 committed by iverase
parent 072b775199
commit a580e29539
2 changed files with 325 additions and 74 deletions

View File

@ -43,6 +43,8 @@ Improvements
added in LUCENE-9820 (Ignacio Vera)
* LUCENE-9538: Detect polygon self-intersections in the Tessellator. (Ignacio Vera)
* LUCENE-10275: Speed up MultiRangeQuery by using an interval tree. (Ignacio Vera)
Optimizations
---------------------

View File

@ -20,6 +20,7 @@ package org.apache.lucene.sandbox.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.LeafReader;
@ -156,7 +157,8 @@ public abstract class MultiRangeQuery extends Query {
return new ConstantScoreWeight(this, boost) {
private PointValues.IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) {
private PointValues.IntersectVisitor getIntersectVisitor(
DocIdSetBuilder result, Range range) {
return new PointValues.IntersectVisitor() {
DocIdSetBuilder.BulkAdder adder;
@ -173,72 +175,14 @@ public abstract class MultiRangeQuery extends Query {
@Override
public void visit(int docID, byte[] packedValue) {
// If a single OR clause has the value in range, the entire query accepts the value
continueRange:
for (RangeClause rangeClause : rangeClauses) {
for (int dim = 0; dim < numDims; dim++) {
int offset = dim * bytesPerDim;
if (comparator.compare(packedValue, offset, rangeClause.lowerValue, offset) < 0) {
// Doc value is too low in this dim:
continue continueRange;
}
if (comparator.compare(packedValue, offset, rangeClause.upperValue, offset) > 0) {
// Doc value is too high in this dim:
continue continueRange;
}
}
// Doc matched on all dimensions:
if (range.matches(packedValue)) {
adder.add(docID);
}
}
@Override
public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
boolean crosses = false;
/**
* CROSSES and INSIDE take priority over OUTSIDE. How we calculate the position is: 1)
* If any range sees the point as inside, return INSIDE. 2) If no range sees the point
* as inside and atleast one range sees the point as CROSSES, return CROSSES 3) If none
* of the above, return OUTSIDE
*/
continueRange:
for (RangeClause rangeClause : rangeClauses) {
boolean rangeCrosses = false;
for (int dim = 0; dim < numDims; dim++) {
int offset = dim * bytesPerDim;
if (comparator.compare(minPackedValue, offset, rangeClause.upperValue, offset) > 0
|| comparator.compare(maxPackedValue, offset, rangeClause.lowerValue, offset)
< 0) {
continue continueRange;
}
rangeCrosses |=
comparator.compare(minPackedValue, offset, rangeClause.lowerValue, offset) < 0
|| comparator.compare(
maxPackedValue, offset, rangeClause.upperValue, offset)
> 0;
}
if (rangeCrosses == false) {
// At this point we know that the cell is fully inside the range clause, so we
// return early:
return PointValues.Relation.CELL_INSIDE_QUERY;
} else {
// This range clause crosses the cell, but we'll keep checking more ranges to see if
// one fully contains the cell:
crosses = true;
}
}
if (crosses) {
return PointValues.Relation.CELL_CROSSES_QUERY;
} else {
return PointValues.Relation.CELL_OUTSIDE_QUERY;
}
return range.relate(minPackedValue, maxPackedValue);
}
};
}
@ -272,22 +216,14 @@ public abstract class MultiRangeQuery extends Query {
+ bytesPerDim);
}
Range range = create(rangeClauses, numDims, bytesPerDim, comparator);
boolean allDocsMatch;
if (values.getDocCount() == reader.maxDoc()) {
final byte[] fieldPackedLower = values.getMinPackedValue();
final byte[] fieldPackedUpper = values.getMaxPackedValue();
allDocsMatch = true;
for (RangeClause rangeClause : rangeClauses) {
for (int i = 0; i < numDims; ++i) {
int offset = i * bytesPerDim;
if (comparator.compare(rangeClause.lowerValue, offset, fieldPackedLower, offset) > 0
|| comparator.compare(rangeClause.upperValue, offset, fieldPackedUpper, offset)
< 0) {
allDocsMatch = false;
break;
}
}
}
allDocsMatch =
range.relate(fieldPackedLower, fieldPackedUpper)
== PointValues.Relation.CELL_INSIDE_QUERY;
} else {
allDocsMatch = false;
}
@ -311,7 +247,7 @@ public abstract class MultiRangeQuery extends Query {
return new ScorerSupplier() {
final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
final PointValues.IntersectVisitor visitor = getIntersectVisitor(result);
final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, range);
long cost = -1;
@Override
@ -439,4 +375,317 @@ public abstract class MultiRangeQuery extends Query {
* @return human readable value for debugging
*/
protected abstract String toString(int dimension, byte[] value);
/**
* A range represents anything with a min/max value that can compute its relation with another
* range and can compute if a point is inside it
*/
private interface Range {
/** min value of this range */
byte[] getMinPackedValue();
/** max value of this range */
byte[] getMaxPackedValue();
/** return true if the provided point is inside the range */
boolean matches(byte[] packedValue);
/** return the relation between this range and the provided range */
PointValues.Relation relate(byte[] minPackedValue, byte[] maxPackedValue);
}
/** An interval tree of Ranges for speeding up computations */
private static class RangeTree implements Range {
/** minimum value of this Range and its children */
private final byte[] minPackedValue;
/** maximum value of this Range and its children */
private final byte[] maxPackedValue;
/** Left child, it can be null */
private RangeTree left;
/** Right child, it can be null */
private RangeTree right;
/** which dimension was this node split on */
private final int split;
/** Range of this tree node */
private final Range component;
// Utility variables for computing relationships
private final ArrayUtil.ByteArrayComparator comparator;
private final int numIndexDim;
private final int bytesPerDim;
private RangeTree(
Range component,
int split,
ArrayUtil.ByteArrayComparator comparator,
int numIndexDim,
int bytesPerDim) {
this.minPackedValue = component.getMinPackedValue().clone();
this.maxPackedValue = component.getMaxPackedValue().clone();
this.component = component;
this.split = split;
this.comparator = comparator;
this.numIndexDim = numIndexDim;
this.bytesPerDim = bytesPerDim;
}
@Override
public byte[] getMinPackedValue() {
return minPackedValue;
}
@Override
public byte[] getMaxPackedValue() {
return maxPackedValue;
}
@Override
public boolean matches(byte[] packedValue) {
boolean valid = true;
for (int i = 0; i < numIndexDim; i++) {
int offset = bytesPerDim * i;
if (comparator.compare(packedValue, offset, maxPackedValue, offset) > 0) {
valid = false;
break;
}
}
if (valid) {
if (component.matches(packedValue)) {
return true;
}
if (left != null) {
if (left.matches(packedValue)) {
return true;
}
}
if (right != null
&& comparator.compare(
packedValue, split * bytesPerDim, minPackedValue, split * bytesPerDim)
>= 0) {
if (right.matches(packedValue)) {
return true;
}
}
}
return false;
}
@Override
public PointValues.Relation relate(byte[] minPackedValue, byte[] maxPackedValue) {
boolean valid = true;
for (int i = 0; i < numIndexDim; i++) {
int offset = bytesPerDim * i;
if (comparator.compare(minPackedValue, offset, this.maxPackedValue, offset) > 0) {
valid = false;
break;
}
}
if (valid) {
PointValues.Relation relation = component.relate(minPackedValue, maxPackedValue);
if (relation != PointValues.Relation.CELL_OUTSIDE_QUERY) {
return relation;
}
if (left != null) {
relation = left.relate(minPackedValue, maxPackedValue);
if (relation != PointValues.Relation.CELL_OUTSIDE_QUERY) {
return relation;
}
}
if (right != null
&& comparator.compare(
maxPackedValue, split * bytesPerDim, this.minPackedValue, split * bytesPerDim)
>= 0) {
relation = right.relate(minPackedValue, maxPackedValue);
if (relation != PointValues.Relation.CELL_OUTSIDE_QUERY) {
return relation;
}
}
}
return PointValues.Relation.CELL_OUTSIDE_QUERY;
}
}
/** Creates a tree from provided clauses */
static Range create(
List<RangeClause> clauses,
int numIndexDim,
int bytesPerDim,
ArrayUtil.ByteArrayComparator comparator) {
if (clauses.size() == 1) {
return getRange(clauses.get(0), numIndexDim, bytesPerDim, comparator);
}
Range[] ranges = new Range[clauses.size()];
for (int i = 0; i < clauses.size(); i++) {
ranges[i] = getRange(clauses.get(i), numIndexDim, bytesPerDim, comparator);
}
RangeTree root =
createTree(ranges, 0, ranges.length - 1, 0, numIndexDim, bytesPerDim, comparator);
// pull up min values for the root node so it contains a consistent bounding box
for (Range range : ranges) {
for (int i = 0; i < numIndexDim; i++) {
int offset = i * bytesPerDim;
if (comparator.compare(root.minPackedValue, offset, range.getMinPackedValue(), offset)
> 0) {
System.arraycopy(
range.getMinPackedValue(), offset, root.minPackedValue, offset, bytesPerDim);
}
if (comparator.compare(root.maxPackedValue, offset, range.getMaxPackedValue(), offset)
< 0) {
System.arraycopy(
range.getMaxPackedValue(), offset, root.maxPackedValue, offset, bytesPerDim);
}
}
}
return root;
}
/** Creates tree from sorted ranges (with range low and high inclusive) */
private static RangeTree createTree(
Range[] components,
int low,
int high,
int split,
int numIndexDim,
int bytesPerDim,
ArrayUtil.ByteArrayComparator comparator) {
if (low > high) {
return null;
}
final int mid = (low + high) >>> 1;
if (low < high) {
int offset = split * bytesPerDim;
Comparator<Range> comp =
(left, right) -> {
int ret =
comparator.compare(
left.getMinPackedValue(), offset, right.getMinPackedValue(), offset);
if (ret == 0) {
ret =
comparator.compare(
left.getMaxPackedValue(), offset, right.getMaxPackedValue(), offset);
}
return ret;
};
ArrayUtil.select(components, low, high + 1, mid, comp);
}
RangeTree newNode = new RangeTree(components[mid], split, comparator, numIndexDim, bytesPerDim);
// find children
split++;
if (split == numIndexDim) {
split = 0;
}
newNode.left =
createTree(components, low, mid - 1, split, numIndexDim, bytesPerDim, comparator);
newNode.right =
createTree(components, mid + 1, high, split, numIndexDim, bytesPerDim, comparator);
// pull up max values to this node
if (newNode.left != null) {
for (int i = 0; i < numIndexDim; i++) {
int offset = i * bytesPerDim;
if (comparator.compare(
newNode.minPackedValue, offset, newNode.left.getMinPackedValue(), offset)
> 0) {
System.arraycopy(
newNode.left.getMinPackedValue(),
offset,
newNode.minPackedValue,
offset,
bytesPerDim);
}
if (comparator.compare(
newNode.maxPackedValue, offset, newNode.left.getMaxPackedValue(), offset)
< 0) {
System.arraycopy(
newNode.left.getMaxPackedValue(),
offset,
newNode.maxPackedValue,
offset,
bytesPerDim);
}
}
}
if (newNode.right != null) {
for (int i = 0; i < numIndexDim; i++) {
int offset = i * bytesPerDim;
if (comparator.compare(
newNode.minPackedValue, offset, newNode.right.getMinPackedValue(), offset)
> 0) {
System.arraycopy(
newNode.right.getMinPackedValue(),
offset,
newNode.minPackedValue,
offset,
bytesPerDim);
}
if (comparator.compare(
newNode.maxPackedValue, offset, newNode.right.getMaxPackedValue(), offset)
< 0) {
System.arraycopy(
newNode.right.getMaxPackedValue(),
offset,
newNode.maxPackedValue,
offset,
bytesPerDim);
}
}
}
return newNode;
}
/** Builds a Range object from a range clause */
private static Range getRange(
RangeClause clause,
int numIndexDim,
int bytesPerDim,
ArrayUtil.ByteArrayComparator comparator) {
return new Range() {
@Override
public byte[] getMinPackedValue() {
return clause.lowerValue;
}
@Override
public byte[] getMaxPackedValue() {
return clause.upperValue;
}
@Override
public boolean matches(byte[] packedValue) {
for (int dim = 0; dim < numIndexDim; dim++) {
int offset = dim * bytesPerDim;
if (comparator.compare(packedValue, offset, clause.lowerValue, offset) < 0) {
// Doc's value is too low, in this dimension
return false;
}
if (comparator.compare(packedValue, offset, clause.upperValue, offset) > 0) {
// Doc's value is too high, in this dimension
return false;
}
}
return true;
}
@Override
public PointValues.Relation relate(byte[] minPackedValue, byte[] maxPackedValue) {
boolean crosses = false;
for (int dim = 0; dim < numIndexDim; dim++) {
int offset = dim * bytesPerDim;
if (comparator.compare(minPackedValue, offset, clause.upperValue, offset) > 0
|| comparator.compare(maxPackedValue, offset, clause.lowerValue, offset) < 0) {
return PointValues.Relation.CELL_OUTSIDE_QUERY;
}
crosses |=
comparator.compare(minPackedValue, offset, clause.lowerValue, offset) < 0
|| comparator.compare(maxPackedValue, offset, clause.upperValue, offset) > 0;
}
if (crosses) {
return PointValues.Relation.CELL_CROSSES_QUERY;
} else {
return PointValues.Relation.CELL_INSIDE_QUERY;
}
}
};
}
}