mirror of https://github.com/apache/lucene.git
Pruning of estimating the point value count since BooleanScorerSupplier
This commit is contained in:
parent
12ca4779b9
commit
939893227c
|
@ -20,6 +20,8 @@ import static org.apache.lucene.geo.GeoEncodingUtils.decodeLatitude;
|
||||||
import static org.apache.lucene.geo.GeoEncodingUtils.decodeLongitude;
|
import static org.apache.lucene.geo.GeoEncodingUtils.decodeLongitude;
|
||||||
import static org.apache.lucene.geo.GeoEncodingUtils.encodeLatitude;
|
import static org.apache.lucene.geo.GeoEncodingUtils.encodeLatitude;
|
||||||
import static org.apache.lucene.geo.GeoEncodingUtils.encodeLongitude;
|
import static org.apache.lucene.geo.GeoEncodingUtils.encodeLongitude;
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO;
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import org.apache.lucene.geo.GeoEncodingUtils;
|
import org.apache.lucene.geo.GeoEncodingUtils;
|
||||||
|
@ -40,6 +42,7 @@ import org.apache.lucene.search.QueryVisitor;
|
||||||
import org.apache.lucene.search.ScoreMode;
|
import org.apache.lucene.search.ScoreMode;
|
||||||
import org.apache.lucene.search.Scorer;
|
import org.apache.lucene.search.Scorer;
|
||||||
import org.apache.lucene.search.ScorerSupplier;
|
import org.apache.lucene.search.ScorerSupplier;
|
||||||
|
import org.apache.lucene.search.TotalHits;
|
||||||
import org.apache.lucene.search.Weight;
|
import org.apache.lucene.search.Weight;
|
||||||
import org.apache.lucene.util.BitSetIterator;
|
import org.apache.lucene.util.BitSetIterator;
|
||||||
import org.apache.lucene.util.DocIdSetBuilder;
|
import org.apache.lucene.util.DocIdSetBuilder;
|
||||||
|
@ -139,7 +142,7 @@ final class LatLonPointDistanceQuery extends Query {
|
||||||
|
|
||||||
return new ScorerSupplier() {
|
return new ScorerSupplier() {
|
||||||
|
|
||||||
long cost = -1;
|
TotalHits estimatedCount;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Scorer get(long leadCost) throws IOException {
|
public Scorer get(long leadCost) throws IOException {
|
||||||
|
@ -162,11 +165,28 @@ final class LatLonPointDistanceQuery extends Query {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long cost() {
|
public long cost() {
|
||||||
if (cost == -1) {
|
if (estimatedCount == null || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) {
|
||||||
cost = values.estimateDocCount(visitor);
|
estimatedCount =
|
||||||
|
new TotalHits(values.estimateDocCount(visitor, Long.MAX_VALUE), EQUAL_TO);
|
||||||
|
assert estimatedCount.value() >= 0;
|
||||||
}
|
}
|
||||||
assert cost >= 0;
|
return estimatedCount.value();
|
||||||
return cost;
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) {
|
||||||
|
if (estimatedCount == null
|
||||||
|
|| (estimatedCount.value() < upperBound
|
||||||
|
&& estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) {
|
||||||
|
long cost = values.estimateDocCount(visitor, Long.MAX_VALUE);
|
||||||
|
if (cost < upperBound) {
|
||||||
|
estimatedCount = new TotalHits(cost, EQUAL_TO);
|
||||||
|
} else if (estimatedCount == null || cost > estimatedCount.value()) {
|
||||||
|
estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO);
|
||||||
|
}
|
||||||
|
assert estimatedCount.value() >= 0;
|
||||||
|
}
|
||||||
|
return estimatedCount;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,9 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.lucene.document;
|
package org.apache.lucene.document;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO;
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
@ -34,6 +37,7 @@ import org.apache.lucene.search.QueryVisitor;
|
||||||
import org.apache.lucene.search.ScoreMode;
|
import org.apache.lucene.search.ScoreMode;
|
||||||
import org.apache.lucene.search.Scorer;
|
import org.apache.lucene.search.Scorer;
|
||||||
import org.apache.lucene.search.ScorerSupplier;
|
import org.apache.lucene.search.ScorerSupplier;
|
||||||
|
import org.apache.lucene.search.TotalHits;
|
||||||
import org.apache.lucene.search.Weight;
|
import org.apache.lucene.search.Weight;
|
||||||
import org.apache.lucene.util.ArrayUtil;
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
import org.apache.lucene.util.ArrayUtil.ByteArrayComparator;
|
import org.apache.lucene.util.ArrayUtil.ByteArrayComparator;
|
||||||
|
@ -477,7 +481,7 @@ public abstract class RangeFieldQuery extends Query {
|
||||||
|
|
||||||
final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
||||||
final IntersectVisitor visitor = getIntersectVisitor(result);
|
final IntersectVisitor visitor = getIntersectVisitor(result);
|
||||||
long cost = -1;
|
TotalHits estimatedCount = null;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Scorer get(long leadCost) throws IOException {
|
public Scorer get(long leadCost) throws IOException {
|
||||||
|
@ -488,12 +492,29 @@ public abstract class RangeFieldQuery extends Query {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long cost() {
|
public long cost() {
|
||||||
if (cost == -1) {
|
if (estimatedCount == null || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) {
|
||||||
// Computing the cost may be expensive, so only do it if necessary
|
estimatedCount =
|
||||||
cost = values.estimateDocCount(visitor);
|
new TotalHits(values.estimateDocCount(visitor, Long.MAX_VALUE), EQUAL_TO);
|
||||||
assert cost >= 0;
|
assert estimatedCount.value() >= 0;
|
||||||
}
|
}
|
||||||
return cost;
|
return estimatedCount.value();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) {
|
||||||
|
if (estimatedCount == null
|
||||||
|
|| (estimatedCount.value() < upperBound
|
||||||
|
&& estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) {
|
||||||
|
// Computing the cost may be expensive, so only do it if necessary
|
||||||
|
long cost = values.estimateDocCount(visitor, upperBound);
|
||||||
|
if (cost < upperBound) {
|
||||||
|
estimatedCount = new TotalHits(cost, EQUAL_TO);
|
||||||
|
} else if (estimatedCount == null || cost > estimatedCount.value()) {
|
||||||
|
estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO);
|
||||||
|
}
|
||||||
|
assert estimatedCount.value() >= 0;
|
||||||
|
}
|
||||||
|
return estimatedCount;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,9 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.lucene.document;
|
package org.apache.lucene.document;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO;
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import org.apache.lucene.geo.Component2D;
|
import org.apache.lucene.geo.Component2D;
|
||||||
|
@ -36,6 +39,7 @@ import org.apache.lucene.search.QueryVisitor;
|
||||||
import org.apache.lucene.search.ScoreMode;
|
import org.apache.lucene.search.ScoreMode;
|
||||||
import org.apache.lucene.search.Scorer;
|
import org.apache.lucene.search.Scorer;
|
||||||
import org.apache.lucene.search.ScorerSupplier;
|
import org.apache.lucene.search.ScorerSupplier;
|
||||||
|
import org.apache.lucene.search.TotalHits;
|
||||||
import org.apache.lucene.search.Weight;
|
import org.apache.lucene.search.Weight;
|
||||||
import org.apache.lucene.util.DocIdSetBuilder;
|
import org.apache.lucene.util.DocIdSetBuilder;
|
||||||
|
|
||||||
|
@ -144,7 +148,7 @@ final class XYPointInGeometryQuery extends Query {
|
||||||
|
|
||||||
return new ScorerSupplier() {
|
return new ScorerSupplier() {
|
||||||
|
|
||||||
long cost = -1;
|
TotalHits estimatedCount;
|
||||||
DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
||||||
final IntersectVisitor visitor = getIntersectVisitor(result, tree);
|
final IntersectVisitor visitor = getIntersectVisitor(result, tree);
|
||||||
|
|
||||||
|
@ -156,12 +160,29 @@ final class XYPointInGeometryQuery extends Query {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long cost() {
|
public long cost() {
|
||||||
if (cost == -1) {
|
if (estimatedCount == null || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) {
|
||||||
// Computing the cost may be expensive, so only do it if necessary
|
// Computing the cost may be expensive, so only do it if necessary
|
||||||
cost = values.estimateDocCount(visitor);
|
estimatedCount =
|
||||||
assert cost >= 0;
|
new TotalHits(values.estimateDocCount(visitor, Long.MAX_VALUE), EQUAL_TO);
|
||||||
|
assert estimatedCount.value() >= 0;
|
||||||
}
|
}
|
||||||
return cost;
|
return estimatedCount.value();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) {
|
||||||
|
if (estimatedCount == null
|
||||||
|
|| (estimatedCount.value() < upperBound
|
||||||
|
&& estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) {
|
||||||
|
long cost = values.estimateDocCount(visitor, upperBound);
|
||||||
|
if (cost < upperBound) {
|
||||||
|
estimatedCount = new TotalHits(cost, EQUAL_TO);
|
||||||
|
} else if (estimatedCount == null || cost > estimatedCount.value()) {
|
||||||
|
estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO);
|
||||||
|
}
|
||||||
|
assert estimatedCount.value() >= 0;
|
||||||
|
}
|
||||||
|
return estimatedCount;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -385,9 +385,17 @@ public abstract class PointValues {
|
||||||
* IntersectVisitor}. This should run many times faster than {@link #intersect(IntersectVisitor)}.
|
* IntersectVisitor}. This should run many times faster than {@link #intersect(IntersectVisitor)}.
|
||||||
*/
|
*/
|
||||||
public final long estimatePointCount(IntersectVisitor visitor) {
|
public final long estimatePointCount(IntersectVisitor visitor) {
|
||||||
|
return estimatePointCount(visitor, Long.MAX_VALUE);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Estimate the number of points within the given {@link IntersectVisitor} and a maximum of
|
||||||
|
* {upperBound}
|
||||||
|
*/
|
||||||
|
public final long estimatePointCount(IntersectVisitor visitor, long upperBound) {
|
||||||
try {
|
try {
|
||||||
final PointTree pointTree = getPointTree();
|
final PointTree pointTree = getPointTree();
|
||||||
final long count = estimatePointCount(visitor, pointTree, Long.MAX_VALUE);
|
final long count = estimatePointCount(visitor, pointTree, upperBound);
|
||||||
assert pointTree.moveToParent() == false;
|
assert pointTree.moveToParent() == false;
|
||||||
return count;
|
return count;
|
||||||
} catch (IOException ioe) {
|
} catch (IOException ioe) {
|
||||||
|
@ -449,7 +457,15 @@ public abstract class PointValues {
|
||||||
* @see DocIdSetIterator#cost
|
* @see DocIdSetIterator#cost
|
||||||
*/
|
*/
|
||||||
public final long estimateDocCount(IntersectVisitor visitor) {
|
public final long estimateDocCount(IntersectVisitor visitor) {
|
||||||
long estimatedPointCount = estimatePointCount(visitor);
|
return estimateDocCount(visitor, Long.MAX_VALUE);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Estimate the number of documents that would be matched by {@link #intersect} with the given
|
||||||
|
* {upperBound}
|
||||||
|
*/
|
||||||
|
public final long estimateDocCount(IntersectVisitor visitor, long upperBound) {
|
||||||
|
long estimatedPointCount = estimatePointCount(visitor, upperBound);
|
||||||
int docCount = getDocCount();
|
int docCount = getDocCount();
|
||||||
double size = size();
|
double size = size();
|
||||||
if (estimatedPointCount >= size) {
|
if (estimatedPointCount >= size) {
|
||||||
|
|
|
@ -16,6 +16,9 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.lucene.search;
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO;
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -24,7 +27,7 @@ import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.OptionalLong;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
import org.apache.lucene.search.BooleanClause.Occur;
|
import org.apache.lucene.search.BooleanClause.Occur;
|
||||||
import org.apache.lucene.search.Weight.DefaultBulkScorer;
|
import org.apache.lucene.search.Weight.DefaultBulkScorer;
|
||||||
|
@ -35,7 +38,7 @@ final class BooleanScorerSupplier extends ScorerSupplier {
|
||||||
private final ScoreMode scoreMode;
|
private final ScoreMode scoreMode;
|
||||||
private final int minShouldMatch;
|
private final int minShouldMatch;
|
||||||
private final int maxDoc;
|
private final int maxDoc;
|
||||||
private long cost = -1;
|
private TotalHits estimatedCount = null;
|
||||||
private boolean topLevelScoringClause;
|
private boolean topLevelScoringClause;
|
||||||
|
|
||||||
BooleanScorerSupplier(
|
BooleanScorerSupplier(
|
||||||
|
@ -69,21 +72,40 @@ final class BooleanScorerSupplier extends ScorerSupplier {
|
||||||
this.maxDoc = maxDoc;
|
this.maxDoc = maxDoc;
|
||||||
}
|
}
|
||||||
|
|
||||||
private long computeCost() {
|
private TotalHits computeCost(long upperBound) {
|
||||||
OptionalLong minRequiredCost =
|
|
||||||
|
TotalHits minRequiredCost = null;
|
||||||
|
TotalHits totalHits = null;
|
||||||
|
for (ScorerSupplier scorerSupplier :
|
||||||
Stream.concat(subs.get(Occur.MUST).stream(), subs.get(Occur.FILTER).stream())
|
Stream.concat(subs.get(Occur.MUST).stream(), subs.get(Occur.FILTER).stream())
|
||||||
.mapToLong(ScorerSupplier::cost)
|
.collect(Collectors.toList())) {
|
||||||
.min();
|
totalHits = scorerSupplier.isEstimatedPointCountGreaterThanOrEqualTo(upperBound);
|
||||||
if (minRequiredCost.isPresent() && minShouldMatch == 0) {
|
if (totalHits.relation() == EQUAL_TO && totalHits.value() < upperBound) {
|
||||||
return minRequiredCost.getAsLong();
|
upperBound = totalHits.value();
|
||||||
|
minRequiredCost = totalHits;
|
||||||
|
} else if (minRequiredCost == null) {
|
||||||
|
minRequiredCost = totalHits;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (minRequiredCost != null && minShouldMatch == 0) {
|
||||||
|
return minRequiredCost;
|
||||||
} else {
|
} else {
|
||||||
final Collection<ScorerSupplier> optionalScorers = subs.get(Occur.SHOULD);
|
final Collection<ScorerSupplier> optionalScorers = subs.get(Occur.SHOULD);
|
||||||
final long shouldCost =
|
final TotalHits shouldCost =
|
||||||
ScorerUtil.costWithMinShouldMatch(
|
ScorerUtil.costWithMinShouldMatch(
|
||||||
optionalScorers.stream().mapToLong(ScorerSupplier::cost),
|
optionalScorers, optionalScorers.size(), minShouldMatch, upperBound);
|
||||||
optionalScorers.size(),
|
|
||||||
minShouldMatch);
|
if (shouldCost.relation() == EQUAL_TO) {
|
||||||
return Math.min(minRequiredCost.orElse(Long.MAX_VALUE), shouldCost);
|
return shouldCost;
|
||||||
|
} else if (minRequiredCost != null && minRequiredCost.relation() == EQUAL_TO) {
|
||||||
|
return minRequiredCost;
|
||||||
|
} else if (minRequiredCost != null) {
|
||||||
|
// or we should return small one? it doesn't matter
|
||||||
|
return (shouldCost.value() > minRequiredCost.value() ? shouldCost : minRequiredCost);
|
||||||
|
} else {
|
||||||
|
return shouldCost;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,10 +125,22 @@ final class BooleanScorerSupplier extends ScorerSupplier {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long cost() {
|
public long cost() {
|
||||||
if (cost == -1) {
|
if (estimatedCount == null || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) {
|
||||||
cost = computeCost();
|
estimatedCount = computeCost(Long.MAX_VALUE);
|
||||||
|
assert estimatedCount.value() >= 0;
|
||||||
}
|
}
|
||||||
return cost;
|
return estimatedCount.value();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) {
|
||||||
|
if (estimatedCount == null
|
||||||
|
|| (estimatedCount.value() < upperBound
|
||||||
|
&& estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) {
|
||||||
|
estimatedCount = computeCost(upperBound);
|
||||||
|
assert estimatedCount.value() >= 0;
|
||||||
|
}
|
||||||
|
return estimatedCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -126,7 +160,10 @@ final class BooleanScorerSupplier extends ScorerSupplier {
|
||||||
|
|
||||||
private Scorer getInternal(long leadCost) throws IOException {
|
private Scorer getInternal(long leadCost) throws IOException {
|
||||||
// three cases: conjunction, disjunction, or mix
|
// three cases: conjunction, disjunction, or mix
|
||||||
leadCost = Math.min(leadCost, cost());
|
estimatedCount = isEstimatedPointCountGreaterThanOrEqualTo(leadCost);
|
||||||
|
if (estimatedCount.relation() == EQUAL_TO && estimatedCount.value() < leadCost) {
|
||||||
|
leadCost = estimatedCount.value();
|
||||||
|
}
|
||||||
|
|
||||||
// pure conjunction
|
// pure conjunction
|
||||||
if (subs.get(Occur.SHOULD).isEmpty()) {
|
if (subs.get(Occur.SHOULD).isEmpty()) {
|
||||||
|
@ -202,10 +239,11 @@ final class BooleanScorerSupplier extends ScorerSupplier {
|
||||||
// there will be no matches in the end) so we should only use
|
// there will be no matches in the end) so we should only use
|
||||||
// BooleanScorer if matches are very dense
|
// BooleanScorer if matches are very dense
|
||||||
costThreshold = maxDoc / 3;
|
costThreshold = maxDoc / 3;
|
||||||
}
|
|
||||||
|
|
||||||
if (cost() < costThreshold) {
|
TotalHits estimatedCount = isEstimatedPointCountGreaterThanOrEqualTo(costThreshold);
|
||||||
return null;
|
if (estimatedCount.relation() == EQUAL_TO) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
positiveScorer = optionalBulkScorer();
|
positiveScorer = optionalBulkScorer();
|
||||||
|
@ -315,10 +353,16 @@ final class BooleanScorerSupplier extends ScorerSupplier {
|
||||||
return scorer;
|
return scorer;
|
||||||
}
|
}
|
||||||
|
|
||||||
long leadCost =
|
long leadCost = Long.MAX_VALUE;
|
||||||
subs.get(Occur.MUST).stream().mapToLong(ScorerSupplier::cost).min().orElse(Long.MAX_VALUE);
|
TotalHits estimatedCount;
|
||||||
leadCost =
|
for (ScorerSupplier scorerSupplier :
|
||||||
subs.get(Occur.FILTER).stream().mapToLong(ScorerSupplier::cost).min().orElse(leadCost);
|
Stream.concat(subs.get(Occur.MUST).stream(), subs.get(Occur.FILTER).stream())
|
||||||
|
.collect(Collectors.toList())) {
|
||||||
|
estimatedCount = scorerSupplier.isEstimatedPointCountGreaterThanOrEqualTo(leadCost);
|
||||||
|
if (estimatedCount.relation() == EQUAL_TO && estimatedCount.value() < leadCost) {
|
||||||
|
leadCost = estimatedCount.value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
List<Scorer> requiredNoScoring = new ArrayList<>();
|
List<Scorer> requiredNoScoring = new ArrayList<>();
|
||||||
for (ScorerSupplier ss : subs.get(Occur.FILTER)) {
|
for (ScorerSupplier ss : subs.get(Occur.FILTER)) {
|
||||||
|
|
|
@ -16,6 +16,9 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.lucene.search;
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO;
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.UncheckedIOException;
|
import java.io.UncheckedIOException;
|
||||||
import java.util.AbstractCollection;
|
import java.util.AbstractCollection;
|
||||||
|
@ -176,7 +179,7 @@ public abstract class PointInSetQuery extends Query implements Accountable {
|
||||||
// We optimize this common case, effectively doing a merge sort of the indexed values vs
|
// We optimize this common case, effectively doing a merge sort of the indexed values vs
|
||||||
// the queried set:
|
// the queried set:
|
||||||
return new ScorerSupplier() {
|
return new ScorerSupplier() {
|
||||||
long cost = -1; // calculate lazily, only once
|
TotalHits estimatedCount = null; // calculate lazily, only once
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Scorer get(long leadCost) throws IOException {
|
public Scorer get(long leadCost) throws IOException {
|
||||||
|
@ -189,15 +192,42 @@ public abstract class PointInSetQuery extends Query implements Accountable {
|
||||||
@Override
|
@Override
|
||||||
public long cost() {
|
public long cost() {
|
||||||
try {
|
try {
|
||||||
if (cost == -1) {
|
if (estimatedCount == null
|
||||||
|
|| estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) {
|
||||||
// Computing the cost may be expensive, so only do it if necessary
|
// Computing the cost may be expensive, so only do it if necessary
|
||||||
DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
||||||
cost =
|
estimatedCount =
|
||||||
values.estimateDocCount(
|
new TotalHits(
|
||||||
new MergePointVisitor(sortedPackedPoints.iterator(), result));
|
values.estimateDocCount(
|
||||||
assert cost >= 0;
|
new MergePointVisitor(sortedPackedPoints.iterator(), result),
|
||||||
|
Long.MAX_VALUE),
|
||||||
|
EQUAL_TO);
|
||||||
|
assert estimatedCount.value() >= 0;
|
||||||
}
|
}
|
||||||
return cost;
|
return estimatedCount.value();
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new UncheckedIOException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) {
|
||||||
|
try {
|
||||||
|
if (estimatedCount == null
|
||||||
|
|| (estimatedCount.value() < upperBound
|
||||||
|
&& estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) {
|
||||||
|
DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
||||||
|
long cost =
|
||||||
|
values.estimateDocCount(
|
||||||
|
new MergePointVisitor(sortedPackedPoints.iterator(), result), upperBound);
|
||||||
|
if (cost < upperBound) {
|
||||||
|
estimatedCount = new TotalHits(cost, EQUAL_TO);
|
||||||
|
} else if (estimatedCount == null || cost > estimatedCount.value()) {
|
||||||
|
estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO);
|
||||||
|
}
|
||||||
|
assert estimatedCount.value() >= 0;
|
||||||
|
}
|
||||||
|
return estimatedCount;
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new UncheckedIOException(e);
|
throw new UncheckedIOException(e);
|
||||||
}
|
}
|
||||||
|
@ -211,7 +241,7 @@ public abstract class PointInSetQuery extends Query implements Accountable {
|
||||||
// index, which is probably tricky!
|
// index, which is probably tricky!
|
||||||
|
|
||||||
return new ScorerSupplier() {
|
return new ScorerSupplier() {
|
||||||
long cost = -1; // calculate lazily, only once
|
TotalHits estimatedCount = null; // calculate lazily, only once
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Scorer get(long leadCost) throws IOException {
|
public Scorer get(long leadCost) throws IOException {
|
||||||
|
@ -228,18 +258,49 @@ public abstract class PointInSetQuery extends Query implements Accountable {
|
||||||
@Override
|
@Override
|
||||||
public long cost() {
|
public long cost() {
|
||||||
try {
|
try {
|
||||||
if (cost == -1) {
|
if (estimatedCount == null
|
||||||
|
|| estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) {
|
||||||
DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
||||||
SinglePointVisitor visitor = new SinglePointVisitor(result);
|
SinglePointVisitor visitor = new SinglePointVisitor(result);
|
||||||
TermIterator iterator = sortedPackedPoints.iterator();
|
TermIterator iterator = sortedPackedPoints.iterator();
|
||||||
cost = 0;
|
long cost = 0;
|
||||||
for (BytesRef point = iterator.next(); point != null; point = iterator.next()) {
|
for (BytesRef point = iterator.next(); point != null; point = iterator.next()) {
|
||||||
visitor.setPoint(point);
|
visitor.setPoint(point);
|
||||||
cost += values.estimateDocCount(visitor);
|
cost += values.estimateDocCount(visitor, Long.MAX_VALUE);
|
||||||
|
}
|
||||||
|
assert cost >= 0;
|
||||||
|
estimatedCount = new TotalHits(cost, EQUAL_TO);
|
||||||
|
}
|
||||||
|
return estimatedCount.value();
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new UncheckedIOException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) {
|
||||||
|
try {
|
||||||
|
if (estimatedCount == null
|
||||||
|
|| (estimatedCount.value() < upperBound
|
||||||
|
&& estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) {
|
||||||
|
DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
||||||
|
SinglePointVisitor visitor = new SinglePointVisitor(result);
|
||||||
|
TermIterator iterator = sortedPackedPoints.iterator();
|
||||||
|
long cost = 0;
|
||||||
|
for (BytesRef point = iterator.next(); point != null; point = iterator.next()) {
|
||||||
|
visitor.setPoint(point);
|
||||||
|
cost += values.estimateDocCount(visitor, upperBound);
|
||||||
|
if (cost >= upperBound) {
|
||||||
|
estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (cost < upperBound) {
|
||||||
|
estimatedCount = new TotalHits(cost, EQUAL_TO);
|
||||||
}
|
}
|
||||||
assert cost >= 0;
|
assert cost >= 0;
|
||||||
}
|
}
|
||||||
return cost;
|
return estimatedCount;
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new UncheckedIOException(e);
|
throw new UncheckedIOException(e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,9 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.lucene.search;
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO;
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
@ -360,7 +363,7 @@ public abstract class PointRangeQuery extends Query {
|
||||||
|
|
||||||
final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
||||||
final IntersectVisitor visitor = getIntersectVisitor(result);
|
final IntersectVisitor visitor = getIntersectVisitor(result);
|
||||||
long cost = -1;
|
TotalHits estimatedCount = null;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Scorer get(long leadCost) throws IOException {
|
public Scorer get(long leadCost) throws IOException {
|
||||||
|
@ -385,12 +388,28 @@ public abstract class PointRangeQuery extends Query {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long cost() {
|
public long cost() {
|
||||||
if (cost == -1) {
|
if (estimatedCount == null || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) {
|
||||||
// Computing the cost may be expensive, so only do it if necessary
|
estimatedCount =
|
||||||
cost = values.estimateDocCount(visitor);
|
new TotalHits(values.estimateDocCount(visitor, Long.MAX_VALUE), EQUAL_TO);
|
||||||
assert cost >= 0;
|
assert estimatedCount.value() >= 0;
|
||||||
}
|
}
|
||||||
return cost;
|
return estimatedCount.value();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) {
|
||||||
|
if (estimatedCount == null
|
||||||
|
|| (estimatedCount.value() < upperBound
|
||||||
|
&& estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) {
|
||||||
|
long cost = values.estimateDocCount(visitor, upperBound);
|
||||||
|
if (cost < upperBound) {
|
||||||
|
estimatedCount = new TotalHits(cost, EQUAL_TO);
|
||||||
|
} else if (estimatedCount == null || cost > estimatedCount.value()) {
|
||||||
|
estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO);
|
||||||
|
}
|
||||||
|
assert estimatedCount.value() >= 0;
|
||||||
|
}
|
||||||
|
return estimatedCount;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,6 +53,10 @@ public abstract class ScorerSupplier {
|
||||||
*/
|
*/
|
||||||
public abstract long cost();
|
public abstract long cost();
|
||||||
|
|
||||||
|
public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) {
|
||||||
|
return new TotalHits(cost(), TotalHits.Relation.EQUAL_TO);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Inform this {@link ScorerSupplier} that its returned scorers produce scores that get passed to
|
* Inform this {@link ScorerSupplier} that its returned scorers produce scores that get passed to
|
||||||
* the collector, as opposed to partial scores that then need to get combined (e.g. summed up).
|
* the collector, as opposed to partial scores that then need to get combined (e.g. summed up).
|
||||||
|
|
|
@ -16,6 +16,10 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.lucene.search;
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO;
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
import java.util.stream.LongStream;
|
import java.util.stream.LongStream;
|
||||||
import java.util.stream.StreamSupport;
|
import java.util.stream.StreamSupport;
|
||||||
import org.apache.lucene.util.PriorityQueue;
|
import org.apache.lucene.util.PriorityQueue;
|
||||||
|
@ -46,4 +50,33 @@ class ScorerUtil {
|
||||||
costs.forEach(pq::insertWithOverflow);
|
costs.forEach(pq::insertWithOverflow);
|
||||||
return StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum();
|
return StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static TotalHits costWithMinShouldMatch(
|
||||||
|
Collection<ScorerSupplier> collection, int numScorers, int minShouldMatch, long upperBound) {
|
||||||
|
int queueSize = Math.min(numScorers - minShouldMatch + 1, collection.size());
|
||||||
|
final PriorityQueue<Long> pq =
|
||||||
|
new PriorityQueue<Long>(queueSize) {
|
||||||
|
@Override
|
||||||
|
protected boolean lessThan(Long a, Long b) {
|
||||||
|
return a > b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Keep track of the last eliminated value that was added to the priority queue.
|
||||||
|
long leastTopNScoreBound = upperBound;
|
||||||
|
for (ScorerSupplier supplier : collection) {
|
||||||
|
TotalHits totalHits = supplier.isEstimatedPointCountGreaterThanOrEqualTo(leastTopNScoreBound);
|
||||||
|
if (totalHits.relation() == EQUAL_TO) {
|
||||||
|
Long oldCost = pq.insertWithOverflow(totalHits.value());
|
||||||
|
if (oldCost != null && leastTopNScoreBound > oldCost) {
|
||||||
|
leastTopNScoreBound = oldCost;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
long cost = StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum();
|
||||||
|
if (pq.size() < queueSize || cost > upperBound) {
|
||||||
|
return new TotalHits(Math.max(cost, upperBound), GREATER_THAN_OR_EQUAL_TO);
|
||||||
|
} else {
|
||||||
|
return new TotalHits(cost, EQUAL_TO);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,9 @@
|
||||||
|
|
||||||
package org.apache.lucene.sandbox.search;
|
package org.apache.lucene.sandbox.search;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO;
|
||||||
|
import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -36,6 +39,7 @@ import org.apache.lucene.search.QueryVisitor;
|
||||||
import org.apache.lucene.search.ScoreMode;
|
import org.apache.lucene.search.ScoreMode;
|
||||||
import org.apache.lucene.search.Scorer;
|
import org.apache.lucene.search.Scorer;
|
||||||
import org.apache.lucene.search.ScorerSupplier;
|
import org.apache.lucene.search.ScorerSupplier;
|
||||||
|
import org.apache.lucene.search.TotalHits;
|
||||||
import org.apache.lucene.search.Weight;
|
import org.apache.lucene.search.Weight;
|
||||||
import org.apache.lucene.util.ArrayUtil;
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
import org.apache.lucene.util.DocIdSetBuilder;
|
import org.apache.lucene.util.DocIdSetBuilder;
|
||||||
|
@ -352,7 +356,7 @@ public abstract class MultiRangeQuery extends Query implements Cloneable {
|
||||||
|
|
||||||
final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
||||||
final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, range);
|
final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, range);
|
||||||
long cost = -1;
|
TotalHits estimatedCount = null;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Scorer get(long leadCost) throws IOException {
|
public Scorer get(long leadCost) throws IOException {
|
||||||
|
@ -363,12 +367,30 @@ public abstract class MultiRangeQuery extends Query implements Cloneable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long cost() {
|
public long cost() {
|
||||||
if (cost == -1) {
|
if (estimatedCount == null || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) {
|
||||||
// Computing the cost may be expensive, so only do it if necessary
|
estimatedCount =
|
||||||
cost = values.estimateDocCount(visitor) * rangeClauses.size();
|
new TotalHits(
|
||||||
assert cost >= 0;
|
values.estimateDocCount(visitor, Long.MAX_VALUE) * rangeClauses.size(),
|
||||||
|
EQUAL_TO);
|
||||||
|
assert estimatedCount.value() >= 0;
|
||||||
}
|
}
|
||||||
return cost;
|
return estimatedCount.value();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) {
|
||||||
|
if (estimatedCount == null
|
||||||
|
|| (estimatedCount.value() < upperBound
|
||||||
|
&& estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) {
|
||||||
|
long cost = values.estimateDocCount(visitor, upperBound);
|
||||||
|
if (cost < upperBound) {
|
||||||
|
estimatedCount = new TotalHits(cost, EQUAL_TO);
|
||||||
|
} else if (estimatedCount == null || cost > estimatedCount.value()) {
|
||||||
|
estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO);
|
||||||
|
}
|
||||||
|
assert estimatedCount.value() >= 0;
|
||||||
|
}
|
||||||
|
return estimatedCount;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue