LUCENE-8990: Add estimateDocCount(visitor) method to PointValues (#905)

This commit is contained in:
Ignacio Vera 2019-10-04 10:13:55 +02:00 committed by GitHub
parent d4ab808a8a
commit 9942544a7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 317 additions and 104 deletions

View File

@ -83,6 +83,10 @@ API Changes
And don't call if docFreq <= 0. The previous implementation survives as deprecated and final. It's removed in 9.0.
(Bruno Roustant, David Smiley, Alan Woodward)
* LUCENE-8990: PointValues#estimateDocCount(visitor) estimates the number of documents that would be matched by
the given IntersectVisitor. THe method is used to compute the cost() of ScorerSuppliers instead of
PointValues#estimatePointCount(visitor). (Ignacio Vera, Adrien Grand)
New Features
* LUCENE-8936: Add SpanishMinimalStemFilter (vinod kumar via Tomoko Uchida)

View File

@ -177,7 +177,7 @@ final class LatLonPointDistanceQuery extends Query {
@Override
public long cost() {
if (cost == -1) {
cost = values.estimatePointCount(visitor);
cost = values.estimateDocCount(visitor);
}
assert cost >= 0;
return cost;

View File

@ -190,7 +190,7 @@ final class LatLonPointInPolygonQuery extends Query {
public long cost() {
if (cost == -1) {
// Computing the cost may be expensive, so only do it if necessary
cost = values.estimatePointCount(visitor);
cost = values.estimateDocCount(visitor);
assert cost >= 0;
}
return cost;

View File

@ -361,7 +361,7 @@ abstract class RangeFieldQuery extends Query {
public long cost() {
if (cost == -1) {
// Computing the cost may be expensive, so only do it if necessary
cost = values.estimatePointCount(visitor);
cost = values.estimateDocCount(visitor);
assert cost >= 0;
}
return cost;

View File

@ -232,10 +232,36 @@ public abstract class PointValues {
public abstract void intersect(IntersectVisitor visitor) throws IOException;
/** Estimate the number of points that would be visited by {@link #intersect}
* with the given {@link IntersectVisitor}. This should run many times faster
* than {@link #intersect(IntersectVisitor)}. */
public abstract long estimatePointCount(IntersectVisitor visitor);
/** Estimate the number of documents that would be matched by {@link #intersect}
* with the given {@link IntersectVisitor}. This should run many times faster
* than {@link #intersect(IntersectVisitor)}.
* @see DocIdSetIterator#cost */
public abstract long estimatePointCount(IntersectVisitor visitor);
public long estimateDocCount(IntersectVisitor visitor) {
long estimatedPointCount = estimatePointCount(visitor);
int docCount = getDocCount();
double size = size();
if (estimatedPointCount >= size) {
// math all docs
return docCount;
} else if (size == docCount || estimatedPointCount == 0L ) {
// if the point count estimate is 0 or we have only single values
// return this estimate
return estimatedPointCount;
} else {
// in case of multi values estimate the number of docs using the solution provided in
// https://math.stackexchange.com/questions/1175295/urn-problem-probability-of-drawing-balls-of-k-unique-colors
// then approximate the solution for points per doc << size() which results in the expression
// D * (1 - ((N - n) / N)^(N/D))
// where D is the total number of docs, N the total number of points and n the estimated point count
long docEstimate = (long) (docCount * (1d - Math.pow((size - estimatedPointCount) / size, size / docCount)));
return docEstimate == 0L ? 1L : docEstimate;
}
}
/** Returns minimum value for each dimension, packed, or null if {@link #size} is <code>0</code> */
public abstract byte[] getMinPackedValue() throws IOException;

View File

@ -316,7 +316,7 @@ public abstract class PointRangeQuery extends Query {
public long cost() {
if (cost == -1) {
// Computing the cost may be expensive, so only do it if necessary
cost = values.estimatePointCount(visitor);
cost = values.estimateDocCount(visitor);
assert cost >= 0;
}
return cost;

View File

@ -109,15 +109,19 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
byte[] uniquePointValue = new byte[3];
random().nextBytes(uniquePointValue);
final int numDocs = atLeast(10000); // make sure we have several leaves
final boolean multiValues = random().nextBoolean();
for (int i = 0; i < numDocs; ++i) {
Document doc = new Document();
if (i == numDocs / 2) {
doc.add(new BinaryPoint("f", uniquePointValue));
} else {
do {
random().nextBytes(pointValue);
} while (Arrays.equals(pointValue, uniquePointValue));
doc.add(new BinaryPoint("f", pointValue));
final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1;
for (int j = 0; j < numValues; j ++) {
do {
random().nextBytes(pointValue);
} while (Arrays.equals(pointValue, uniquePointValue));
doc.add(new BinaryPoint("f", pointValue));
}
}
w.addDocument(doc);
}
@ -128,58 +132,72 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
PointValues points = lr.getPointValues("f");
// If all points match, then the point count is numLeaves * maxPointsInLeafNode
final int numLeaves = (int) Math.ceil((double) numDocs / maxPointsInLeafNode);
assertEquals(numLeaves * maxPointsInLeafNode,
points.estimatePointCount(new IntersectVisitor() {
@Override
public void visit(int docID, byte[] packedValue) throws IOException {}
@Override
public void visit(int docID) throws IOException {}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return Relation.CELL_INSIDE_QUERY;
}
}));
final int numLeaves = (int) Math.ceil((double) points.size() / maxPointsInLeafNode);
IntersectVisitor allPointsVisitor = new IntersectVisitor() {
@Override
public void visit(int docID, byte[] packedValue) throws IOException {}
@Override
public void visit(int docID) throws IOException {}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return Relation.CELL_INSIDE_QUERY;
}
};
assertEquals(numLeaves * maxPointsInLeafNode, points.estimatePointCount(allPointsVisitor));
assertEquals(numDocs, points.estimateDocCount(allPointsVisitor));
IntersectVisitor noPointsVisitor = new IntersectVisitor() {
@Override
public void visit(int docID, byte[] packedValue) throws IOException {}
@Override
public void visit(int docID) throws IOException {}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return Relation.CELL_OUTSIDE_QUERY;
}
};
// Return 0 if no points match
assertEquals(0,
points.estimatePointCount(new IntersectVisitor() {
@Override
public void visit(int docID, byte[] packedValue) throws IOException {}
@Override
public void visit(int docID) throws IOException {}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return Relation.CELL_OUTSIDE_QUERY;
}
}));
assertEquals(0, points.estimatePointCount(noPointsVisitor));
assertEquals(0, points.estimateDocCount(noPointsVisitor));
IntersectVisitor onePointMatchVisitor = new IntersectVisitor() {
@Override
public void visit(int docID, byte[] packedValue) throws IOException {}
@Override
public void visit(int docID) throws IOException {}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
if (Arrays.compareUnsigned(uniquePointValue, 0, 3, maxPackedValue, 0, 3) > 0 ||
Arrays.compareUnsigned(uniquePointValue, 0, 3, minPackedValue, 0, 3) < 0) {
return Relation.CELL_OUTSIDE_QUERY;
}
return Relation.CELL_CROSSES_QUERY;
}
};
// If only one point matches, then the point count is (maxPointsInLeafNode + 1) / 2
// in general, or maybe 2x that if the point is a split value
final long pointCount = points.estimatePointCount(new IntersectVisitor() {
@Override
public void visit(int docID, byte[] packedValue) throws IOException {}
@Override
public void visit(int docID) throws IOException {}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
if (Arrays.compareUnsigned(uniquePointValue, 0, 3, maxPackedValue, 0, 3) > 0 ||
Arrays.compareUnsigned(uniquePointValue, 0, 3, minPackedValue, 0, 3) < 0) {
return Relation.CELL_OUTSIDE_QUERY;
}
return Relation.CELL_CROSSES_QUERY;
}
});
final long pointCount = points.estimatePointCount(onePointMatchVisitor);
assertTrue(""+pointCount,
pointCount == (maxPointsInLeafNode + 1) / 2 || // common case
pointCount == 2*((maxPointsInLeafNode + 1) / 2)); // if the point is a split value
final long docCount = points.estimateDocCount(onePointMatchVisitor);
if (multiValues) {
assertEquals(docCount, (long) (docCount * (1d - Math.pow( (numDocs - pointCount) / points.size() , points.size() / docCount))));
} else {
assertEquals(pointCount, docCount);
}
r.close();
dir.close();
}
@ -198,16 +216,20 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
random().nextBytes(uniquePointValue[0]);
random().nextBytes(uniquePointValue[1]);
final int numDocs = atLeast(10000); // make sure we have several leaves
final boolean multiValues = random().nextBoolean();
for (int i = 0; i < numDocs; ++i) {
Document doc = new Document();
if (i == numDocs / 2) {
doc.add(new BinaryPoint("f", uniquePointValue));
} else {
do {
random().nextBytes(pointValue[0]);
random().nextBytes(pointValue[1]);
} while (Arrays.equals(pointValue[0], uniquePointValue[0]) || Arrays.equals(pointValue[1], uniquePointValue[1]));
doc.add(new BinaryPoint("f", pointValue));
final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1;
for (int j = 0; j < numValues; j ++) {
do {
random().nextBytes(pointValue[0]);
random().nextBytes(pointValue[1]);
} while (Arrays.equals(pointValue[0], uniquePointValue[0]) || Arrays.equals(pointValue[1], uniquePointValue[1]));
doc.add(new BinaryPoint("f", pointValue));
}
}
w.addDocument(doc);
}
@ -218,67 +240,161 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
PointValues points = lr.getPointValues("f");
// With >1 dims, the tree is balanced
int actualMaxPointsInLeafNode = numDocs;
long actualMaxPointsInLeafNode = points.size();
while (actualMaxPointsInLeafNode > maxPointsInLeafNode) {
actualMaxPointsInLeafNode = (actualMaxPointsInLeafNode + 1) / 2;
}
IntersectVisitor allPointsVisitor = new IntersectVisitor() {
@Override
public void visit(int docID, byte[] packedValue) throws IOException {}
@Override
public void visit(int docID) throws IOException {}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return Relation.CELL_INSIDE_QUERY;
}
};
// If all points match, then the point count is numLeaves * maxPointsInLeafNode
final int numLeaves = Integer.highestOneBit((numDocs - 1) / actualMaxPointsInLeafNode) << 1;
assertEquals(numLeaves * actualMaxPointsInLeafNode,
points.estimatePointCount(new IntersectVisitor() {
@Override
public void visit(int docID, byte[] packedValue) throws IOException {}
@Override
public void visit(int docID) throws IOException {}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return Relation.CELL_INSIDE_QUERY;
}
}));
final int numLeaves = (int) Long.highestOneBit( ((points.size() - 1) / actualMaxPointsInLeafNode)) << 1;
assertEquals(numLeaves * actualMaxPointsInLeafNode, points.estimatePointCount(allPointsVisitor));
assertEquals(numDocs, points.estimateDocCount(allPointsVisitor));
IntersectVisitor noPointsVisitor = new IntersectVisitor() {
@Override
public void visit(int docID, byte[] packedValue) throws IOException {}
@Override
public void visit(int docID) throws IOException {}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return Relation.CELL_OUTSIDE_QUERY;
}
};
// Return 0 if no points match
assertEquals(0,
points.estimatePointCount(new IntersectVisitor() {
@Override
public void visit(int docID, byte[] packedValue) throws IOException {}
@Override
public void visit(int docID) throws IOException {}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
assertEquals(0, points.estimatePointCount(noPointsVisitor));
assertEquals(0, points.estimateDocCount(noPointsVisitor));
IntersectVisitor onePointMatchVisitor = new IntersectVisitor() {
@Override
public void visit(int docID, byte[] packedValue) throws IOException {}
@Override
public void visit(int docID) throws IOException {}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
for (int dim = 0; dim < 2; ++dim) {
if (Arrays.compareUnsigned(uniquePointValue[dim], 0, 3, maxPackedValue, dim * 3, dim * 3 + 3) > 0 ||
Arrays.compareUnsigned(uniquePointValue[dim], 0, 3, minPackedValue, dim * 3, dim * 3 + 3) < 0) {
return Relation.CELL_OUTSIDE_QUERY;
}
}));
}
return Relation.CELL_CROSSES_QUERY;
}
};
// If only one point matches, then the point count is (actualMaxPointsInLeafNode + 1) / 2
// in general, or maybe 2x that if the point is a split value
final long pointCount = points.estimatePointCount(new IntersectVisitor() {
@Override
public void visit(int docID, byte[] packedValue) throws IOException {}
@Override
public void visit(int docID) throws IOException {}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
for (int dim = 0; dim < 2; ++dim) {
if (Arrays.compareUnsigned(uniquePointValue[dim], 0, 3, maxPackedValue, dim * 3, dim * 3 + 3) > 0 ||
Arrays.compareUnsigned(uniquePointValue[dim], 0, 3, minPackedValue, dim * 3, dim * 3 + 3) < 0) {
return Relation.CELL_OUTSIDE_QUERY;
}
}
return Relation.CELL_CROSSES_QUERY;
}
});
final long pointCount = points.estimatePointCount(onePointMatchVisitor);
assertTrue(""+pointCount,
pointCount == (actualMaxPointsInLeafNode + 1) / 2 || // common case
pointCount == 2*((actualMaxPointsInLeafNode + 1) / 2)); // if the point is a split value
final long docCount = points.estimateDocCount(onePointMatchVisitor);
if (multiValues) {
assertEquals(docCount, (long) (docCount * (1d - Math.pow( (numDocs - pointCount) / points.size() , points.size() / docCount))));
} else {
assertEquals(pointCount, docCount);
}
r.close();
dir.close();
}
public void testDocCountEdgeCases() {
PointValues values = getPointValues(Long.MAX_VALUE, 1, Long.MAX_VALUE);
long docs = values.estimateDocCount(null);
assertEquals(1, docs);
values = getPointValues(Long.MAX_VALUE, 1, 1);
docs = values.estimateDocCount(null);
assertEquals(1, docs);
values = getPointValues(Long.MAX_VALUE, Integer.MAX_VALUE, Long.MAX_VALUE);
docs = values.estimateDocCount(null);
assertEquals(Integer.MAX_VALUE, docs);
values = getPointValues(Long.MAX_VALUE, Integer.MAX_VALUE, Long.MAX_VALUE / 2);
docs = values.estimateDocCount(null);
assertEquals(Integer.MAX_VALUE, docs);
values = getPointValues(Long.MAX_VALUE, Integer.MAX_VALUE, 1);
docs = values.estimateDocCount(null);
assertEquals(1, docs);
}
public void testRandomDocCount() {
for (int i = 0; i < 100; i++) {
long size = TestUtil.nextLong(random(), 1, Long.MAX_VALUE);
int maxDoc = (size > Integer.MAX_VALUE) ? Integer.MAX_VALUE : Math.toIntExact(size);
int docCount = TestUtil.nextInt(random(), 1, maxDoc);
long estimatedPointCount = TestUtil.nextLong(random(), 0, size);
PointValues values = getPointValues(size, docCount, estimatedPointCount);
long docs = values.estimateDocCount(null);
assertTrue(docs <= estimatedPointCount);
assertTrue(docs <= maxDoc);
assertTrue(docs >= estimatedPointCount / (size/docCount));
}
}
private PointValues getPointValues(long size, int docCount, long estimatedPointCount) {
return new PointValues() {
@Override
public void intersect(IntersectVisitor visitor) {
throw new UnsupportedOperationException();
}
@Override
public long estimatePointCount(IntersectVisitor visitor) {
return estimatedPointCount;
}
@Override
public byte[] getMinPackedValue() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public byte[] getMaxPackedValue() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int getNumDataDimensions() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int getNumIndexDimensions() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int getBytesPerDimension() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public long size() {
return size;
}
@Override
public int getDocCount() {
return docCount;
}
};
}
}

