LUCENE-7055: Add ScorerProvider to get an estimation of the cost of scorers before building them.

This commit is contained in:
Adrien Grand 2017-01-16 15:47:53 +01:00
parent 38af094d17
commit 86233cb95d
29 changed files with 1255 additions and 205 deletions

View File

@ -73,6 +73,13 @@ Bug Fixes
* LUCENE-7630: Fix (Edge)NGramTokenFilter to no longer drop payloads * LUCENE-7630: Fix (Edge)NGramTokenFilter to no longer drop payloads
and preserve all attributes. (Nathan Gass via Uwe Schindler) and preserve all attributes. (Nathan Gass via Uwe Schindler)
Improvements
* LUCENE-7055: Added Weight#scorerSupplier, which allows to estimate the cost
of a Scorer before actually building it, in order to optimize how the query
should be run, eg. using points or doc values depending on costs of other
parts of the query. (Adrien Grand)
======================= Lucene 6.4.0 ======================= ======================= Lucene 6.4.0 =======================
API Changes API Changes

View File

@ -286,6 +286,56 @@ final class SimpleTextBKDReader extends PointValues implements Accountable {
} }
} }
@Override
public long estimatePointCount(IntersectVisitor visitor) {
return estimatePointCount(getIntersectState(visitor), 1, minPackedValue, maxPackedValue);
}
private long estimatePointCount(IntersectState state,
int nodeID, byte[] cellMinPacked, byte[] cellMaxPacked) {
Relation r = state.visitor.compare(cellMinPacked, cellMaxPacked);
if (r == Relation.CELL_OUTSIDE_QUERY) {
// This cell is fully outside of the query shape: stop recursing
return 0L;
} else if (nodeID >= leafNodeOffset) {
// Assume all points match and there are no dups
return maxPointsInLeafNode;
} else {
// Non-leaf node: recurse on the split left and right nodes
int address = nodeID * bytesPerIndexEntry;
int splitDim;
if (numDims == 1) {
splitDim = 0;
} else {
splitDim = splitPackedValues[address++] & 0xff;
}
assert splitDim < numDims;
// TODO: can we alloc & reuse this up front?
byte[] splitPackedValue = new byte[packedBytesLength];
// Recurse on left sub-tree:
System.arraycopy(cellMaxPacked, 0, splitPackedValue, 0, packedBytesLength);
System.arraycopy(splitPackedValues, address, splitPackedValue, splitDim*bytesPerDim, bytesPerDim);
final long leftCost = estimatePointCount(state,
2*nodeID,
cellMinPacked, splitPackedValue);
// Recurse on right sub-tree:
System.arraycopy(cellMinPacked, 0, splitPackedValue, 0, packedBytesLength);
System.arraycopy(splitPackedValues, address, splitPackedValue, splitDim*bytesPerDim, bytesPerDim);
final long rightCost = estimatePointCount(state,
2*nodeID+1,
splitPackedValue, cellMaxPacked);
return leftCost + rightCost;
}
}
/** Copies the split value for this node into the provided byte array */ /** Copies the split value for this node into the provided byte array */
public void copySplitValue(int nodeID, byte[] splitPackedValue) { public void copySplitValue(int nodeID, byte[] splitPackedValue) {
int address = nodeID * bytesPerIndexEntry; int address = nodeID * bytesPerIndexEntry;

View File

@ -127,6 +127,11 @@ public abstract class PointsWriter implements Closeable {
} }
} }
@Override
public long estimatePointCount(IntersectVisitor visitor) {
throw new UnsupportedOperationException();
}
@Override @Override
public byte[] getMinPackedValue() { public byte[] getMinPackedValue() {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();

View File

@ -42,6 +42,8 @@ import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.DocumentStoredFieldVisitor; import org.apache.lucene.document.DocumentStoredFieldVisitor;
import org.apache.lucene.index.CheckIndex.Status.DocValuesStatus; import org.apache.lucene.index.CheckIndex.Status.DocValuesStatus;
import org.apache.lucene.index.PointValues.IntersectVisitor;
import org.apache.lucene.index.PointValues.Relation;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.LeafFieldComparator; import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
@ -1810,6 +1812,19 @@ public final class CheckIndex implements Closeable {
long size = values.size(); long size = values.size();
int docCount = values.getDocCount(); int docCount = values.getDocCount();
final long crossCost = values.estimatePointCount(new ConstantRelationIntersectVisitor(Relation.CELL_CROSSES_QUERY));
if (crossCost < size) {
throw new RuntimeException("estimatePointCount should return >= size when all cells match");
}
final long insideCost = values.estimatePointCount(new ConstantRelationIntersectVisitor(Relation.CELL_INSIDE_QUERY));
if (insideCost < size) {
throw new RuntimeException("estimatePointCount should return >= size when all cells fully match");
}
final long outsideCost = values.estimatePointCount(new ConstantRelationIntersectVisitor(Relation.CELL_OUTSIDE_QUERY));
if (outsideCost != 0) {
throw new RuntimeException("estimatePointCount should return 0 when no cells match");
}
VerifyPointsVisitor visitor = new VerifyPointsVisitor(fieldInfo.name, reader.maxDoc(), values); VerifyPointsVisitor visitor = new VerifyPointsVisitor(fieldInfo.name, reader.maxDoc(), values);
values.intersect(visitor); values.intersect(visitor);
@ -2002,6 +2017,28 @@ public final class CheckIndex implements Closeable {
} }
} }
private static class ConstantRelationIntersectVisitor implements IntersectVisitor {
private final Relation relation;
ConstantRelationIntersectVisitor(Relation relation) {
this.relation = relation;
}
@Override
public void visit(int docID) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public void visit(int docID, byte[] packedValue) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return relation;
}
}
/** /**
* Test stored fields. * Test stored fields.

View File

@ -26,6 +26,7 @@ import org.apache.lucene.document.Field;
import org.apache.lucene.document.FloatPoint; import org.apache.lucene.document.FloatPoint;
import org.apache.lucene.document.IntPoint; import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.LongPoint; import org.apache.lucene.document.LongPoint;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.StringHelper; import org.apache.lucene.util.StringHelper;
import org.apache.lucene.util.bkd.BKDWriter; import org.apache.lucene.util.bkd.BKDWriter;
@ -220,6 +221,12 @@ public abstract class PointValues {
* to test whether each document is deleted, if necessary. */ * to test whether each document is deleted, if necessary. */
public abstract void intersect(IntersectVisitor visitor) throws IOException; public abstract void intersect(IntersectVisitor visitor) throws IOException;
/** Estimate the number of points that would be visited by {@link #intersect}
* with the given {@link IntersectVisitor}. This should run many times faster
* than {@link #intersect(IntersectVisitor)}.
* @see DocIdSetIterator#cost */
public abstract long estimatePointCount(IntersectVisitor visitor);
/** Returns minimum value for each dimension, packed, or null if {@link #size} is <code>0</code> */ /** Returns minimum value for each dimension, packed, or null if {@link #size} is <code>0</code> */
public abstract byte[] getMinPackedValue() throws IOException; public abstract byte[] getMinPackedValue() throws IOException;

View File

