LUCENE-6244: DisjunctionScorer propagates two-phase iterators of its sub scorers.

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1660180 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Adrien Grand 2015-02-16 18:27:25 +00:00
parent d0a220ed00
commit 40a4330b14
11 changed files with 670 additions and 108 deletions

View File

@ -86,6 +86,9 @@ Optimizations
positions lazily if the phrase query is in a conjunction with other queries.
(Robert Muir, Adrien Grand)
* LUCENE-6244: Pure disjunctions now propagate two-phase iterators of the
wrapped scorers (see LUCENE-6198). (Adrien Grand, Robert Muir)
* LUCENE-6241: FSDirectory.listAll() doesnt filter out subdirectories anymore,
for faster performance. Subdirectories don't matter to Lucene. If you need to
filter out non-index files with some custom usage, you may want to look at

View File

@ -376,10 +376,7 @@ public class BooleanWeight extends Weight {
} else {
float coords[] = new float[prohibited.size()+1];
Arrays.fill(coords, 1F);
return new ReqExclScorer(main,
new DisjunctionSumScorer(this,
prohibited.toArray(new Scorer[prohibited.size()]),
coords));
return new ReqExclScorer(main, new DisjunctionSumScorer(this, prohibited, coords, false));
}
}
@ -402,9 +399,7 @@ public class BooleanWeight extends Weight {
if (minShouldMatch > 1) {
return new MinShouldMatchSumScorer(this, optional, minShouldMatch, coords);
} else {
return new DisjunctionSumScorer(this,
optional.toArray(new Scorer[optional.size()]),
coords);
return new DisjunctionSumScorer(this, optional, coords, needsScores);
}
}
}

View File

