LUCENE-10456: Implement Weight#count for MultiRangeQuery (#731)

This commit is contained in:
xiaoping 2022-04-05 15:23:59 +08:00 committed by GitHub
parent f249046a1d
commit 898ec1659d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 202 additions and 4 deletions

View File

@ -73,6 +73,9 @@ New Features
implementation. `Monitor` can be created with a readonly `QueryIndex` in order to
have readonly `Monitor` instances. (Niko Usai)
* LUCENE-10456: Implement rewrite and Weight#count for MultiRangeQuery
by merging overlapping ranges . (Jianping Weng)
Improvements
---------------------

View File

@ -23,6 +23,7 @@ import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
@ -30,6 +31,7 @@ import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
@ -41,12 +43,11 @@ import org.apache.lucene.util.DocIdSetBuilder;
/**
* Abstract class for range queries involving multiple ranges against physical points such as {@code
* IntPoints} All ranges are logically ORed together TODO: Add capability for handling overlapping
* ranges at rewrite time
* IntPoints} All ranges are logically ORed together
*
* @lucene.experimental
*/
public abstract class MultiRangeQuery extends Query {
public abstract class MultiRangeQuery extends Query implements Cloneable {
/** Representation of a single clause in a MultiRangeQuery */
public static final class RangeClause {
byte[] lowerValue;
@ -140,7 +141,7 @@ public abstract class MultiRangeQuery extends Query {
final String field;
final int numDims;
final int bytesPerDim;
final List<RangeClause> rangeClauses;
List<RangeClause> rangeClauses;
/**
* Expert: create a multidimensional range query with multiple connected ranges
*
@ -163,6 +164,79 @@ public abstract class MultiRangeQuery extends Query {
}
}
/**
* Merges the overlapping ranges and returns unconnected ranges by calling {@link
* #mergeOverlappingRanges}
*/
@Override
public Query rewrite(IndexReader reader) throws IOException {
if (numDims != 1) {
return this;
}
List<RangeClause> mergedRanges = mergeOverlappingRanges(rangeClauses, bytesPerDim);
if (mergedRanges != rangeClauses) {
try {
MultiRangeQuery clone = (MultiRangeQuery) super.clone();
clone.rangeClauses = mergedRanges;
return clone;
} catch (CloneNotSupportedException e) {
throw new AssertionError(e);
}
} else {
return this;
}
}
/**
* Merges overlapping ranges and returns unconnected ranges
*
* @param rangeClauses some overlapping ranges
* @param bytesPerDim bytes per Dimension of the point value
* @return unconnected ranges
*/
static List<RangeClause> mergeOverlappingRanges(List<RangeClause> rangeClauses, int bytesPerDim) {
if (rangeClauses.size() <= 1) {
return rangeClauses;
}
List<RangeClause> originRangeClause = new ArrayList<>(rangeClauses);
final ArrayUtil.ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(bytesPerDim);
originRangeClause.sort(
new Comparator<RangeClause>() {
@Override
public int compare(RangeClause o1, RangeClause o2) {
int result = comparator.compare(o1.lowerValue, 0, o2.lowerValue, 0);
if (result == 0) {
return comparator.compare(o1.upperValue, 0, o2.upperValue, 0);
} else {
return result;
}
}
});
List<RangeClause> finalRangeClause = new ArrayList<>();
RangeClause current = originRangeClause.get(0);
for (int i = 1; i < originRangeClause.size(); i++) {
RangeClause nextClause = originRangeClause.get(i);
if (comparator.compare(nextClause.lowerValue, 0, current.upperValue, 0) > 0) {
finalRangeClause.add(current);
current = nextClause;
} else {
if (comparator.compare(nextClause.upperValue, 0, current.upperValue, 0) > 0) {
current = new RangeClause(current.lowerValue, nextClause.upperValue);
}
}
}
finalRangeClause.add(current);
/**
* in {@link #rewrite} it compares the returned rangeClauses with origin rangeClauses to decide
* if rewrite should return a new query or the origin query
*/
if (finalRangeClause.size() != rangeClauses.size()) {
return finalRangeClause;
} else {
return rangeClauses;
}
}
/*
* TODO: Organize ranges similar to how EdgeTree does, to avoid linear scan of ranges
*/
@ -314,6 +388,38 @@ public abstract class MultiRangeQuery extends Query {
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
@Override
public int count(LeafReaderContext context) throws IOException {
if (numDims != 1 || context.reader().hasDeletions() == true) {
return super.count(context);
}
PointValues pointValues = context.reader().getPointValues(field);
if (pointValues == null || pointValues.size() != pointValues.getDocCount()) {
return super.count(context);
}
int total = 0;
for (RangeClause rangeClause : rangeClauses) {
PointRangeQuery pointRangeQuery =
new PointRangeQuery(field, rangeClause.lowerValue, rangeClause.upperValue, numDims) {
@Override
protected String toString(int dimension, byte[] value) {
return MultiRangeQuery.this.toString(dimension, value);
}
};
int count =
pointRangeQuery
.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f)
.count(context);
if (count != -1) {
total += count;
} else {
return super.count(context);
}
}
return total;
}
};
}

View File

@ -19,6 +19,7 @@ package org.apache.lucene.sandbox.search;
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
import java.io.IOException;
import java.util.Random;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.DoublePoint;
import org.apache.lucene.document.FloatPoint;
@ -33,8 +34,10 @@ import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollectorManager;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.QueryUtils;
@ -761,4 +764,90 @@ public class TestMultiRangeQueries extends LuceneTestCase {
assertNotEquals(query1.hashCode(), query3.hashCode());
}
}
private void addRandomDocs(RandomIndexWriter w) throws IOException {
Random random = random();
for (int i = 0, end = random.nextInt(100, 500); i < end; i++) {
int numPoints = RandomNumbers.randomIntBetween(random(), 1, 200);
long value = RandomNumbers.randomLongBetween(random(), 0, 2000);
for (int j = 0; j < numPoints; j++) {
Document doc = new Document();
doc.add(new LongPoint("point", value));
w.addDocument(doc);
}
}
w.flush();
w.forceMerge(1);
}
/** The hit doc count of the rewritten query should be the same as origin query's */
public void testRandomRewrite() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
int dims = 1;
addRandomDocs(w);
IndexReader reader = w.getReader();
IndexSearcher searcher = newSearcher(reader);
int numIters = atLeast(100);
for (int n = 0; n < numIters; n++) {
int numRanges = RandomNumbers.randomIntBetween(random(), 1, 20);
LongPointMultiRangeBuilder builder1 = new LongPointMultiRangeBuilder("point", dims);
BooleanQuery.Builder builder2 = new BooleanQuery.Builder();
for (int i = 0; i < numRanges; i++) {
long[] lower = new long[dims];
long[] upper = new long[dims];
for (int j = 0; j < dims; j++) {
lower[j] = RandomNumbers.randomLongBetween(random(), 0, 2000);
upper[j] = lower[j] + RandomNumbers.randomLongBetween(random(), 0, 2000);
}
builder1.add(lower, upper);
builder2.add(LongPoint.newRangeQuery("point", lower, upper), BooleanClause.Occur.SHOULD);
}
MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(reader);
BooleanQuery booleanQuery = builder2.build();
int count = searcher.search(multiRangeQuery, new TotalHitCountCollectorManager());
int booleanCount = searcher.search(booleanQuery, new TotalHitCountCollectorManager());
assertEquals(booleanCount, count);
}
IOUtils.close(reader, w, dir);
}
public void testOneDimensionCount() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
int dims = 1;
addRandomDocs(w);
IndexReader reader = w.getReader();
IndexSearcher searcher = newSearcher(reader);
int numIters = atLeast(100);
for (int n = 0; n < numIters; n++) {
int numRanges = RandomNumbers.randomIntBetween(random(), 1, 20);
LongPointMultiRangeBuilder builder1 = new LongPointMultiRangeBuilder("point", dims);
BooleanQuery.Builder builder2 = new BooleanQuery.Builder();
for (int i = 0; i < numRanges; i++) {
long[] lower = new long[dims];
long[] upper = new long[dims];
for (int j = 0; j < dims; j++) {
lower[j] = RandomNumbers.randomLongBetween(random(), 0, 2000);
upper[j] = lower[j] + RandomNumbers.randomLongBetween(random(), 0, 2000);
}
builder1.add(lower, upper);
builder2.add(LongPoint.newRangeQuery("point", lower, upper), BooleanClause.Occur.SHOULD);
}
MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(reader);
BooleanQuery booleanQuery = builder2.build();
int count =
multiRangeQuery
.createWeight(searcher, ScoreMode.COMPLETE, 1.0f)
.count(searcher.getLeafContexts().get(0));
int booleanCount = searcher.count(booleanQuery);
assertEquals(booleanCount, count);
}
IOUtils.close(reader, w, dir);
}
}