View File

@ -22,6 +22,7 @@ import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
@ -86,4 +87,70 @@ public class TestIndexOrDocValuesQuery extends LuceneTestCase {
dir.close();
}
public void testUseIndexForSelectiveMultiValueQueries() throws IOException {
Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig()
// relies on costs and PointValues.estimateCost so we need the default codec
.setCodec(TestUtil.getDefaultCodec()));
for (int i = 0; i < 2000; ++i) {
Document doc = new Document();
if (i < 1000) {
doc.add(new StringField("f1", "bar", Store.NO));
for (int j =0; j < 500; j++) {
doc.add(new LongPoint("f2", 42L));
doc.add(new SortedNumericDocValuesField("f2", 42L));
}
} else if (i == 1001) {
doc.add(new StringField("f1", "foo", Store.NO));
doc.add(new LongPoint("f2", 2L));
doc.add(new SortedNumericDocValuesField("f2", 42L));
} else {
doc.add(new StringField("f1", "bar", Store.NO));
for (int j =0; j < 100; j++) {
doc.add(new LongPoint("f2", 2L));
doc.add(new SortedNumericDocValuesField("f2", 2L));
}
}
w.addDocument(doc);
}
w.forceMerge(1);
IndexReader reader = DirectoryReader.open(w);
IndexSearcher searcher = newSearcher(reader);
searcher.setQueryCache(null);
// The term query is less selective, so the IndexOrDocValuesQuery should use points
final Query q1 = new BooleanQuery.Builder()
.add(new TermQuery(new Term("f1", "bar")), Occur.MUST)
.add(new IndexOrDocValuesQuery(LongPoint.newExactQuery("f2", 2), SortedNumericDocValuesField.newSlowRangeQuery("f2", 2L, 2L)), Occur.MUST)
.build();
final Weight w1 = searcher.createWeight(searcher.rewrite(q1), ScoreMode.COMPLETE, 1);
final Scorer s1 = w1.scorer(searcher.getIndexReader().leaves().get(0));
assertNull(s1.twoPhaseIterator()); // means we use points
// The term query is less selective, so the IndexOrDocValuesQuery should use points
final Query q2 = new BooleanQuery.Builder()
.add(new TermQuery(new Term("f1", "bar")), Occur.MUST)
.add(new IndexOrDocValuesQuery(LongPoint.newExactQuery("f2", 42), SortedNumericDocValuesField.newSlowRangeQuery("f2", 42, 42L)), Occur.MUST)
.build();
final Weight w2 = searcher.createWeight(searcher.rewrite(q2), ScoreMode.COMPLETE, 1);
final Scorer s2 = w2.scorer(searcher.getIndexReader().leaves().get(0));
assertNull(s2.twoPhaseIterator()); // means we use points
// The term query is more selective, so the IndexOrDocValuesQuery should use doc values
final Query q3 = new BooleanQuery.Builder()
.add(new TermQuery(new Term("f1", "foo")), Occur.MUST)
.add(new IndexOrDocValuesQuery(LongPoint.newExactQuery("f2", 42), SortedNumericDocValuesField.newSlowRangeQuery("f2", 42, 42L)), Occur.MUST)
.build();
final Weight w3 = searcher.createWeight(searcher.rewrite(q3), ScoreMode.COMPLETE, 1);
final Scorer s3 = w3.scorer(searcher.getIndexReader().leaves().get(0));
assertNotNull(s3.twoPhaseIterator()); // means we use doc values
reader.close();
w.close();
dir.close();
}
}

View File

@ -283,7 +283,7 @@ abstract class ShapeQuery extends Query {
public long cost() {
if (cost == -1) {
// Computing the cost may be expensive, so only do it if necessary
cost = values.estimatePointCount(getEstimateVisitor(query));
cost = values.estimateDocCount(getEstimateVisitor(query));
assert cost >= 0;
}
return cost;

View File

@ -278,7 +278,7 @@ public abstract class MultiRangeQuery extends Query {
public long cost() {
if (cost == -1) {
// Computing the cost may be expensive, so only do it if necessary
cost = values.estimatePointCount(visitor) * rangeClauses.size();
cost = values.estimateDocCount(visitor) * rangeClauses.size();
assert cost >= 0;
}
return cost;