LUCENE-6310: FilterScorer supports two-phase always, adds two-phase support to some long-tail scorers

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1663095 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Robert Muir 2015-03-01 13:24:59 +00:00
parent cd66868e93
commit f64e74c0c7
10 changed files with 199 additions and 144 deletions

View File

@ -38,11 +38,6 @@ class BooleanTopLevelScorers {
this.boost = boost;
}
@Override
public TwoPhaseIterator asTwoPhaseIterator() {
return in.asTwoPhaseIterator();
}
@Override
public float score() throws IOException {
return in.score() * boost;

View File

@ -182,11 +182,6 @@ public class ConstantScoreQuery extends Query {
this.score = score;
}
@Override
public TwoPhaseIterator asTwoPhaseIterator() {
return in.asTwoPhaseIterator();
}
@Override
public int freq() throws IOException {
return 1;

View File

@ -62,17 +62,17 @@ public abstract class FilterScorer extends Scorer {
}
@Override
public int docID() {
public final int docID() {
return in.docID();
}
@Override
public int nextDoc() throws IOException {
public final int nextDoc() throws IOException {
return in.nextDoc();
}
@Override
public int advance(int target) throws IOException {
public final int advance(int target) throws IOException {
return in.advance(target);
}
@ -80,5 +80,9 @@ public abstract class FilterScorer extends Scorer {
public long cost() {
return in.cost();
}
@Override
public final TwoPhaseIterator asTwoPhaseIterator() {
return in.asTwoPhaseIterator();
}
}

View File

@ -18,6 +18,7 @@ package org.apache.lucene.search;
*/
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Set;
@ -154,13 +155,12 @@ public class FilteredQuery extends Query {
* than document scoring or if the filter has a linear running time to compute
* the next matching doc like exact geo distances.
*/
private static final class QueryFirstScorer extends FilterScorer {
private static final class QueryFirstScorer extends Scorer {
private final Scorer scorer;
private int scorerDoc = -1;
private final Bits filterBits;
protected QueryFirstScorer(Weight weight, Bits filterBits, Scorer other) {
super(other, weight);
super(weight);
this.scorer = other;
this.filterBits = filterBits;
}
@ -170,8 +170,8 @@ public class FilteredQuery extends Query {
int doc;
for(;;) {
doc = scorer.nextDoc();
if (doc == Scorer.NO_MORE_DOCS || filterBits.get(doc)) {
return scorerDoc = doc;
if (doc == DocIdSetIterator.NO_MORE_DOCS || filterBits.get(doc)) {
return doc;
}
}
}
@ -179,15 +179,31 @@ public class FilteredQuery extends Query {
@Override
public int advance(int target) throws IOException {
int doc = scorer.advance(target);
if (doc != Scorer.NO_MORE_DOCS && !filterBits.get(doc)) {
return scorerDoc = nextDoc();
if (doc != DocIdSetIterator.NO_MORE_DOCS && !filterBits.get(doc)) {
return nextDoc();
} else {
return scorerDoc = doc;
return doc;
}
}
@Override
public int docID() {
return scorerDoc;
return scorer.docID();
}
@Override
public float score() throws IOException {
return scorer.score();
}
@Override
public int freq() throws IOException {
return scorer.freq();
}
@Override
public long cost() {
return scorer.cost();
}
@Override
@ -195,6 +211,37 @@ public class FilteredQuery extends Query {
return Collections.singleton(new ChildScorer(scorer, "FILTERED"));
}
@Override
public TwoPhaseIterator asTwoPhaseIterator() {
TwoPhaseIterator inner = scorer.asTwoPhaseIterator();
if (inner != null) {
// we are like a simplified conjunction here, handle the nested case:
return new TwoPhaseIterator() {
@Override
public DocIdSetIterator approximation() {
return inner.approximation();
}
@Override
public boolean matches() throws IOException {
// check the approximation matches first, then check bits last.
return inner.matches() && filterBits.get(scorer.docID());
}
};
} else {
// scorer doesnt have an approximation, just use it, to force bits applied last.
return new TwoPhaseIterator() {
@Override
public DocIdSetIterator approximation() {
return scorer;
}
@Override
public boolean matches() throws IOException {
return filterBits.get(scorer.docID());
}
};
}
}
}
private static class QueryFirstBulkScorer extends BulkScorer {
@ -242,53 +289,29 @@ public class FilteredQuery extends Query {
* jumping past the target document. When both land on the same document, it's
* collected.
*/
private static final class LeapFrogScorer extends FilterScorer {
private final DocIdSetIterator secondary;
private final DocIdSetIterator primary;
private static final class LeapFrogScorer extends Scorer {
private final ConjunctionDISI conjunction;
private final Scorer scorer;
private int primaryDoc = -1;
private int secondaryDoc = -1;
protected LeapFrogScorer(Weight weight, DocIdSetIterator primary, DocIdSetIterator secondary, Scorer scorer) {
super(scorer, weight);
this.primary = primary;
this.secondary = secondary;
super(weight);
conjunction = ConjunctionDISI.intersect(Arrays.asList(primary, secondary));
this.scorer = scorer;
}
private final int advanceToNextCommonDoc() throws IOException {
for (;;) {
if (secondaryDoc < primaryDoc) {
secondaryDoc = secondary.advance(primaryDoc);
} else if (secondaryDoc == primaryDoc) {
return primaryDoc;
} else {
primaryDoc = primary.advance(secondaryDoc);
}
}
}
@Override
public final int nextDoc() throws IOException {
primaryDoc = primaryNext();
return advanceToNextCommonDoc();
}
protected int primaryNext() throws IOException {
return primary.nextDoc();
public int nextDoc() throws IOException {
return conjunction.nextDoc();
}
@Override
public final int advance(int target) throws IOException {
if (target > primaryDoc) {
primaryDoc = primary.advance(target);
}
return advanceToNextCommonDoc();
return conjunction.advance(target);
}
@Override
public final int docID() {
return secondaryDoc;
return conjunction.docID();
}
@Override
@ -298,7 +321,22 @@ public class FilteredQuery extends Query {
@Override
public long cost() {
return Math.min(primary.cost(), secondary.cost());
return conjunction.cost();
}
@Override
public float score() throws IOException {
return scorer.score();
}
@Override
public int freq() throws IOException {
return scorer.freq();
}
@Override
public TwoPhaseIterator asTwoPhaseIterator() {
return conjunction.asTwoPhaseIterator();
}
}

View File

@ -27,7 +27,7 @@ import java.util.Collections;
* This <code>Scorer</code> implements {@link Scorer#advance(int)},
* and it uses the advance() on the given scorers.
*/
class ReqExclScorer extends FilterScorer {
class ReqExclScorer extends Scorer {
private final Scorer reqScorer;
// approximations of the scorers, or the scorers themselves if they don't support approximations
@ -42,7 +42,7 @@ class ReqExclScorer extends FilterScorer {
* @param exclScorer indicates exclusion.
*/
public ReqExclScorer(Scorer reqScorer, Scorer exclScorer) {
super(reqScorer);
super(reqScorer.weight);
this.reqScorer = reqScorer;
reqTwoPhaseIterator = reqScorer.asTwoPhaseIterator();
if (reqTwoPhaseIterator == null) {
@ -106,10 +106,16 @@ class ReqExclScorer extends FilterScorer {
return reqScorer.docID();
}
/** Returns the score of the current document matching the query.
* Initially invalid, until {@link #nextDoc()} is called the first time.
* @return The score of the required scorer.
*/
@Override
public int freq() throws IOException {
return reqScorer.freq();
}
@Override
public long cost() {
return reqScorer.cost();
}
@Override
public float score() throws IOException {
return reqScorer.score(); // reqScorer may be null when next() or skipTo() already return false

View File

@ -17,7 +17,6 @@ package org.apache.lucene.search.join;
* limitations under the License.
*/
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
@ -28,20 +27,15 @@ import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.search.Filter;
import org.apache.lucene.search.FilteredQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryWrapperFilter;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.LuceneTestCase;
public class TestBlockJoinValidation extends LuceneTestCase {
@ -125,41 +119,6 @@ public class TestBlockJoinValidation extends LuceneTestCase {
}
}
// a filter for which other queries don't have special rewrite rules
private static class FilterWrapper extends Filter {
private final Filter in;
FilterWrapper(Filter in) {
this.in = in;
}
@Override
public DocIdSet getDocIdSet(LeafReaderContext context, Bits acceptDocs) throws IOException {
return in.getDocIdSet(context, acceptDocs);
}
@Override
public String toString(String field) {
return in.toString(field);
}
}
public void testValidationForToChildBjqWithChildFilterQuery() throws Exception {
Query parentQueryWithRandomChild = createParentQuery();
ToChildBlockJoinQuery blockJoinQuery = new ToChildBlockJoinQuery(parentQueryWithRandomChild, parentsFilter);
Filter childFilter = new FilterWrapper(new QueryWrapperFilter(new TermQuery(new Term("common_field", "1"))));
try {
indexSearcher.search(new FilteredQuery(blockJoinQuery, childFilter), 1);
fail("didn't get expected exception");
} catch (IllegalStateException expected) {
assertTrue(expected.getMessage() != null && expected.getMessage().contains(ToChildBlockJoinQuery.ILLEGAL_ADVANCE_ON_PARENT));
}
}
public void testAdvanceValidationForToChildBjq() throws Exception {
int randomChildNumber = getRandomChildNumber(0);
// we need to make advance method meet wrong document, so random child number

View File

@ -30,9 +30,11 @@ import org.apache.lucene.queries.function.FunctionQuery;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.ComplexExplanation;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FilterScorer;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.ToStringUtils;
@ -274,76 +276,46 @@ public class CustomScoreQuery extends Query {
/**
* A scorer that applies a (callback) function on scores of the subQuery.
*/
private class CustomScorer extends Scorer {
private class CustomScorer extends FilterScorer {
private final float qWeight;
private final Scorer subQueryScorer;
private final Scorer[] valSrcScorers;
private final CustomScoreProvider provider;
private final float[] vScores; // reused in score() to avoid allocating this array for each doc
// TODO : can we use FilterScorer here instead?
private int valSrcDocID = -1; // we lazily advance subscorers.
// constructor
private CustomScorer(CustomScoreProvider provider, CustomWeight w, float qWeight,
Scorer subQueryScorer, Scorer[] valSrcScorers) {
super(w);
super(subQueryScorer, w);
this.qWeight = qWeight;
this.subQueryScorer = subQueryScorer;
this.valSrcScorers = valSrcScorers;
this.vScores = new float[valSrcScorers.length];
this.provider = provider;
}
@Override
public int nextDoc() throws IOException {
int doc = subQueryScorer.nextDoc();
if (doc != NO_MORE_DOCS) {
public float score() throws IOException {
// lazily advance to current doc.
int doc = docID();
if (doc > valSrcDocID) {
for (Scorer valSrcScorer : valSrcScorers) {
valSrcScorer.advance(doc);
}
valSrcDocID = doc;
}
return doc;
}
@Override
public int docID() {
return subQueryScorer.docID();
}
/*(non-Javadoc) @see org.apache.lucene.search.Scorer#score() */
@Override
public float score() throws IOException {
// TODO: this thing technically takes any Query, so what about when subs don't match?
for (int i = 0; i < valSrcScorers.length; i++) {
vScores[i] = valSrcScorers[i].score();
}
return qWeight * provider.customScore(subQueryScorer.docID(), subQueryScorer.score(), vScores);
}
@Override
public int freq() throws IOException {
return subQueryScorer.freq();
}
@Override
public Collection<ChildScorer> getChildren() {
return Collections.singleton(new ChildScorer(subQueryScorer, "CUSTOM"));
}
@Override
public int advance(int target) throws IOException {
int doc = subQueryScorer.advance(target);
if (doc != NO_MORE_DOCS) {
for (Scorer valSrcScorer : valSrcScorers) {
valSrcScorer.advance(doc);
}
}
return doc;
}
@Override
public long cost() {
return subQueryScorer.cost();
}
}
@Override

View File

@ -30,6 +30,7 @@ import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.RunAutomaton;
// TODO: add two-phase and needsScores support. maybe use conjunctionDISI internally?
class TermAutomatonScorer extends Scorer {
private final EnumAndScorer[] subs;
private final EnumAndScorer[] subsOnDoc;

View File

@ -27,6 +27,8 @@ import java.util.WeakHashMap;
/** Wraps a Scorer with additional checks */
public class AssertingScorer extends Scorer {
// TODO: add asserts for two-phase intersection
static enum IteratorState { START, ITERATING, FINISHED };

View File

@ -17,6 +17,7 @@ package org.apache.lucene.search;
* limitations under the License.
*/
import java.io.IOException;
import java.util.BitSet;
import java.util.Random;
@ -28,10 +29,13 @@ import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BitDocIdSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
import org.apache.lucene.util.automaton.Automata;
@ -140,7 +144,86 @@ public abstract class SearchEquivalenceTestBase extends LuceneTestCase {
* Returns a random filter over the document set
*/
protected Filter randomFilter() {
return new QueryWrapperFilter(TermRangeQuery.newStringRange("field", "a", "" + randomChar(), true, true));
final Query query;
if (random().nextBoolean()) {
query = TermRangeQuery.newStringRange("field", "a", "" + randomChar(), true, true);
} else {
// use a query with a two-phase approximation
PhraseQuery phrase = new PhraseQuery();
phrase.add(new Term("field", "" + randomChar()));
phrase.add(new Term("field", "" + randomChar()));
phrase.setSlop(100);
query = phrase;
}
// now wrap the query as a filter. QWF has its own codepath
if (random().nextBoolean()) {
return new QueryWrapperFilter(query);
} else {
return new SlowWrapperFilter(query, random().nextBoolean());
}
}
static class SlowWrapperFilter extends Filter {
final Query query;
final boolean useBits;
SlowWrapperFilter(Query query, boolean useBits) {
this.query = query;
this.useBits = useBits;
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
Query q = query.rewrite(reader);
if (q != query) {
return new SlowWrapperFilter(q, useBits);
} else {
return this;
}
}
@Override
public DocIdSet getDocIdSet(LeafReaderContext context, Bits acceptDocs) throws IOException {
// get a private context that is used to rewrite, createWeight and score eventually
final LeafReaderContext privateContext = context.reader().getContext();
final Weight weight = new IndexSearcher(privateContext).createNormalizedWeight(query, false);
return new DocIdSet() {
@Override
public DocIdSetIterator iterator() throws IOException {
return weight.scorer(privateContext, acceptDocs);
}
@Override
public long ramBytesUsed() {
return 0L;
}
@Override
public Bits bits() throws IOException {
if (useBits) {
BitDocIdSet.Builder builder = new BitDocIdSet.Builder(context.reader().maxDoc());
DocIdSetIterator disi = iterator();
if (disi != null) {
builder.or(disi);
}
BitDocIdSet bitset = builder.build();
if (bitset == null) {
return new Bits.MatchNoBits(context.reader().maxDoc());
} else {
return bitset.bits();
}
} else {
return null;
}
}
};
}
@Override
public String toString(String field) {
return "SlowQWF(" + query + ")";
}
}
/**