From a580e29539ff3922901661145980e5b27892e553 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Thu, 2 Dec 2021 09:53:23 +0100 Subject: [PATCH] LUCENE-10275: Speed up MultiRangeQuery by using an interval tree --- lucene/CHANGES.txt | 2 + .../sandbox/search/MultiRangeQuery.java | 397 ++++++++++++++---- 2 files changed, 325 insertions(+), 74 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 57cb762dc04..65b5ec7bedf 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -43,6 +43,8 @@ Improvements added in LUCENE-9820 (Ignacio Vera) * LUCENE-9538: Detect polygon self-intersections in the Tessellator. (Ignacio Vera) + +* LUCENE-10275: Speed up MultiRangeQuery by using an interval tree. (Ignacio Vera) Optimizations --------------------- diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java index f0583bd2761..2d71d3257ad 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java @@ -20,6 +20,7 @@ package org.apache.lucene.sandbox.search; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Comparator; import java.util.List; import java.util.Objects; import org.apache.lucene.index.LeafReader; @@ -156,7 +157,8 @@ public abstract class MultiRangeQuery extends Query { return new ConstantScoreWeight(this, boost) { - private PointValues.IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) { + private PointValues.IntersectVisitor getIntersectVisitor( + DocIdSetBuilder result, Range range) { return new PointValues.IntersectVisitor() { DocIdSetBuilder.BulkAdder adder; @@ -173,72 +175,14 @@ public abstract class MultiRangeQuery extends Query { @Override public void visit(int docID, byte[] packedValue) { - // If a single OR clause has the value in range, the entire query accepts the value - continueRange: - for (RangeClause rangeClause : rangeClauses) { - for (int dim = 0; dim < numDims; dim++) { - int offset = dim * bytesPerDim; - if (comparator.compare(packedValue, offset, rangeClause.lowerValue, offset) < 0) { - // Doc value is too low in this dim: - continue continueRange; - } - if (comparator.compare(packedValue, offset, rangeClause.upperValue, offset) > 0) { - // Doc value is too high in this dim: - continue continueRange; - } - } - // Doc matched on all dimensions: + if (range.matches(packedValue)) { adder.add(docID); } } @Override public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - - boolean crosses = false; - - /** - * CROSSES and INSIDE take priority over OUTSIDE. How we calculate the position is: 1) - * If any range sees the point as inside, return INSIDE. 2) If no range sees the point - * as inside and atleast one range sees the point as CROSSES, return CROSSES 3) If none - * of the above, return OUTSIDE - */ - continueRange: - for (RangeClause rangeClause : rangeClauses) { - boolean rangeCrosses = false; - - for (int dim = 0; dim < numDims; dim++) { - int offset = dim * bytesPerDim; - - if (comparator.compare(minPackedValue, offset, rangeClause.upperValue, offset) > 0 - || comparator.compare(maxPackedValue, offset, rangeClause.lowerValue, offset) - < 0) { - continue continueRange; - } - - rangeCrosses |= - comparator.compare(minPackedValue, offset, rangeClause.lowerValue, offset) < 0 - || comparator.compare( - maxPackedValue, offset, rangeClause.upperValue, offset) - > 0; - } - - if (rangeCrosses == false) { - // At this point we know that the cell is fully inside the range clause, so we - // return early: - return PointValues.Relation.CELL_INSIDE_QUERY; - } else { - // This range clause crosses the cell, but we'll keep checking more ranges to see if - // one fully contains the cell: - crosses = true; - } - } - - if (crosses) { - return PointValues.Relation.CELL_CROSSES_QUERY; - } else { - return PointValues.Relation.CELL_OUTSIDE_QUERY; - } + return range.relate(minPackedValue, maxPackedValue); } }; } @@ -272,22 +216,14 @@ public abstract class MultiRangeQuery extends Query { + bytesPerDim); } + Range range = create(rangeClauses, numDims, bytesPerDim, comparator); boolean allDocsMatch; if (values.getDocCount() == reader.maxDoc()) { final byte[] fieldPackedLower = values.getMinPackedValue(); final byte[] fieldPackedUpper = values.getMaxPackedValue(); - allDocsMatch = true; - for (RangeClause rangeClause : rangeClauses) { - for (int i = 0; i < numDims; ++i) { - int offset = i * bytesPerDim; - if (comparator.compare(rangeClause.lowerValue, offset, fieldPackedLower, offset) > 0 - || comparator.compare(rangeClause.upperValue, offset, fieldPackedUpper, offset) - < 0) { - allDocsMatch = false; - break; - } - } - } + allDocsMatch = + range.relate(fieldPackedLower, fieldPackedUpper) + == PointValues.Relation.CELL_INSIDE_QUERY; } else { allDocsMatch = false; } @@ -311,7 +247,7 @@ public abstract class MultiRangeQuery extends Query { return new ScorerSupplier() { final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field); - final PointValues.IntersectVisitor visitor = getIntersectVisitor(result); + final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, range); long cost = -1; @Override @@ -439,4 +375,317 @@ public abstract class MultiRangeQuery extends Query { * @return human readable value for debugging */ protected abstract String toString(int dimension, byte[] value); + + /** + * A range represents anything with a min/max value that can compute its relation with another + * range and can compute if a point is inside it + */ + private interface Range { + /** min value of this range */ + byte[] getMinPackedValue(); + /** max value of this range */ + byte[] getMaxPackedValue(); + /** return true if the provided point is inside the range */ + boolean matches(byte[] packedValue); + /** return the relation between this range and the provided range */ + PointValues.Relation relate(byte[] minPackedValue, byte[] maxPackedValue); + } + + /** An interval tree of Ranges for speeding up computations */ + private static class RangeTree implements Range { + /** minimum value of this Range and its children */ + private final byte[] minPackedValue; + /** maximum value of this Range and its children */ + private final byte[] maxPackedValue; + + /** Left child, it can be null */ + private RangeTree left; + /** Right child, it can be null */ + private RangeTree right; + /** which dimension was this node split on */ + private final int split; + /** Range of this tree node */ + private final Range component; + // Utility variables for computing relationships + private final ArrayUtil.ByteArrayComparator comparator; + private final int numIndexDim; + private final int bytesPerDim; + + private RangeTree( + Range component, + int split, + ArrayUtil.ByteArrayComparator comparator, + int numIndexDim, + int bytesPerDim) { + this.minPackedValue = component.getMinPackedValue().clone(); + this.maxPackedValue = component.getMaxPackedValue().clone(); + this.component = component; + this.split = split; + this.comparator = comparator; + this.numIndexDim = numIndexDim; + this.bytesPerDim = bytesPerDim; + } + + @Override + public byte[] getMinPackedValue() { + return minPackedValue; + } + + @Override + public byte[] getMaxPackedValue() { + return maxPackedValue; + } + + @Override + public boolean matches(byte[] packedValue) { + boolean valid = true; + for (int i = 0; i < numIndexDim; i++) { + int offset = bytesPerDim * i; + if (comparator.compare(packedValue, offset, maxPackedValue, offset) > 0) { + valid = false; + break; + } + } + if (valid) { + if (component.matches(packedValue)) { + return true; + } + if (left != null) { + if (left.matches(packedValue)) { + return true; + } + } + if (right != null + && comparator.compare( + packedValue, split * bytesPerDim, minPackedValue, split * bytesPerDim) + >= 0) { + if (right.matches(packedValue)) { + return true; + } + } + } + return false; + } + + @Override + public PointValues.Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { + boolean valid = true; + for (int i = 0; i < numIndexDim; i++) { + int offset = bytesPerDim * i; + if (comparator.compare(minPackedValue, offset, this.maxPackedValue, offset) > 0) { + valid = false; + break; + } + } + if (valid) { + PointValues.Relation relation = component.relate(minPackedValue, maxPackedValue); + if (relation != PointValues.Relation.CELL_OUTSIDE_QUERY) { + return relation; + } + if (left != null) { + relation = left.relate(minPackedValue, maxPackedValue); + if (relation != PointValues.Relation.CELL_OUTSIDE_QUERY) { + return relation; + } + } + if (right != null + && comparator.compare( + maxPackedValue, split * bytesPerDim, this.minPackedValue, split * bytesPerDim) + >= 0) { + relation = right.relate(minPackedValue, maxPackedValue); + if (relation != PointValues.Relation.CELL_OUTSIDE_QUERY) { + return relation; + } + } + } + return PointValues.Relation.CELL_OUTSIDE_QUERY; + } + } + + /** Creates a tree from provided clauses */ + static Range create( + List clauses, + int numIndexDim, + int bytesPerDim, + ArrayUtil.ByteArrayComparator comparator) { + if (clauses.size() == 1) { + return getRange(clauses.get(0), numIndexDim, bytesPerDim, comparator); + } + Range[] ranges = new Range[clauses.size()]; + for (int i = 0; i < clauses.size(); i++) { + ranges[i] = getRange(clauses.get(i), numIndexDim, bytesPerDim, comparator); + } + RangeTree root = + createTree(ranges, 0, ranges.length - 1, 0, numIndexDim, bytesPerDim, comparator); + // pull up min values for the root node so it contains a consistent bounding box + for (Range range : ranges) { + for (int i = 0; i < numIndexDim; i++) { + int offset = i * bytesPerDim; + if (comparator.compare(root.minPackedValue, offset, range.getMinPackedValue(), offset) + > 0) { + System.arraycopy( + range.getMinPackedValue(), offset, root.minPackedValue, offset, bytesPerDim); + } + if (comparator.compare(root.maxPackedValue, offset, range.getMaxPackedValue(), offset) + < 0) { + System.arraycopy( + range.getMaxPackedValue(), offset, root.maxPackedValue, offset, bytesPerDim); + } + } + } + return root; + } + + /** Creates tree from sorted ranges (with range low and high inclusive) */ + private static RangeTree createTree( + Range[] components, + int low, + int high, + int split, + int numIndexDim, + int bytesPerDim, + ArrayUtil.ByteArrayComparator comparator) { + if (low > high) { + return null; + } + final int mid = (low + high) >>> 1; + if (low < high) { + int offset = split * bytesPerDim; + Comparator comp = + (left, right) -> { + int ret = + comparator.compare( + left.getMinPackedValue(), offset, right.getMinPackedValue(), offset); + if (ret == 0) { + ret = + comparator.compare( + left.getMaxPackedValue(), offset, right.getMaxPackedValue(), offset); + } + return ret; + }; + ArrayUtil.select(components, low, high + 1, mid, comp); + } + RangeTree newNode = new RangeTree(components[mid], split, comparator, numIndexDim, bytesPerDim); + // find children + split++; + if (split == numIndexDim) { + split = 0; + } + newNode.left = + createTree(components, low, mid - 1, split, numIndexDim, bytesPerDim, comparator); + newNode.right = + createTree(components, mid + 1, high, split, numIndexDim, bytesPerDim, comparator); + + // pull up max values to this node + if (newNode.left != null) { + for (int i = 0; i < numIndexDim; i++) { + int offset = i * bytesPerDim; + if (comparator.compare( + newNode.minPackedValue, offset, newNode.left.getMinPackedValue(), offset) + > 0) { + System.arraycopy( + newNode.left.getMinPackedValue(), + offset, + newNode.minPackedValue, + offset, + bytesPerDim); + } + if (comparator.compare( + newNode.maxPackedValue, offset, newNode.left.getMaxPackedValue(), offset) + < 0) { + System.arraycopy( + newNode.left.getMaxPackedValue(), + offset, + newNode.maxPackedValue, + offset, + bytesPerDim); + } + } + } + if (newNode.right != null) { + for (int i = 0; i < numIndexDim; i++) { + int offset = i * bytesPerDim; + if (comparator.compare( + newNode.minPackedValue, offset, newNode.right.getMinPackedValue(), offset) + > 0) { + System.arraycopy( + newNode.right.getMinPackedValue(), + offset, + newNode.minPackedValue, + offset, + bytesPerDim); + } + if (comparator.compare( + newNode.maxPackedValue, offset, newNode.right.getMaxPackedValue(), offset) + < 0) { + System.arraycopy( + newNode.right.getMaxPackedValue(), + offset, + newNode.maxPackedValue, + offset, + bytesPerDim); + } + } + } + return newNode; + } + + /** Builds a Range object from a range clause */ + private static Range getRange( + RangeClause clause, + int numIndexDim, + int bytesPerDim, + ArrayUtil.ByteArrayComparator comparator) { + return new Range() { + @Override + public byte[] getMinPackedValue() { + return clause.lowerValue; + } + + @Override + public byte[] getMaxPackedValue() { + return clause.upperValue; + } + + @Override + public boolean matches(byte[] packedValue) { + for (int dim = 0; dim < numIndexDim; dim++) { + int offset = dim * bytesPerDim; + if (comparator.compare(packedValue, offset, clause.lowerValue, offset) < 0) { + // Doc's value is too low, in this dimension + return false; + } + if (comparator.compare(packedValue, offset, clause.upperValue, offset) > 0) { + // Doc's value is too high, in this dimension + return false; + } + } + return true; + } + + @Override + public PointValues.Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { + boolean crosses = false; + + for (int dim = 0; dim < numIndexDim; dim++) { + int offset = dim * bytesPerDim; + + if (comparator.compare(minPackedValue, offset, clause.upperValue, offset) > 0 + || comparator.compare(maxPackedValue, offset, clause.lowerValue, offset) < 0) { + return PointValues.Relation.CELL_OUTSIDE_QUERY; + } + + crosses |= + comparator.compare(minPackedValue, offset, clause.lowerValue, offset) < 0 + || comparator.compare(maxPackedValue, offset, clause.upperValue, offset) > 0; + } + + if (crosses) { + return PointValues.Relation.CELL_CROSSES_QUERY; + } else { + return PointValues.Relation.CELL_INSIDE_QUERY; + } + } + }; + } }