@ -90,6 +90,11 @@ class PointValuesWriter {
} }
} }
@Override
public long estimatePointCount(IntersectVisitor visitor) {
throw new UnsupportedOperationException();
}
@Override @Override
public byte[] getMinPackedValue() { public byte[] getMinPackedValue() {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
@ -208,6 +213,11 @@ class PointValuesWriter {
}); });
} }
@Override
public long estimatePointCount(IntersectVisitor visitor) {
return in.estimatePointCount(visitor);
}
@Override @Override
public byte[] getMinPackedValue() throws IOException { public byte[] getMinPackedValue() throws IOException {
return in.getMinPackedValue(); return in.getMinPackedValue();

View File

@ -327,6 +327,11 @@ class SortingLeafReader extends FilterLeafReader {
}); });
} }
@Override
public long estimatePointCount(IntersectVisitor visitor) {
return in.estimatePointCount(visitor);
}
@Override @Override
public byte[] getMinPackedValue() throws IOException { public byte[] getMinPackedValue() throws IOException {
return in.getMinPackedValue(); return in.getMinPackedValue();

View File

@ -0,0 +1,217 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.OptionalLong;
import java.util.stream.Stream;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.util.PriorityQueue;
final class Boolean2ScorerSupplier extends ScorerSupplier {
private final BooleanWeight weight;
private final Map<BooleanClause.Occur, Collection<ScorerSupplier>> subs;
private final boolean needsScores;
private final int minShouldMatch;
private long cost = -1;
Boolean2ScorerSupplier(BooleanWeight weight,
Map<Occur, Collection<ScorerSupplier>> subs,
boolean needsScores, int minShouldMatch) {
if (minShouldMatch < 0) {
throw new IllegalArgumentException("minShouldMatch must be positive, but got: " + minShouldMatch);
}
if (minShouldMatch != 0 && minShouldMatch >= subs.get(Occur.SHOULD).size()) {
throw new IllegalArgumentException("minShouldMatch must be strictly less than the number of SHOULD clauses");
}
if (needsScores == false && minShouldMatch == 0 && subs.get(Occur.SHOULD).size() > 0
&& subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size() > 0) {
throw new IllegalArgumentException("Cannot pass purely optional clauses if scores are not needed");
}
if (subs.get(Occur.SHOULD).size() + subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size() == 0) {
throw new IllegalArgumentException("There should be at least one positive clause");
}
this.weight = weight;
this.subs = subs;
this.needsScores = needsScores;
this.minShouldMatch = minShouldMatch;
}
private long computeCost() {
OptionalLong minRequiredCost = Stream.concat(
subs.get(Occur.MUST).stream(),
subs.get(Occur.FILTER).stream())
.mapToLong(ScorerSupplier::cost)
.min();
if (minRequiredCost.isPresent() && minShouldMatch == 0) {
return minRequiredCost.getAsLong();
} else {
final Collection<ScorerSupplier> optionalScorers = subs.get(Occur.SHOULD);
final long shouldCost = MinShouldMatchSumScorer.cost(
optionalScorers.stream().mapToLong(ScorerSupplier::cost),
optionalScorers.size(), minShouldMatch);
return Math.min(minRequiredCost.orElse(Long.MAX_VALUE), shouldCost);
}
}
@Override
public long cost() {
if (cost == -1) {
cost = computeCost();
}
return cost;
}
@Override
public Scorer get(boolean randomAccess) throws IOException {
// three cases: conjunction, disjunction, or mix
// pure conjunction
if (subs.get(Occur.SHOULD).isEmpty()) {
return excl(req(subs.get(Occur.FILTER), subs.get(Occur.MUST), randomAccess), subs.get(Occur.MUST_NOT));
}
// pure disjunction
if (subs.get(Occur.FILTER).isEmpty() && subs.get(Occur.MUST).isEmpty()) {
return excl(opt(subs.get(Occur.SHOULD), minShouldMatch, needsScores, randomAccess), subs.get(Occur.MUST_NOT));
}
// conjunction-disjunction mix:
// we create the required and optional pieces, and then
// combine the two: if minNrShouldMatch > 0, then it's a conjunction: because the
// optional side must match. otherwise it's required + optional
if (minShouldMatch > 0) {
boolean reqRandomAccess = true;
boolean msmRandomAccess = true;
if (randomAccess == false) {
// We need to figure out whether the MUST/FILTER or the SHOULD clauses would lead the iteration
final long reqCost = Stream.concat(
subs.get(Occur.MUST).stream(),
subs.get(Occur.FILTER).stream())
.mapToLong(ScorerSupplier::cost)
.min().getAsLong();
final long msmCost = MinShouldMatchSumScorer.cost(
subs.get(Occur.SHOULD).stream().mapToLong(ScorerSupplier::cost),
subs.get(Occur.SHOULD).size(), minShouldMatch);
reqRandomAccess = reqCost > msmCost;
msmRandomAccess = msmCost > reqCost;
}
Scorer req = excl(req(subs.get(Occur.FILTER), subs.get(Occur.MUST), reqRandomAccess), subs.get(Occur.MUST_NOT));
Scorer opt = opt(subs.get(Occur.SHOULD), minShouldMatch, needsScores, msmRandomAccess);
return new ConjunctionScorer(weight, Arrays.asList(req, opt), Arrays.asList(req, opt));
} else {
assert needsScores;
return new ReqOptSumScorer(
excl(req(subs.get(Occur.FILTER), subs.get(Occur.MUST), randomAccess), subs.get(Occur.MUST_NOT)),
opt(subs.get(Occur.SHOULD), minShouldMatch, needsScores, true));
}
}
/** Create a new scorer for the given required clauses. Note that
* {@code requiredScoring} is a subset of {@code required} containing
* required clauses that should participate in scoring. */
private Scorer req(Collection<ScorerSupplier> requiredNoScoring, Collection<ScorerSupplier> requiredScoring, boolean randomAccess) throws IOException {
if (requiredNoScoring.size() + requiredScoring.size() == 1) {
Scorer req = (requiredNoScoring.isEmpty() ? requiredScoring : requiredNoScoring).iterator().next().get(randomAccess);
if (needsScores == false) {
return req;
}
if (requiredScoring.isEmpty()) {
// Scores are needed but we only have a filter clause
// BooleanWeight expects that calling score() is ok so we need to wrap
// to prevent score() from being propagated
return new FilterScorer(req) {
@Override
public float score() throws IOException {
return 0f;
}
@Override
public int freq() throws IOException {
return 0;
}
};
}
return req;
} else {
long minCost = Math.min(
requiredNoScoring.stream().mapToLong(ScorerSupplier::cost).min().orElse(Long.MAX_VALUE),
requiredScoring.stream().mapToLong(ScorerSupplier::cost).min().orElse(Long.MAX_VALUE));
List<Scorer> requiredScorers = new ArrayList<>();
List<Scorer> scoringScorers = new ArrayList<>();
for (ScorerSupplier s : requiredNoScoring) {
requiredScorers.add(s.get(randomAccess || s.cost() > minCost));
}
for (ScorerSupplier s : requiredScoring) {
Scorer scorer = s.get(randomAccess || s.cost() > minCost);
requiredScorers.add(scorer);
scoringScorers.add(scorer);
}
return new ConjunctionScorer(weight, requiredScorers, scoringScorers);
}
}
private Scorer excl(Scorer main, Collection<ScorerSupplier> prohibited) throws IOException {
if (prohibited.isEmpty()) {
return main;
} else {
return new ReqExclScorer(main, opt(prohibited, 1, false, true));
}
}
private Scorer opt(Collection<ScorerSupplier> optional, int minShouldMatch,
boolean needsScores, boolean randomAccess) throws IOException {
if (optional.size() == 1) {
return optional.iterator().next().get(randomAccess);
} else if (minShouldMatch > 1) {
final List<Scorer> optionalScorers = new ArrayList<>();
final PriorityQueue<ScorerSupplier> pq = new PriorityQueue<ScorerSupplier>(subs.get(Occur.SHOULD).size() - minShouldMatch + 1) {
@Override
protected boolean lessThan(ScorerSupplier a, ScorerSupplier b) {
return a.cost() > b.cost();
}
};
for (ScorerSupplier scorer : subs.get(Occur.SHOULD)) {
ScorerSupplier overflow = pq.insertWithOverflow(scorer);
if (overflow != null) {
optionalScorers.add(overflow.get(true));
}
}
for (ScorerSupplier scorer : pq) {
optionalScorers.add(scorer.get(randomAccess));
}
return new MinShouldMatchSumScorer(weight, optionalScorers, minShouldMatch);
} else {
final List<Scorer> optionalScorers = new ArrayList<>();
for (ScorerSupplier scorer : optional) {
optionalScorers.add(scorer.get(randomAccess));
}
return new DisjunctionSumScorer(weight, optionalScorers, needsScores);
}
}
}

View File

@ -19,9 +19,11 @@ package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Collection;
import java.util.EnumMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
@ -265,7 +267,9 @@ final class BooleanWeight extends Weight {
if (prohibited.isEmpty()) { if (prohibited.isEmpty()) {
return positiveScorer; return positiveScorer;
} else { } else {
Scorer prohibitedScorer = opt(prohibited, 1); Scorer prohibitedScorer = prohibited.size() == 1
? prohibited.get(0)
: new DisjunctionSumScorer(this, prohibited, false);
if (prohibitedScorer.twoPhaseIterator() != null) { if (prohibitedScorer.twoPhaseIterator() != null) {
// ReqExclBulkScorer can't deal efficiently with two-phased prohibited clauses // ReqExclBulkScorer can't deal efficiently with two-phased prohibited clauses
return null; return null;
@ -288,50 +292,48 @@ final class BooleanWeight extends Weight {
@Override @Override
public Scorer scorer(LeafReaderContext context) throws IOException { public Scorer scorer(LeafReaderContext context) throws IOException {
// initially the user provided value, ScorerSupplier scorerSupplier = scorerSupplier(context);
// but if minNrShouldMatch == optional.size(), if (scorerSupplier == null) {
// we will optimize and move these to required, making this 0 return null;
}
return scorerSupplier.get(false);
}
@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
int minShouldMatch = query.getMinimumNumberShouldMatch(); int minShouldMatch = query.getMinimumNumberShouldMatch();
List<Scorer> required = new ArrayList<>(); final Map<Occur, Collection<ScorerSupplier>> scorers = new EnumMap<>(Occur.class);
// clauses that are required AND participate in scoring, subset of 'required' for (Occur occur : Occur.values()) {
List<Scorer> requiredScoring = new ArrayList<>(); scorers.put(occur, new ArrayList<>());
List<Scorer> prohibited = new ArrayList<>(); }
List<Scorer> optional = new ArrayList<>();
Iterator<BooleanClause> cIter = query.iterator(); Iterator<BooleanClause> cIter = query.iterator();
for (Weight w : weights) { for (Weight w : weights) {
BooleanClause c = cIter.next(); BooleanClause c = cIter.next();
Scorer subScorer = w.scorer(context); ScorerSupplier subScorer = w.scorerSupplier(context);
if (subScorer == null) { if (subScorer == null) {
if (c.isRequired()) { if (c.isRequired()) {
return null; return null;
} }
} else if (c.isRequired()) {
required.add(subScorer);
if (c.isScoring()) {
requiredScoring.add(subScorer);
}
} else if (c.isProhibited()) {
prohibited.add(subScorer);
} else { } else {
optional.add(subScorer); scorers.get(c.getOccur()).add(subScorer);
} }
} }
// scorer simplifications: // scorer simplifications:
if (optional.size() == minShouldMatch) { if (scorers.get(Occur.SHOULD).size() == minShouldMatch) {
// any optional clauses are in fact required // any optional clauses are in fact required
required.addAll(optional); scorers.get(Occur.MUST).addAll(scorers.get(Occur.SHOULD));
requiredScoring.addAll(optional); scorers.get(Occur.SHOULD).clear();
optional.clear();
minShouldMatch = 0; minShouldMatch = 0;
} }
if (required.isEmpty() && optional.isEmpty()) { if (scorers.get(Occur.FILTER).isEmpty() && scorers.get(Occur.MUST).isEmpty() && scorers.get(Occur.SHOULD).isEmpty()) {
// no required and optional clauses. // no required and optional clauses.
return null; return null;
} else if (optional.size() < minShouldMatch) { } else if (scorers.get(Occur.SHOULD).size() < minShouldMatch) {
// either >1 req scorer, or there are 0 req scorers and at least 1 // either >1 req scorer, or there are 0 req scorers and at least 1
// optional scorer. Therefore if there are not enough optional scorers // optional scorer. Therefore if there are not enough optional scorers
// no documents will be matched by the query // no documents will be matched by the query
@ -339,87 +341,11 @@ final class BooleanWeight extends Weight {
} }
// we don't need scores, so if we have required clauses, drop optional clauses completely // we don't need scores, so if we have required clauses, drop optional clauses completely
if (!needsScores && minShouldMatch == 0 && required.size() > 0) { if (!needsScores && minShouldMatch == 0 && scorers.get(Occur.MUST).size() + scorers.get(Occur.FILTER).size() > 0) {
optional.clear(); scorers.get(Occur.SHOULD).clear();
} }
// three cases: conjunction, disjunction, or mix
// pure conjunction
if (optional.isEmpty()) {
return excl(req(required, requiredScoring), prohibited);
}
// pure disjunction
if (required.isEmpty()) {
return excl(opt(optional, minShouldMatch), prohibited);
}
// conjunction-disjunction mix:
// we create the required and optional pieces, and then
// combine the two: if minNrShouldMatch > 0, then it's a conjunction: because the
// optional side must match. otherwise it's required + optional
Scorer req = excl(req(required, requiredScoring), prohibited);
Scorer opt = opt(optional, minShouldMatch);
if (minShouldMatch > 0) { return new Boolean2ScorerSupplier(this, scorers, needsScores, minShouldMatch);
return new ConjunctionScorer(this, Arrays.asList(req, opt), Arrays.asList(req, opt));
} else {
return new ReqOptSumScorer(req, opt);
}
} }
/** Create a new scorer for the given required clauses. Note that
* {@code requiredScoring} is a subset of {@code required} containing
* required clauses that should participate in scoring. */
private Scorer req(List<Scorer> required, List<Scorer> requiredScoring) {
if (required.size() == 1) {
Scorer req = required.get(0);
if (needsScores == false) {
return req;
}
if (requiredScoring.isEmpty()) {
// Scores are needed but we only have a filter clause
// BooleanWeight expects that calling score() is ok so we need to wrap
// to prevent score() from being propagated
return new FilterScorer(req) {
@Override
public float score() throws IOException {
return 0f;
}
@Override
public int freq() throws IOException {
return 0;
}
};
}
return req;
} else {
return new ConjunctionScorer(this, required, requiredScoring);
}
}
private Scorer excl(Scorer main, List<Scorer> prohibited) throws IOException {
if (prohibited.isEmpty()) {
return main;
} else if (prohibited.size() == 1) {
return new ReqExclScorer(main, prohibited.get(0));
} else {
return new ReqExclScorer(main, new DisjunctionSumScorer(this, prohibited, false));
}
}
private Scorer opt(List<Scorer> optional, int minShouldMatch) throws IOException {
if (optional.size() == 1) {
return optional.get(0);
} else if (minShouldMatch > 1) {
return new MinShouldMatchSumScorer(this, optional, minShouldMatch);
} else {
return new DisjunctionSumScorer(this, optional, needsScores);
}
}
} }

View File

@ -41,7 +41,7 @@ public final class ConjunctionDISI extends DocIdSetIterator {
* returned {@link DocIdSetIterator} might leverage two-phase iteration in * returned {@link DocIdSetIterator} might leverage two-phase iteration in
* which case it is possible to retrieve the {@link TwoPhaseIterator} using * which case it is possible to retrieve the {@link TwoPhaseIterator} using
* {@link TwoPhaseIterator#unwrap}. */ * {@link TwoPhaseIterator#unwrap}. */
public static DocIdSetIterator intersectScorers(List<Scorer> scorers) { public static DocIdSetIterator intersectScorers(Collection<Scorer> scorers) {
if (scorers.size() < 2) { if (scorers.size() < 2) {
throw new IllegalArgumentException("Cannot make a ConjunctionDISI of less than 2 iterators"); throw new IllegalArgumentException("Cannot make a ConjunctionDISI of less than 2 iterators");
} }

View File

@ -20,7 +20,6 @@ package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.List;
/** Scorer for conjunctions, sets of queries, all of which are required. */ /** Scorer for conjunctions, sets of queries, all of which are required. */
class ConjunctionScorer extends Scorer { class ConjunctionScorer extends Scorer {
@ -29,7 +28,7 @@ class ConjunctionScorer extends Scorer {
final Scorer[] scorers; final Scorer[] scorers;
/** Create a new {@link ConjunctionScorer}, note that {@code scorers} must be a subset of {@code required}. */ /** Create a new {@link ConjunctionScorer}, note that {@code scorers} must be a subset of {@code required}. */
ConjunctionScorer(Weight weight, List<Scorer> required, List<Scorer> scorers) { ConjunctionScorer(Weight weight, Collection<Scorer> required, Collection<Scorer> scorers) {
super(weight); super(weight);
assert required.containsAll(scorers); assert required.containsAll(scorers);
this.disi = ConjunctionDISI.intersectScorers(required); this.disi = ConjunctionDISI.intersectScorers(required);

View File

@ -125,28 +125,48 @@ public final class ConstantScoreQuery extends Query {
} }
@Override @Override
public Scorer scorer(LeafReaderContext context) throws IOException { public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
final Scorer innerScorer = innerWeight.scorer(context); ScorerSupplier innerScorerSupplier = innerWeight.scorerSupplier(context);
if (innerScorer == null) { if (innerScorerSupplier == null) {
return null; return null;
} }
final float score = score(); return new ScorerSupplier() {
return new FilterScorer(innerScorer) {
@Override @Override
public float score() throws IOException { public Scorer get(boolean randomAccess) throws IOException {
return score; final Scorer innerScorer = innerScorerSupplier.get(randomAccess);
final float score = score();
return new FilterScorer(innerScorer) {
@Override
public float score() throws IOException {
return score;
}
@Override
public int freq() throws IOException {
return 1;
}
@Override
public Collection<ChildScorer> getChildren() {
return Collections.singleton(new ChildScorer(innerScorer, "constant"));
}
};
} }
@Override @Override
public int freq() throws IOException { public long cost() {
return 1; return innerScorerSupplier.cost();
}
@Override
public Collection<ChildScorer> getChildren() {
return Collections.singleton(new ChildScorer(innerScorer, "constant"));
} }
}; };
} }
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
ScorerSupplier scorerSupplier = scorerSupplier(context);
if (scorerSupplier == null) {
return null;
}
return scorerSupplier.get(false);
}
}; };
} else { } else {
return innerWeight; return innerWeight;

View File

@ -22,6 +22,8 @@ import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.stream.LongStream;
import java.util.stream.StreamSupport;
import org.apache.lucene.util.PriorityQueue; import org.apache.lucene.util.PriorityQueue;
@ -47,7 +49,7 @@ import static org.apache.lucene.search.DisiPriorityQueue.rightNode;
*/ */
final class MinShouldMatchSumScorer extends Scorer { final class MinShouldMatchSumScorer extends Scorer {
private static long cost(Collection<Scorer> scorers, int minShouldMatch) { static long cost(LongStream costs, int numScorers, int minShouldMatch) {
// the idea here is the following: a boolean query c1,c2,...cn with minShouldMatch=m // the idea here is the following: a boolean query c1,c2,...cn with minShouldMatch=m
// could be rewritten to: // could be rewritten to:
// (c1 AND (c2..cn|msm=m-1)) OR (!c1 AND (c2..cn|msm=m)) // (c1 AND (c2..cn|msm=m-1)) OR (!c1 AND (c2..cn|msm=m))
@ -61,20 +63,14 @@ final class MinShouldMatchSumScorer extends Scorer {
// If we recurse infinitely, we find out that the cost of a msm query is the sum of the // If we recurse infinitely, we find out that the cost of a msm query is the sum of the
// costs of the num_scorers - minShouldMatch + 1 least costly scorers // costs of the num_scorers - minShouldMatch + 1 least costly scorers
final PriorityQueue<Scorer> pq = new PriorityQueue<Scorer>(scorers.size() - minShouldMatch + 1) { final PriorityQueue<Long> pq = new PriorityQueue<Long>(numScorers - minShouldMatch + 1) {
@Override @Override
protected boolean lessThan(Scorer a, Scorer b) { protected boolean lessThan(Long a, Long b) {
return a.iterator().cost() > b.iterator().cost(); return a > b;
} }
}; };
for (Scorer scorer : scorers) { costs.forEach(pq::insertWithOverflow);
pq.insertWithOverflow(scorer); return StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum();
}
long cost = 0;
for (Scorer scorer = pq.pop(); scorer != null; scorer = pq.pop()) {
cost += scorer.iterator().cost();
}
return cost;
} }
final int minShouldMatch; final int minShouldMatch;
@ -124,7 +120,7 @@ final class MinShouldMatchSumScorer extends Scorer {
children.add(new ChildScorer(scorer, "SHOULD")); children.add(new ChildScorer(scorer, "SHOULD"));
} }
this.childScorers = Collections.unmodifiableCollection(children); this.childScorers = Collections.unmodifiableCollection(children);
this.cost = cost(scorers, minShouldMatch); this.cost = cost(scorers.stream().map(Scorer::iterator).mapToLong(DocIdSetIterator::cost), scorers.size(), minShouldMatch);
} }
@Override @Override

View File

@ -104,71 +104,67 @@ public abstract class PointRangeQuery extends Query {
return new ConstantScoreWeight(this, boost) { return new ConstantScoreWeight(this, boost) {
private DocIdSet buildMatchingDocIdSet(LeafReader reader, PointValues values) throws IOException { private IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) {
DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field); return new IntersectVisitor() {
values.intersect( DocIdSetBuilder.BulkAdder adder;
new IntersectVisitor() {
DocIdSetBuilder.BulkAdder adder; @Override
public void grow(int count) {
adder = result.grow(count);
}
@Override @Override
public void grow(int count) { public void visit(int docID) {
adder = result.grow(count); adder.add(docID);
}
@Override
public void visit(int docID, byte[] packedValue) {
for(int dim=0;dim<numDims;dim++) {
int offset = dim*bytesPerDim;
if (StringHelper.compare(bytesPerDim, packedValue, offset, lowerPoint, offset) < 0) {
// Doc's value is too low, in this dimension
return;
}
if (StringHelper.compare(bytesPerDim, packedValue, offset, upperPoint, offset) > 0) {
// Doc's value is too high, in this dimension
return;
}
}
// Doc is in-bounds
adder.add(docID);
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
boolean crosses = false;
for(int dim=0;dim<numDims;dim++) {
int offset = dim*bytesPerDim;
if (StringHelper.compare(bytesPerDim, minPackedValue, offset, upperPoint, offset) > 0 ||
StringHelper.compare(bytesPerDim, maxPackedValue, offset, lowerPoint, offset) < 0) {
return Relation.CELL_OUTSIDE_QUERY;
} }
@Override crosses |= StringHelper.compare(bytesPerDim, minPackedValue, offset, lowerPoint, offset) < 0 ||
public void visit(int docID) { StringHelper.compare(bytesPerDim, maxPackedValue, offset, upperPoint, offset) > 0;
adder.add(docID); }
}
@Override if (crosses) {
public void visit(int docID, byte[] packedValue) { return Relation.CELL_CROSSES_QUERY;
for(int dim=0;dim<numDims;dim++) { } else {
int offset = dim*bytesPerDim; return Relation.CELL_INSIDE_QUERY;
if (StringHelper.compare(bytesPerDim, packedValue, offset, lowerPoint, offset) < 0) { }
// Doc's value is too low, in this dimension }
return; };
}
if (StringHelper.compare(bytesPerDim, packedValue, offset, upperPoint, offset) > 0) {
// Doc's value is too high, in this dimension
return;
}
}
// Doc is in-bounds
adder.add(docID);
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
boolean crosses = false;
for(int dim=0;dim<numDims;dim++) {
int offset = dim*bytesPerDim;
if (StringHelper.compare(bytesPerDim, minPackedValue, offset, upperPoint, offset) > 0 ||
StringHelper.compare(bytesPerDim, maxPackedValue, offset, lowerPoint, offset) < 0) {
return Relation.CELL_OUTSIDE_QUERY;
}
crosses |= StringHelper.compare(bytesPerDim, minPackedValue, offset, lowerPoint, offset) < 0 ||
StringHelper.compare(bytesPerDim, maxPackedValue, offset, upperPoint, offset) > 0;
}
if (crosses) {
return Relation.CELL_CROSSES_QUERY;
} else {
return Relation.CELL_INSIDE_QUERY;
}
}
});
return result.build();
} }
@Override @Override
public Scorer scorer(LeafReaderContext context) throws IOException { public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
LeafReader reader = context.reader(); LeafReader reader = context.reader();
PointValues values = reader.getPointValues(field); PointValues values = reader.getPointValues(field);
@ -201,15 +197,55 @@ public abstract class PointRangeQuery extends Query {
allDocsMatch = false; allDocsMatch = false;
} }
DocIdSetIterator iterator; final Weight weight = this;
if (allDocsMatch) { if (allDocsMatch) {
// all docs have a value and all points are within bounds, so everything matches // all docs have a value and all points are within bounds, so everything matches
iterator = DocIdSetIterator.all(reader.maxDoc()); return new ScorerSupplier() {
@Override
public Scorer get(boolean randomAccess) {
return new ConstantScoreScorer(weight, score(),
DocIdSetIterator.all(reader.maxDoc()));
}
@Override
public long cost() {
return reader.maxDoc();
}
};
} else { } else {
iterator = buildMatchingDocIdSet(reader, values).iterator(); return new ScorerSupplier() {
}
return new ConstantScoreScorer(this, score(), iterator); final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
final IntersectVisitor visitor = getIntersectVisitor(result);
long cost = -1;
@Override
public Scorer get(boolean randomAccess) throws IOException {
values.intersect(visitor);
DocIdSetIterator iterator = result.build().iterator();
return new ConstantScoreScorer(weight, score(), iterator);
}
@Override
public long cost() {
if (cost == -1) {
// Computing the cost may be expensive, so only do it if necessary
cost = values.estimatePointCount(visitor);
assert cost >= 0;
}
return cost;
}
};
}
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
ScorerSupplier scorerSupplier = scorerSupplier(context);
if (scorerSupplier == null) {
return null;
}
return scorerSupplier.get(false);
} }
}; };
} }

View File

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
/**
* A supplier of {@link Scorer}. This allows to get an estimate of the cost before
* building the {@link Scorer}.
*/
public abstract class ScorerSupplier {
/**
* Get the {@link Scorer}. This may not return {@code null} and must be called
* at most once.
* @param randomAccess A hint about the expected usage of the {@link Scorer}.
* If {@link DocIdSetIterator#advance} or {@link TwoPhaseIterator} will be
* used to check whether given doc ids match, then pass {@code true}.
* Otherwise if the {@link Scorer} will be mostly used to lead the iteration
* using {@link DocIdSetIterator#nextDoc()}, then {@code false} should be
* passed. Under doubt, pass {@code false} which usually has a better
* worst-case.
*/
public abstract Scorer get(boolean randomAccess) throws IOException;
/**
* Get an estimate of the {@link Scorer} that would be returned by {@link #get}.
* This may be a costly operation, so it should only be called if necessary.
* @see DocIdSetIterator#cost
*/
public abstract long cost();
}

View File

@ -102,6 +102,31 @@ public abstract class Weight {
*/ */
public abstract Scorer scorer(LeafReaderContext context) throws IOException; public abstract Scorer scorer(LeafReaderContext context) throws IOException;
/**
* Optional method.
* Get a {@link ScorerSupplier}, which allows to know the cost of the {@link Scorer}
* before building it. The default implementation calls {@link #scorer} and
* builds a {@link ScorerSupplier} wrapper around it.
* @see #scorer
*/
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
final Scorer scorer = scorer(context);
if (scorer == null) {
return null;
}
return new ScorerSupplier() {
@Override
public Scorer get(boolean randomAccess) {
return scorer;
}
@Override
public long cost() {
return scorer.iterator().cost();
}
};
}
/** /**
* Optional method, to return a {@link BulkScorer} to * Optional method, to return a {@link BulkScorer} to
* score the query and send hits to a {@link Collector}. * score the query and send hits to a {@link Collector}.

View File

@ -482,10 +482,16 @@ public final class BKDReader extends PointValues implements Accountable {
} }
} }
@Override
public void intersect(IntersectVisitor visitor) throws IOException { public void intersect(IntersectVisitor visitor) throws IOException {
intersect(getIntersectState(visitor), minPackedValue, maxPackedValue); intersect(getIntersectState(visitor), minPackedValue, maxPackedValue);
} }
@Override
public long estimatePointCount(IntersectVisitor visitor) {
return estimatePointCount(getIntersectState(visitor), minPackedValue, maxPackedValue);
}
/** Fast path: this is called when the query box fully encompasses all cells under this node. */ /** Fast path: this is called when the query box fully encompasses all cells under this node. */
private void addAll(IntersectState state) throws IOException { private void addAll(IntersectState state) throws IOException {
//System.out.println("R: addAll nodeID=" + nodeID); //System.out.println("R: addAll nodeID=" + nodeID);
@ -696,6 +702,59 @@ public final class BKDReader extends PointValues implements Accountable {
} }
} }
private long estimatePointCount(IntersectState state, byte[] cellMinPacked, byte[] cellMaxPacked) {
/*
System.out.println("\nR: intersect nodeID=" + state.index.getNodeID());
for(int dim=0;dim<numDims;dim++) {
System.out.println(" dim=" + dim + "\n cellMin=" + new BytesRef(cellMinPacked, dim*bytesPerDim, bytesPerDim) + "\n cellMax=" + new BytesRef(cellMaxPacked, dim*bytesPerDim, bytesPerDim));
}
*/
Relation r = state.visitor.compare(cellMinPacked, cellMaxPacked);
if (r == Relation.CELL_OUTSIDE_QUERY) {
// This cell is fully outside of the query shape: stop recursing
return 0L;
} else if (state.index.isLeafNode()) {
// Assume all points match and there are no dups
return maxPointsInLeafNode;
} else {
// Non-leaf node: recurse on the split left and right nodes
int splitDim = state.index.getSplitDim();
assert splitDim >= 0: "splitDim=" + splitDim;
assert splitDim < numDims;
byte[] splitPackedValue = state.index.getSplitPackedValue();
BytesRef splitDimValue = state.index.getSplitDimValue();
assert splitDimValue.length == bytesPerDim;
//System.out.println(" splitDimValue=" + splitDimValue + " splitDim=" + splitDim);
// make sure cellMin <= splitValue <= cellMax:
assert StringHelper.compare(bytesPerDim, cellMinPacked, splitDim*bytesPerDim, splitDimValue.bytes, splitDimValue.offset) <= 0: "bytesPerDim=" + bytesPerDim + " splitDim=" + splitDim + " numDims=" + numDims;
assert StringHelper.compare(bytesPerDim, cellMaxPacked, splitDim*bytesPerDim, splitDimValue.bytes, splitDimValue.offset) >= 0: "bytesPerDim=" + bytesPerDim + " splitDim=" + splitDim + " numDims=" + numDims;
// Recurse on left sub-tree:
System.arraycopy(cellMaxPacked, 0, splitPackedValue, 0, packedBytesLength);
System.arraycopy(splitDimValue.bytes, splitDimValue.offset, splitPackedValue, splitDim*bytesPerDim, bytesPerDim);
state.index.pushLeft();
final long leftCost = estimatePointCount(state, cellMinPacked, splitPackedValue);
state.index.pop();
// Restore the split dim value since it may have been overwritten while recursing:
System.arraycopy(splitPackedValue, splitDim*bytesPerDim, splitDimValue.bytes, splitDimValue.offset, bytesPerDim);
// Recurse on right sub-tree:
System.arraycopy(cellMinPacked, 0, splitPackedValue, 0, packedBytesLength);
System.arraycopy(splitDimValue.bytes, splitDimValue.offset, splitPackedValue, splitDim*bytesPerDim, bytesPerDim);
state.index.pushRight();
final long rightCost = estimatePointCount(state, splitPackedValue, cellMaxPacked);
state.index.pop();
return leftCost + rightCost;
}
}
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
if (packedIndex != null) { if (packedIndex != null) {

View File

@ -0,0 +1,332 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.EnumMap;
import java.util.Map;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
import com.carrotsearch.randomizedtesting.generators.RandomPicks;
public class TestBoolean2ScorerSupplier extends LuceneTestCase {
private static class FakeScorer extends Scorer {
private final DocIdSetIterator it;
FakeScorer(long cost) {
super(null);
this.it = DocIdSetIterator.all(Math.toIntExact(cost));
}
@Override
public int docID() {
return it.docID();
}
@Override
public float score() throws IOException {
return 1;
}
@Override
public int freq() throws IOException {
return 1;
}
@Override
public DocIdSetIterator iterator() {
return it;
}
@Override
public String toString() {
return "FakeScorer(cost=" + it.cost() + ")";
}
}
private static class FakeScorerSupplier extends ScorerSupplier {
private final long cost;
private final Boolean randomAccess;
FakeScorerSupplier(long cost) {
this.cost = cost;
this.randomAccess = null;
}
FakeScorerSupplier(long cost, boolean randomAccess) {
this.cost = cost;
this.randomAccess = randomAccess;
}
@Override
public Scorer get(boolean randomAccess) throws IOException {
if (this.randomAccess != null) {
assertEquals(this.toString(), this.randomAccess, randomAccess);
}
return new FakeScorer(cost);
}
@Override
public long cost() {
return cost;
}
@Override
public String toString() {
return "FakeLazyScorer(cost=" + cost + ",randomAccess=" + randomAccess + ")";
}
}
public void testConjunctionCost() {
Map<Occur, Collection<ScorerSupplier>> subs = new EnumMap<>(Occur.class);
for (Occur occur : Occur.values()) {
subs.put(occur, new ArrayList<>());
}
subs.get(RandomPicks.randomFrom(random(), Arrays.asList(Occur.FILTER, Occur.MUST))).add(new FakeScorerSupplier(42));
assertEquals(42, new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 0).cost());
subs.get(RandomPicks.randomFrom(random(), Arrays.asList(Occur.FILTER, Occur.MUST))).add(new FakeScorerSupplier(12));
assertEquals(12, new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 0).cost());
subs.get(RandomPicks.randomFrom(random(), Arrays.asList(Occur.FILTER, Occur.MUST))).add(new FakeScorerSupplier(20));
assertEquals(12, new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 0).cost());
}
public void testDisjunctionCost() throws IOException {
Map<Occur, Collection<ScorerSupplier>> subs = new EnumMap<>(Occur.class);
for (Occur occur : Occur.values()) {
subs.put(occur, new ArrayList<>());
}
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(42));
ScorerSupplier s = new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 0);
assertEquals(42, s.cost());
assertEquals(42, s.get(random().nextBoolean()).iterator().cost());
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(12));
s = new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 0);
assertEquals(42 + 12, s.cost());
assertEquals(42 + 12, s.get(random().nextBoolean()).iterator().cost());
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(20));
s = new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 0);
assertEquals(42 + 12 + 20, s.cost());
assertEquals(42 + 12 + 20, s.get(random().nextBoolean()).iterator().cost());
}
public void testDisjunctionWithMinShouldMatchCost() throws IOException {
Map<Occur, Collection<ScorerSupplier>> subs = new EnumMap<>(Occur.class);
for (Occur occur : Occur.values()) {
subs.put(occur, new ArrayList<>());
}
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(42));
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(12));
ScorerSupplier s = new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 1);
assertEquals(42 + 12, s.cost());
assertEquals(42 + 12, s.get(random().nextBoolean()).iterator().cost());
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(20));
s = new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 1);
assertEquals(42 + 12 + 20, s.cost());
assertEquals(42 + 12 + 20, s.get(random().nextBoolean()).iterator().cost());
s = new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 2);
assertEquals(12 + 20, s.cost());
assertEquals(12 + 20, s.get(random().nextBoolean()).iterator().cost());
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(30));
s = new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 1);
assertEquals(42 + 12 + 20 + 30, s.cost());
assertEquals(42 + 12 + 20 + 30, s.get(random().nextBoolean()).iterator().cost());
s = new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 2);
assertEquals(12 + 20 + 30, s.cost());
assertEquals(12 + 20 + 30, s.get(random().nextBoolean()).iterator().cost());
s = new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 3);
assertEquals(12 + 20, s.cost());
assertEquals(12 + 20, s.get(random().nextBoolean()).iterator().cost());
}
public void testDuelCost() throws Exception {
final int iters = atLeast(1000);
for (int iter = 0; iter < iters; ++iter) {
Map<Occur, Collection<ScorerSupplier>> subs = new EnumMap<>(Occur.class);
for (Occur occur : Occur.values()) {
subs.put(occur, new ArrayList<>());
}
int numClauses = TestUtil.nextInt(random(), 1, 10);
int numShoulds = 0;
int numRequired = 0;
for (int j = 0; j < numClauses; ++j) {
Occur occur = RandomPicks.randomFrom(random(), Occur.values());
subs.get(occur).add(new FakeScorerSupplier(random().nextInt(100)));
if (occur == Occur.SHOULD) {
++numShoulds;
} else if (occur == Occur.FILTER || occur == Occur.MUST) {
numRequired++;
}
}
boolean needsScores = random().nextBoolean();
if (needsScores == false && numRequired > 0) {
numClauses -= numShoulds;
numShoulds = 0;
subs.get(Occur.SHOULD).clear();
}
if (numShoulds + numRequired == 0) {
// only negative clauses, invalid
continue;
}
int minShouldMatch = numShoulds == 0 ? 0 : TestUtil.nextInt(random(), 0, numShoulds - 1);
Boolean2ScorerSupplier supplier = new Boolean2ScorerSupplier(null,
subs, needsScores, minShouldMatch);
long cost1 = supplier.cost();
long cost2 = supplier.get(false).iterator().cost();
assertEquals("clauses=" + subs + ", minShouldMatch=" + minShouldMatch, cost1, cost2);
}
}
// test the tester...
public void testFakeScorerSupplier() {
FakeScorerSupplier randomAccessSupplier = new FakeScorerSupplier(random().nextInt(100), true);
expectThrows(AssertionError.class, () -> randomAccessSupplier.get(false));
FakeScorerSupplier sequentialSupplier = new FakeScorerSupplier(random().nextInt(100), false);
expectThrows(AssertionError.class, () -> sequentialSupplier.get(true));
}
public void testConjunctionRandomAccess() throws IOException {
Map<Occur, Collection<ScorerSupplier>> subs = new EnumMap<>(Occur.class);
for (Occur occur : Occur.values()) {
subs.put(occur, new ArrayList<>());
}
// If sequential access is required, only the least costly clause does not use random-access
subs.get(RandomPicks.randomFrom(random(), Arrays.asList(Occur.FILTER, Occur.MUST))).add(new FakeScorerSupplier(42, true));
subs.get(RandomPicks.randomFrom(random(), Arrays.asList(Occur.FILTER, Occur.MUST))).add(new FakeScorerSupplier(12, false));
new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 0).get(false); // triggers assertions as a side-effect
subs = new EnumMap<>(Occur.class);
for (Occur occur : Occur.values()) {
subs.put(occur, new ArrayList<>());
}
// If random access is required, then we propagate to sub clauses
subs.get(RandomPicks.randomFrom(random(), Arrays.asList(Occur.FILTER, Occur.MUST))).add(new FakeScorerSupplier(42, true));
subs.get(RandomPicks.randomFrom(random(), Arrays.asList(Occur.FILTER, Occur.MUST))).add(new FakeScorerSupplier(12, true));
new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 0).get(true); // triggers assertions as a side-effect
}
public void testDisjunctionRandomAccess() throws IOException {
// disjunctions propagate
for (boolean randomAccess : new boolean[] {false, true}) {
Map<Occur, Collection<ScorerSupplier>> subs = new EnumMap<>(Occur.class);
for (Occur occur : Occur.values()) {
subs.put(occur, new ArrayList<>());
}
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(42, randomAccess));
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(12, randomAccess));
new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 0).get(randomAccess); // triggers assertions as a side-effect
}
}
public void testDisjunctionWithMinShouldMatchRandomAccess() throws IOException {
Map<Occur, Collection<ScorerSupplier>> subs = new EnumMap<>(Occur.class);
for (Occur occur : Occur.values()) {
subs.put(occur, new ArrayList<>());
}
// Only the most costly clause uses random-access in that case:
// most of time, we will find agreement between the 2 least costly
// clauses and only then check whether the 3rd one matches too
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(42, true));
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(12, false));
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(30, false));
new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 2).get(false); // triggers assertions as a side-effect
subs = new EnumMap<>(Occur.class);
for (Occur occur : Occur.values()) {
subs.put(occur, new ArrayList<>());
}
// When random-access is true, just propagate
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(42, true));
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(12, true));
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(30, true));
new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 2).get(true); // triggers assertions as a side-effect
subs = new EnumMap<>(Occur.class);
for (Occur occur : Occur.values()) {
subs.put(occur, new ArrayList<>());
}
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(42, true));
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(12, false));
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(30, false));
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(20, false));
new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 2).get(false); // triggers assertions as a side-effect
subs = new EnumMap<>(Occur.class);
for (Occur occur : Occur.values()) {
subs.put(occur, new ArrayList<>());
}
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(42, true));
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(12, false));
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(30, true));
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(20, false));
new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 3).get(false); // triggers assertions as a side-effect
}
public void testProhibitedRandomAccess() throws IOException {
for (boolean randomAccess : new boolean[] {false, true}) {
Map<Occur, Collection<ScorerSupplier>> subs = new EnumMap<>(Occur.class);
for (Occur occur : Occur.values()) {
subs.put(occur, new ArrayList<>());
}
// The MUST_NOT clause always uses random-access
subs.get(Occur.MUST).add(new FakeScorerSupplier(42, randomAccess));
subs.get(Occur.MUST_NOT).add(new FakeScorerSupplier(TestUtil.nextInt(random(), 1, 100), true));
new Boolean2ScorerSupplier(null, subs, random().nextBoolean(), 0).get(randomAccess); // triggers assertions as a side-effect
}
}
public void testMixedRandomAccess() throws IOException {
for (boolean randomAccess : new boolean[] {false, true}) {
Map<Occur, Collection<ScorerSupplier>> subs = new EnumMap<>(Occur.class);
for (Occur occur : Occur.values()) {
subs.put(occur, new ArrayList<>());
}
// The SHOULD clause always uses random-access if there is a MUST clause
subs.get(Occur.MUST).add(new FakeScorerSupplier(42, randomAccess));
subs.get(Occur.SHOULD).add(new FakeScorerSupplier(TestUtil.nextInt(random(), 1, 100), true));
new Boolean2ScorerSupplier(null, subs, true, 0).get(randomAccess); // triggers assertions as a side-effect
}
}
}