@ -115,7 +115,8 @@ public class DisjunctionMaxQuery extends Query implements Iterable<Query> {
protected class DisjunctionMaxWeight extends Weight {
/** The Weights for our subqueries, in 1-1 correspondence with disjuncts */
protected ArrayList<Weight> weights = new ArrayList<>(); // The Weight's for our subqueries, in 1-1 correspondence with disjuncts
protected final ArrayList<Weight> weights = new ArrayList<>(); // The Weight's for our subqueries, in 1-1 correspondence with disjuncts
private final boolean needsScores;
/** Construct the Weight for this Query searched by searcher. Recursively construct subquery weights. */
public DisjunctionMaxWeight(IndexSearcher searcher, boolean needsScores) throws IOException {
@ -123,6 +124,7 @@ public class DisjunctionMaxQuery extends Query implements Iterable<Query> {
for (Query disjunctQuery : disjuncts) {
weights.add(disjunctQuery.createWeight(searcher, needsScores));
}
this.needsScores = needsScores;
}
/** Compute the sub of squared weights of us applied to our subqueries. Used for normalization. */
@ -166,7 +168,7 @@ public class DisjunctionMaxQuery extends Query implements Iterable<Query> {
// only one sub-scorer in this segment
return scorers.get(0);
} else {
return new DisjunctionMaxScorer(this, tieBreakerMultiplier, scorers.toArray(new Scorer[scorers.size()]));
return new DisjunctionMaxScorer(this, tieBreakerMultiplier, scorers, needsScores);
}
}

View File

@ -17,6 +17,7 @@ package org.apache.lucene.search;
*/
import java.io.IOException;
import java.util.List;
import org.apache.lucene.search.ScorerPriorityQueue.ScorerWrapper;
@ -41,13 +42,13 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
* @param subScorers
* The sub scorers this Scorer should iterate on
*/
DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, Scorer[] subScorers) {
super(weight, subScorers);
DisjunctionMaxScorer(Weight weight, float tieBreakerMultiplier, List<Scorer> subScorers, boolean needsScores) {
super(weight, subScorers, needsScores);
this.tieBreakerMultiplier = tieBreakerMultiplier;
}
@Override
protected float score(ScorerWrapper topList, int freq) throws IOException {
protected float score(ScorerWrapper topList) throws IOException {
float scoreSum = 0;
float scoreMax = 0;
for (ScorerWrapper w = topList; w != null; w = w.next) {

View File

@ -20,6 +20,7 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.lucene.search.ScorerPriorityQueue.ScorerWrapper;
import org.apache.lucene.util.BytesRef;
@ -29,27 +30,199 @@ import org.apache.lucene.util.BytesRef;
*/
abstract class DisjunctionScorer extends Scorer {
private final boolean needsScores;
private final ScorerPriorityQueue subScorers;
private final long cost;
/** The document number of the current match. */
protected int doc = -1;
protected int numScorers;
/** Number of matching scorers for the current match. */
private int freq = -1;
/** Linked list of scorers which are on the current doc */
private ScorerWrapper topScorers;
protected DisjunctionScorer(Weight weight, Scorer subScorers[]) {
protected DisjunctionScorer(Weight weight, List<Scorer> subScorers, boolean needsScores) {
super(weight);
if (subScorers.length <= 1) {
if (subScorers.size() <= 1) {
throw new IllegalArgumentException("There must be at least 2 subScorers");
}
this.subScorers = new ScorerPriorityQueue(subScorers.length);
this.subScorers = new ScorerPriorityQueue(subScorers.size());
long cost = 0;
for (Scorer scorer : subScorers) {
this.subScorers.add(new ScorerWrapper(scorer));
final ScorerWrapper w = new ScorerWrapper(scorer);
cost += w.cost;
this.subScorers.add(w);
}
this.cost = cost;
this.needsScores = needsScores;
}
/**
* A {@link DocIdSetIterator} which is a disjunction of the approximations of
* the provided iterators.
*/
private static class DisjunctionDISIApproximation extends DocIdSetIterator {
final ScorerPriorityQueue subScorers;
final long cost;
DisjunctionDISIApproximation(ScorerPriorityQueue subScorers) {
this.subScorers = subScorers;
long cost = 0;
for (ScorerWrapper w : subScorers) {
cost += w.cost;
}
this.cost = cost;
}
@Override
public long cost() {
return cost;
}
@Override
public int docID() {
return subScorers.top().doc;
}
@Override
public int nextDoc() throws IOException {
ScorerWrapper top = subScorers.top();
final int doc = top.doc;
do {
top.doc = top.approximation.nextDoc();
top = subScorers.updateTop();
} while (top.doc == doc);
return top.doc;
}
@Override
public int advance(int target) throws IOException {
ScorerWrapper top = subScorers.top();
do {
top.doc = top.approximation.advance(target);
top = subScorers.updateTop();
} while (top.doc < target);
return top.doc;
}
}
@Override
public TwoPhaseDocIdSetIterator asTwoPhaseIterator() {
boolean hasApproximation = false;
for (ScorerWrapper w : subScorers) {
if (w.twoPhaseView != null) {
hasApproximation = true;
break;
}
}
if (hasApproximation == false) {
// none of the sub scorers supports approximations
return null;
}
return new TwoPhaseDocIdSetIterator() {
@Override
public DocIdSetIterator approximation() {
// note it is important to share the same pq as this scorer so that
// rebalancing the pq through the approximation will also rebalance
// the pq in this scorer.
return new DisjunctionDISIApproximation(subScorers);
}
@Override
public boolean matches() throws IOException {
ScorerWrapper topScorers = subScorers.topList();
// remove the head of the list as long as it does not match
while (topScorers.twoPhaseView != null && topScorers.twoPhaseView.matches() == false) {
topScorers = topScorers.next;
if (topScorers == null) {
return false;
}
}
// now we know we have at least one match since the first element of 'matchList' matches
if (needsScores) {
// if scores or freqs are needed, we also need to remove scorers
// from the top list that do not actually match
ScorerWrapper previous = topScorers;
for (ScorerWrapper w = topScorers.next; w != null; w = w.next) {
if (w.twoPhaseView != null && w.twoPhaseView.matches() == false) {
// w does not match, remove it
previous.next = w.next;
} else {
previous = w;
}
}
// We need to explicitely set the list of top scorers to avoid the
// laziness of DisjunctionScorer.score() that would take all scorers
// positioned on the same doc as the top of the pq, including
// non-matching scorers
DisjunctionScorer.this.topScorers = topScorers;
}
return true;
}
};
}
@Override
public final long cost() {
return cost;
}
@Override
public final int docID() {
return subScorers.top().doc;
}
@Override
public final int nextDoc() throws IOException {
topScorers = null;
ScorerWrapper top = subScorers.top();
final int doc = top.doc;
do {
top.doc = top.scorer.nextDoc();
top = subScorers.updateTop();
} while (top.doc == doc);
return top.doc;
}
@Override
public final int advance(int target) throws IOException {
topScorers = null;
ScorerWrapper top = subScorers.top();
do {
top.doc = top.scorer.advance(target);
top = subScorers.updateTop();
} while (top.doc < target);
return top.doc;
}
@Override
public final int freq() throws IOException {
if (topScorers == null) {
topScorers = subScorers.topList();
}
int freq = 1;
for (ScorerWrapper w = topScorers.next; w != null; w = w.next) {
freq += 1;
}
return freq;
}
@Override
public final float score() throws IOException {
if (topScorers == null) {
topScorers = subScorers.topList();
}
return score(topScorers);
}
/** Compute the score for the given linked list of scorers. */
protected abstract float score(ScorerWrapper topList) throws IOException;
@Override
public final Collection<ChildScorer> getChildren() {
ArrayList<ChildScorer> children = new ArrayList<>();
@ -78,86 +251,4 @@ abstract class DisjunctionScorer extends Scorer {
public BytesRef getPayload() throws IOException {
return null;
}
@Override
public final long cost() {
long sum = 0;
for (ScorerWrapper scorer : subScorers) {
sum += scorer.cost;
}
return sum;
}
@Override
public final int docID() {
return doc;
}
@Override
public final int nextDoc() throws IOException {
assert doc != NO_MORE_DOCS;
ScorerWrapper top = subScorers.top();
final int doc = this.doc;
while (top.doc == doc) {
top.doc = top.scorer.nextDoc();
if (top.doc == NO_MORE_DOCS) {
subScorers.pop();
if (subScorers.size() == 0) {
return this.doc = NO_MORE_DOCS;
}
top = subScorers.top();
} else {
top = subScorers.updateTop();
}
}
freq = -1;
return this.doc = top.doc;
}
@Override
public final int advance(int target) throws IOException {
assert doc != NO_MORE_DOCS;
ScorerWrapper top = subScorers.top();
while (top.doc < target) {
top.doc = top.scorer.advance(target);
if (top.doc == NO_MORE_DOCS) {
subScorers.pop();
if (subScorers.size() == 0) {
return this.doc = NO_MORE_DOCS;
}
top = subScorers.top();
} else {
top = subScorers.updateTop();
}
}
freq = -1;
return this.doc = top.doc;
}
@Override
public final int freq() throws IOException {
if (freq < 0) {
topScorers = subScorers.topList();
int freq = 1;
for (ScorerWrapper w = topScorers.next; w != null; w = w.next) {
freq += 1;
}
this.freq = freq;
}
return freq;
}
@Override
public final float score() throws IOException {
final int freq = freq(); // compute the top scorers if necessary
return score(topScorers, freq);
}
/** Compute the score for the given linked list of scorers. */
protected abstract float score(ScorerWrapper topList, int freq) throws IOException;
}

View File

@ -18,6 +18,7 @@ package org.apache.lucene.search;
*/
import java.io.IOException;
import java.util.List;
import org.apache.lucene.search.ScorerPriorityQueue.ScorerWrapper;
@ -32,16 +33,18 @@ final class DisjunctionSumScorer extends DisjunctionScorer {
* @param subScorers Array of at least two subscorers.
* @param coord Table of coordination factors
*/
DisjunctionSumScorer(Weight weight, Scorer[] subScorers, float[] coord) {
super(weight, subScorers);
DisjunctionSumScorer(Weight weight, List<Scorer> subScorers, float[] coord, boolean needsScores) {
super(weight, subScorers, needsScores);
this.coord = coord;
}
@Override
protected float score(ScorerWrapper topList, int freq) throws IOException {
protected float score(ScorerWrapper topList) throws IOException {
double score = 0;
int freq = 0;
for (ScorerWrapper w = topList; w != null; w = w.next) {
score += w.scorer.score();
freq += 1;
}
return (float)score * coord[freq];
}

View File

@ -35,10 +35,23 @@ final class ScorerPriorityQueue implements Iterable<org.apache.lucene.search.Sco
int doc; // the current doc, used for comparison
ScorerWrapper next; // reference to a next element, see #topList
// An approximation of the scorer, or the scorer itself if it does not
// support two-phase iteration
final DocIdSetIterator approximation;
// A two-phase view of the scorer, or null if the scorer does not support
// two-phase iteration
final TwoPhaseDocIdSetIterator twoPhaseView;
ScorerWrapper(Scorer scorer) {
this.scorer = scorer;
this.cost = scorer.cost();
this.doc = -1;
this.twoPhaseView = scorer.asTwoPhaseIterator();
if (twoPhaseView != null) {
approximation = twoPhaseView.approximation();
} else {
approximation = scorer;
}
}
}

View File

@ -0,0 +1,169 @@
package org.apache.lucene.search;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur;
/**
* Basic equivalence tests for approximations.
*/
public class TestApproximationSearchEquivalence extends SearchEquivalenceTestBase {
public void testConjunction() throws Exception {
Term t1 = randomTerm();
Term t2 = randomTerm();
TermQuery q1 = new TermQuery(t1);
TermQuery q2 = new TermQuery(t2);
BooleanQuery bq1 = new BooleanQuery();
bq1.add(q1, Occur.MUST);
bq1.add(q2, Occur.MUST);
BooleanQuery bq2 = new BooleanQuery();
bq2.add(new RandomApproximationQuery(q1, random()), Occur.MUST);
bq2.add(new RandomApproximationQuery(q2, random()), Occur.MUST);
assertSameScores(bq1, bq2);
}
public void testNestedConjunction() throws Exception {
Term t1 = randomTerm();
Term t2 = randomTerm();
Term t3 = randomTerm();
TermQuery q1 = new TermQuery(t1);
TermQuery q2 = new TermQuery(t2);
TermQuery q3 = new TermQuery(t3);
BooleanQuery bq1 = new BooleanQuery();
bq1.add(q1, Occur.MUST);
bq1.add(q2, Occur.MUST);
BooleanQuery bq2 = new BooleanQuery();
bq2.add(bq1, Occur.MUST);
bq2.add(q3, Occur.MUST);
BooleanQuery bq3 = new BooleanQuery();
bq3.add(new RandomApproximationQuery(q1, random()), Occur.MUST);
bq3.add(new RandomApproximationQuery(q2, random()), Occur.MUST);
BooleanQuery bq4 = new BooleanQuery();
bq4.add(bq3, Occur.MUST);
bq4.add(q3, Occur.MUST);
assertSameScores(bq2, bq4);
}
public void testDisjunction() throws Exception {
Term t1 = randomTerm();
Term t2 = randomTerm();
TermQuery q1 = new TermQuery(t1);
TermQuery q2 = new TermQuery(t2);
BooleanQuery bq1 = new BooleanQuery();
bq1.add(q1, Occur.SHOULD);
bq1.add(q2, Occur.SHOULD);
BooleanQuery bq2 = new BooleanQuery();
bq2.add(new RandomApproximationQuery(q1, random()), Occur.SHOULD);
bq2.add(new RandomApproximationQuery(q2, random()), Occur.SHOULD);
assertSameScores(bq1, bq2);
}
public void testNestedDisjunction() throws Exception {
Term t1 = randomTerm();
Term t2 = randomTerm();
Term t3 = randomTerm();
TermQuery q1 = new TermQuery(t1);
TermQuery q2 = new TermQuery(t2);
TermQuery q3 = new TermQuery(t3);
BooleanQuery bq1 = new BooleanQuery();
bq1.add(q1, Occur.SHOULD);
bq1.add(q2, Occur.SHOULD);
BooleanQuery bq2 = new BooleanQuery();
bq2.add(bq1, Occur.SHOULD);
bq2.add(q3, Occur.SHOULD);
BooleanQuery bq3 = new BooleanQuery();
bq3.add(new RandomApproximationQuery(q1, random()), Occur.SHOULD);
bq3.add(new RandomApproximationQuery(q2, random()), Occur.SHOULD);
BooleanQuery bq4 = new BooleanQuery();
bq4.add(bq3, Occur.SHOULD);
bq4.add(q3, Occur.SHOULD);
assertSameScores(bq2, bq4);
}
public void testDisjunctionInConjunction() throws Exception {
Term t1 = randomTerm();
Term t2 = randomTerm();
Term t3 = randomTerm();
TermQuery q1 = new TermQuery(t1);
TermQuery q2 = new TermQuery(t2);
TermQuery q3 = new TermQuery(t3);
BooleanQuery bq1 = new BooleanQuery();
bq1.add(q1, Occur.SHOULD);
bq1.add(q2, Occur.SHOULD);
BooleanQuery bq2 = new BooleanQuery();
bq2.add(bq1, Occur.MUST);
bq2.add(q3, Occur.MUST);
BooleanQuery bq3 = new BooleanQuery();
bq3.add(new RandomApproximationQuery(q1, random()), Occur.SHOULD);
bq3.add(new RandomApproximationQuery(q2, random()), Occur.SHOULD);
BooleanQuery bq4 = new BooleanQuery();
bq4.add(bq3, Occur.MUST);
bq4.add(q3, Occur.MUST);
assertSameScores(bq2, bq4);
}
public void testConjunctionInDisjunction() throws Exception {
Term t1 = randomTerm();
Term t2 = randomTerm();
Term t3 = randomTerm();
TermQuery q1 = new TermQuery(t1);
TermQuery q2 = new TermQuery(t2);
TermQuery q3 = new TermQuery(t3);
BooleanQuery bq1 = new BooleanQuery();
bq1.add(q1, Occur.MUST);
bq1.add(q2, Occur.MUST);
BooleanQuery bq2 = new BooleanQuery();
bq2.add(bq1, Occur.SHOULD);
bq2.add(q3, Occur.SHOULD);
BooleanQuery bq3 = new BooleanQuery();
bq3.add(new RandomApproximationQuery(q1, random()), Occur.MUST);
bq3.add(new RandomApproximationQuery(q2, random()), Occur.MUST);
BooleanQuery bq4 = new BooleanQuery();
bq4.add(bq3, Occur.SHOULD);
bq4.add(q3, Occur.SHOULD);
assertSameScores(bq2, bq4);
}
}

View File

@ -591,7 +591,7 @@ public class TestBooleanQuery extends LuceneTestCase {
dir.close();
}
public void testConjunctionSupportsApproximations() throws IOException {
public void testConjunctionPropagatesApproximations() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
@ -619,4 +619,33 @@ public class TestBooleanQuery extends LuceneTestCase {
w.close();
dir.close();
}
public void testDisjunctionPropagatesApproximations() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
Field f = newTextField("field", "a b c", Field.Store.NO);
doc.add(f);
w.addDocument(doc);
w.commit();
DirectoryReader reader = w.getReader();
final IndexSearcher searcher = new IndexSearcher(reader);
PhraseQuery pq = new PhraseQuery();
pq.add(new Term("field", "a"));
pq.add(new Term("field", "b"));
BooleanQuery q = new BooleanQuery();
q.add(pq, Occur.SHOULD);
q.add(new TermQuery(new Term("field", "c")), Occur.SHOULD);
final Weight weight = searcher.createNormalizedWeight(q, random().nextBoolean());
final Scorer scorer = weight.scorer(reader.leaves().get(0), null);
assertNotNull(scorer.asTwoPhaseIterator());
reader.close();
w.close();
dir.close();
}
}

View File

@ -0,0 +1,227 @@
package org.apache.lucene.search;
import java.io.IOException;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.util.Random;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import com.carrotsearch.randomizedtesting.generators.RandomInts;
/**
* A {@link Query} that adds random approximations to its scorers.
*/
public class RandomApproximationQuery extends Query {
private final Query query;
private final Random random;
public RandomApproximationQuery(Query query, Random random) {
this.query = query;
this.random = random;
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
final Query rewritten = query.rewrite(reader);
if (rewritten != query) {
return new RandomApproximationQuery(rewritten, random);
}
return this;
}
@Override
public boolean equals(Object obj) {
if (obj instanceof RandomApproximationQuery == false) {
return false;
}
final RandomApproximationQuery that = (RandomApproximationQuery) obj;
if (this.getBoost() != that.getBoost()) {
return false;
}
if (this.query.equals(that.query) == false) {
return false;
}
return true;
}
@Override
public int hashCode() {
return 31 * query.hashCode() + Float.floatToIntBits(getBoost());
}
@Override
public String toString(String field) {
return query.toString(field);
}
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores) throws IOException {
final Weight weight = query.createWeight(searcher, needsScores);
return new Weight(RandomApproximationQuery.this) {
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
return weight.explain(context, doc);
}
@Override
public float getValueForNormalization() throws IOException {
return weight.getValueForNormalization();
}
@Override
public void normalize(float norm, float topLevelBoost) {
weight.normalize(norm, topLevelBoost);
}
@Override
public Scorer scorer(LeafReaderContext context, Bits acceptDocs) throws IOException {
final Scorer scorer = weight.scorer(context, acceptDocs);
if (scorer == null) {
return null;
}
final RandomTwoPhaseView twoPhaseView = new RandomTwoPhaseView(random, scorer);
return new Scorer(this) {
@Override
public TwoPhaseDocIdSetIterator asTwoPhaseIterator() {
return twoPhaseView;
}
@Override
public float score() throws IOException {
return scorer.score();
}
@Override
public int freq() throws IOException {
return scorer.freq();
}
@Override
public int nextPosition() throws IOException {
return scorer.nextPosition();
}
@Override
public int startOffset() throws IOException {
return scorer.startOffset();
}
@Override
public int endOffset() throws IOException {
return scorer.endOffset();
}
@Override
public BytesRef getPayload() throws IOException {
return scorer.getPayload();
}
@Override
public int docID() {
return scorer.docID();
}
@Override
public int nextDoc() throws IOException {
return scorer.nextDoc();
}
@Override
public int advance(int target) throws IOException {
return scorer.advance(target);
}
@Override
public long cost() {
return scorer.cost();
}
};
}
};
}
private static class RandomTwoPhaseView extends TwoPhaseDocIdSetIterator {
private final DocIdSetIterator disi;
private final RandomApproximation approximation;
RandomTwoPhaseView(Random random, DocIdSetIterator disi) {
this.disi = disi;
this.approximation = new RandomApproximation(random, disi);
}
@Override
public DocIdSetIterator approximation() {
return approximation;
}
@Override
public boolean matches() throws IOException {
return approximation.doc == disi.docID();
}
}
private static class RandomApproximation extends DocIdSetIterator {
private final Random random;
private final DocIdSetIterator disi;
int doc = -1;
public RandomApproximation(Random random, DocIdSetIterator disi) {
this.random = random;
this.disi = disi;
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
if (disi.docID() < target) {
disi.advance(target);
}
if (disi.docID() == NO_MORE_DOCS) {
return doc = NO_MORE_DOCS;
}
return doc = RandomInts.randomIntBetween(random, target, disi.docID());
}
@Override
public long cost() {
return disi.cost();
}
}
}

View File

@ -17,6 +17,7 @@ package org.apache.lucene.search;
* limitations under the License.
*/
import java.io.IOException;
import java.util.BitSet;
import java.util.Random;
@ -193,4 +194,32 @@ public abstract class SearchEquivalenceTestBase extends LuceneTestCase {
assertTrue(bitset.get(td1.scoreDocs[i].doc));
}
}
/**
* Assert that two queries return the same documents and with the same scores.
*/
protected void assertSameScores(Query q1, Query q2) throws Exception {
assertSameSet(q1, q2);
assertSameScores(q1, q2, null);
// also test with a filter to test advancing
assertSameScores(q1, q2, randomFilter());
}
protected void assertSameScores(Query q1, Query q2, Filter filter) throws Exception {
if (filter != null && random().nextBoolean()) {
q1 = new FilteredQuery(q1, filter, TestUtil.randomFilterStrategy(random()));
q2 = new FilteredQuery(q2, filter, TestUtil.randomFilterStrategy(random()));
filter = null;
}
// not efficient, but simple!
TopDocs td1 = s1.search(q1, filter, reader.maxDoc());
TopDocs td2 = s2.search(q2, filter, reader.maxDoc());
assertEquals(td1.totalHits, td2.totalHits);
for (int i = 0; i < td1.scoreDocs.length; ++i) {
assertEquals(td1.scoreDocs[i].doc, td2.scoreDocs[i].doc);
assertEquals(td1.scoreDocs[i].score, td2.scoreDocs[i].score, 10e-5);
}
}
}