View File

@ -206,8 +206,8 @@ public class TestBooleanQueryVisitSubscorers extends LuceneTestCase {
" MUST ConstantScoreScorer\n" + " MUST ConstantScoreScorer\n" +
" MUST MinShouldMatchSumScorer\n" + " MUST MinShouldMatchSumScorer\n" +
" SHOULD TermScorer body:nutch\n" + " SHOULD TermScorer body:nutch\n" +
" SHOULD TermScorer body:web\n" + " SHOULD TermScorer body:crawler\n" +
" SHOULD TermScorer body:crawler", " SHOULD TermScorer body:web",
summary); summary);
} }
} }

View File

@ -18,6 +18,7 @@ package org.apache.lucene.search;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.lang.reflect.Modifier; import java.lang.reflect.Modifier;
import java.util.Arrays;
import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.LuceneTestCase;
import org.junit.Test; import org.junit.Test;
@ -35,7 +36,7 @@ public class TestFilterWeight extends LuceneTestCase {
final int modifiers = superClassMethod.getModifiers(); final int modifiers = superClassMethod.getModifiers();
if (Modifier.isFinal(modifiers)) continue; if (Modifier.isFinal(modifiers)) continue;
if (Modifier.isStatic(modifiers)) continue; if (Modifier.isStatic(modifiers)) continue;
if (superClassMethod.getName().equals("bulkScorer")) { if (Arrays.asList("bulkScorer", "scorerSupplier").contains(superClassMethod.getName())) {
try { try {
final Method subClassMethod = subClass.getDeclaredMethod( final Method subClassMethod = subClass.getDeclaredMethod(
superClassMethod.getName(), superClassMethod.getName(),

View File

@ -311,6 +311,11 @@ public class TestDocIdSetBuilder extends LuceneTestCase {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public long estimatePointCount(IntersectVisitor visitor) {
throw new UnsupportedOperationException();
}
@Override @Override
public byte[] getMinPackedValue() throws IOException { public byte[] getMinPackedValue() throws IOException {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();

View File

@ -220,6 +220,11 @@ public class TestMutablePointsReaderUtils extends LuceneTestCase {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public long estimatePointCount(IntersectVisitor visitor) {
throw new UnsupportedOperationException();
}
@Override @Override
public byte[] getMinPackedValue() throws IOException { public byte[] getMinPackedValue() throws IOException {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();

View File

@ -1521,6 +1521,11 @@ public class MemoryIndex {
} }
} }
@Override
public long estimatePointCount(IntersectVisitor visitor) {
return 1L;
}
@Override @Override
public byte[] getMinPackedValue() throws IOException { public byte[] getMinPackedValue() throws IOException {
BytesRef[] values = info.pointValues; BytesRef[] values = info.pointValues;

View File

@ -23,8 +23,10 @@ import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.Terms;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
/** /**
@ -33,10 +35,11 @@ import org.apache.lucene.util.BytesRef;
* dense case where most documents match this query, it <b>might</b> be as * dense case where most documents match this query, it <b>might</b> be as
* fast or faster than a regular {@link PointRangeQuery}. * fast or faster than a regular {@link PointRangeQuery}.
* *
* <p> * <b>NOTE:</b> This query is typically best used within a
* <b>NOTE</b>: be very careful using this query: it is * {@link IndexOrDocValuesQuery} alongside a query that uses an indexed
* typically much slower than using {@code TermsQuery}, * structure such as {@link PointValues points} or {@link Terms terms},
* but in certain specialized cases may be faster. * which allows to run the query on doc values when that would be more
* efficient, and using an index otherwise.
* *
* @lucene.experimental * @lucene.experimental
*/ */

View File

@ -0,0 +1,116 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
/**
* A query that uses either an index (points or terms) or doc values in order
* to run a range query, depending which one is more efficient.
*/
public final class IndexOrDocValuesQuery extends Query {
private final Query indexQuery, dvQuery;
/**
* Constructor that takes both a query that executes on an index structure
* like the inverted index or the points tree, and another query that
* executes on doc values. Both queries must match the same documents and
* attribute constant scores.
*/
public IndexOrDocValuesQuery(Query indexQuery, Query dvQuery) {
this.indexQuery = indexQuery;
this.dvQuery = dvQuery;
}
@Override
public String toString(String field) {
return indexQuery.toString(field);
}
@Override
public boolean equals(Object obj) {
if (sameClassAs(obj) == false) {
return false;
}
IndexOrDocValuesQuery that = (IndexOrDocValuesQuery) obj;
return indexQuery.equals(that.indexQuery) && dvQuery.equals(that.dvQuery);
}
@Override
public int hashCode() {
int h = classHash();
h = 31 * h + indexQuery.hashCode();
h = 31 * h + dvQuery.hashCode();
return h;
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
Query indexRewrite = indexQuery.rewrite(reader);
Query dvRewrite = dvQuery.rewrite(reader);
if (indexQuery != indexRewrite || dvQuery != dvRewrite) {
return new IndexOrDocValuesQuery(indexRewrite, dvRewrite);
}
return this;
}
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException {
final Weight indexWeight = indexQuery.createWeight(searcher, needsScores, boost);
final Weight dvWeight = dvQuery.createWeight(searcher, needsScores, boost);
return new ConstantScoreWeight(this, boost) {
@Override
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
return indexWeight.bulkScorer(context);
}
@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
final ScorerSupplier indexScorerSupplier = indexWeight.scorerSupplier(context);
final ScorerSupplier dvScorerSupplier = dvWeight.scorerSupplier(context);
if (indexScorerSupplier == null || dvScorerSupplier == null) {
return null;
}
return new ScorerSupplier() {
@Override
public Scorer get(boolean randomAccess) throws IOException {
return (randomAccess ? dvScorerSupplier : indexScorerSupplier).get(randomAccess);
}
@Override
public long cost() {
return Math.min(indexScorerSupplier.cost(), dvScorerSupplier.cost());
}
};
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
ScorerSupplier scorerSupplier = scorerSupplier(context);
if (scorerSupplier == null) {
return null;
}
return scorerSupplier.get(false);
}
};
}
}

View File

@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
public class TestIndexOrDocValuesQuery extends LuceneTestCase {
public void testUseIndexForSelectiveQueries() throws IOException {
Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig()
// relies on costs and PointValues.estimateCost so we need the default codec
.setCodec(TestUtil.getDefaultCodec()));
for (int i = 0; i < 2000; ++i) {
Document doc = new Document();
if (i == 42) {
doc.add(new StringField("f1", "bar", Store.NO));
doc.add(new LongPoint("f2", 42L));
doc.add(new NumericDocValuesField("f2", 42L));
} else if (i == 100) {
doc.add(new StringField("f1", "foo", Store.NO));
doc.add(new LongPoint("f2", 2L));
doc.add(new NumericDocValuesField("f2", 2L));
} else {
doc.add(new StringField("f1", "bar", Store.NO));
doc.add(new LongPoint("f2", 2L));
doc.add(new NumericDocValuesField("f2", 2L));
}
w.addDocument(doc);
}
w.forceMerge(1);
IndexReader reader = DirectoryReader.open(w);
IndexSearcher searcher = newSearcher(reader);
searcher.setQueryCache(null);
// The term query is more selective, so the IndexOrDocValuesQuery should use doc values
final Query q1 = new BooleanQuery.Builder()
.add(new TermQuery(new Term("f1", "foo")), Occur.MUST)
.add(new IndexOrDocValuesQuery(LongPoint.newExactQuery("f2", 2), new DocValuesNumbersQuery("f2", 2L)), Occur.MUST)
.build();
final Weight w1 = searcher.createNormalizedWeight(q1, random().nextBoolean());
final Scorer s1 = w1.scorer(reader.leaves().get(0));
assertNotNull(s1.twoPhaseIterator()); // means we use doc values
// The term query is less selective, so the IndexOrDocValuesQuery should use points
final Query q2 = new BooleanQuery.Builder()
.add(new TermQuery(new Term("f1", "bar")), Occur.MUST)
.add(new IndexOrDocValuesQuery(LongPoint.newExactQuery("f2", 42), new DocValuesNumbersQuery("f2", 42L)), Occur.MUST)
.build();
final Weight w2 = searcher.createNormalizedWeight(q2, random().nextBoolean());
final Scorer s2 = w2.scorer(reader.leaves().get(0));
assertNull(s2.twoPhaseIterator()); // means we use points
reader.close();
w.close();
dir.close();
}
}

View File

@ -133,6 +133,11 @@ class CrankyPointsFormat extends PointsFormat {
} }
} }
@Override
public long estimatePointCount(IntersectVisitor visitor) {
return delegate.estimatePointCount(visitor);
}
@Override @Override
public byte[] getMinPackedValue() throws IOException { public byte[] getMinPackedValue() throws IOException {
if (random.nextInt(100) == 0) { if (random.nextInt(100) == 0) {

View File

@ -883,6 +883,13 @@ public class AssertingLeafReader extends FilterLeafReader {
in.intersect(new AssertingIntersectVisitor(in.getNumDimensions(), in.getBytesPerDimension(), visitor)); in.intersect(new AssertingIntersectVisitor(in.getNumDimensions(), in.getBytesPerDimension(), visitor));
} }
@Override
public long estimatePointCount(IntersectVisitor visitor) {
long cost = in.estimatePointCount(visitor);
assert cost >= 0;
return cost;
}
@Override @Override
public byte[] getMinPackedValue() throws IOException { public byte[] getMinPackedValue() throws IOException {
return Objects.requireNonNull(in.getMinPackedValue()); return Objects.requireNonNull(in.getMinPackedValue());

View File

@ -33,9 +33,45 @@ class AssertingWeight extends FilterWeight {
@Override @Override
public Scorer scorer(LeafReaderContext context) throws IOException { public Scorer scorer(LeafReaderContext context) throws IOException {
final Scorer inScorer = in.scorer(context); if (random.nextBoolean()) {
assert inScorer == null || inScorer.docID() == -1; final Scorer inScorer = in.scorer(context);
return AssertingScorer.wrap(new Random(random.nextLong()), inScorer, needsScores); assert inScorer == null || inScorer.docID() == -1;
return AssertingScorer.wrap(new Random(random.nextLong()), inScorer, needsScores);
} else {
final ScorerSupplier scorerSupplier = scorerSupplier(context);
if (scorerSupplier == null) {
return null;
}
if (random.nextBoolean()) {
// Evil: make sure computing the cost has no side effects
scorerSupplier.cost();
}
return scorerSupplier.get(false);
}
}
@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
final ScorerSupplier inScorerSupplier = in.scorerSupplier(context);
if (inScorerSupplier == null) {
return null;
}
return new ScorerSupplier() {
private boolean getCalled = false;
@Override
public Scorer get(boolean randomAccess) throws IOException {
assert getCalled == false;
getCalled = true;
return AssertingScorer.wrap(new Random(random.nextLong()), inScorerSupplier.get(randomAccess), needsScores);
}
@Override
public long cost() {
final long cost = inScorerSupplier.cost();
assert cost >= 0;
return cost;
}
};
} }
@Override @Override