diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 6ab38d51c63..72b969ab14d 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -208,6 +208,11 @@ New features * LUCENE-2559: Added SegmentReader.reopen methods (John Wang via Mike McCandless) +* LUCENE-2590: Added Scorer.visitSubScorers, and Scorer.freq. Along + with a custom Collector these experimental methods make it possible + to gather the hit-count per sub-clause and per document while a + search is running. (Simon Willnauer, Mike McCandless) + Optimizations * LUCENE-2410: ~20% speedup on exact (slop=0) PhraseQuery matching. diff --git a/lucene/src/java/org/apache/lucene/search/BooleanQuery.java b/lucene/src/java/org/apache/lucene/search/BooleanQuery.java index 04a88fe69a4..b1c643f4734 100644 --- a/lucene/src/java/org/apache/lucene/search/BooleanQuery.java +++ b/lucene/src/java/org/apache/lucene/search/BooleanQuery.java @@ -320,7 +320,7 @@ public class BooleanQuery extends Query implements Iterable { // Check if we can return a BooleanScorer if (!scoreDocsInOrder && topScorer && required.size() == 0 && prohibited.size() < 32) { - return new BooleanScorer(similarity, minNrShouldMatch, optional, prohibited, maxCoord); + return new BooleanScorer(this, similarity, minNrShouldMatch, optional, prohibited, maxCoord); } if (required.size() == 0 && optional.size() == 0) { @@ -334,7 +334,7 @@ public class BooleanQuery extends Query implements Iterable { } // Return a BooleanScorer2 - return new BooleanScorer2(similarity, minNrShouldMatch, required, prohibited, optional, maxCoord); + return new BooleanScorer2(this, similarity, minNrShouldMatch, required, prohibited, optional, maxCoord); } @Override diff --git a/lucene/src/java/org/apache/lucene/search/BooleanScorer.java b/lucene/src/java/org/apache/lucene/search/BooleanScorer.java index a6e99d0e645..8edca949c77 100644 --- a/lucene/src/java/org/apache/lucene/search/BooleanScorer.java +++ b/lucene/src/java/org/apache/lucene/search/BooleanScorer.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.List; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanClause.Occur; /* Description from Doug Cutting (excerpted from * LUCENE-1483): @@ -115,6 +116,7 @@ final class BooleanScorer extends Scorer { float score; int doc = NO_MORE_DOCS; + int freq; public BucketScorer() { super(null); } @@ -124,6 +126,9 @@ final class BooleanScorer extends Scorer { @Override public int docID() { return doc; } + @Override + public float freq() { return freq; } + @Override public int nextDoc() throws IOException { return NO_MORE_DOCS; } @@ -159,7 +164,8 @@ final class BooleanScorer extends Scorer { static final class SubScorer { public Scorer scorer; - public boolean required = false; + // TODO: re-enable this if BQ ever sends us required clauses + //public boolean required = false; public boolean prohibited = false; public Collector collector; public SubScorer next; @@ -167,8 +173,12 @@ final class BooleanScorer extends Scorer { public SubScorer(Scorer scorer, boolean required, boolean prohibited, Collector collector, SubScorer next) throws IOException { + if (required) { + throw new IllegalArgumentException("this scorer cannot handle required=true"); + } this.scorer = scorer; - this.required = required; + // TODO: re-enable this if BQ ever sends us required clauses + //this.required = required; this.prohibited = prohibited; this.collector = collector; this.next = next; @@ -178,17 +188,18 @@ final class BooleanScorer extends Scorer { private SubScorer scorers = null; private BucketTable bucketTable = new BucketTable(); private final float[] coordFactors; - private int requiredMask = 0; + // TODO: re-enable this if BQ ever sends us required clauses + //private int requiredMask = 0; private int prohibitedMask = 0; private int nextMask = 1; private final int minNrShouldMatch; private int end; private Bucket current; private int doc = -1; - - BooleanScorer(Similarity similarity, int minNrShouldMatch, + + BooleanScorer(Weight weight, Similarity similarity, int minNrShouldMatch, List optionalScorers, List prohibitedScorers, int maxCoord) throws IOException { - super(similarity); + super(similarity, weight); this.minNrShouldMatch = minNrShouldMatch; if (optionalScorers != null && optionalScorers.size() > 0) { @@ -231,8 +242,11 @@ final class BooleanScorer extends Scorer { while (current != null) { // more queued // check prohibited & required - if ((current.bits & prohibitedMask) == 0 && - (current.bits & requiredMask) == requiredMask) { + if ((current.bits & prohibitedMask) == 0) { + + // TODO: re-enable this if BQ ever sends us required + // clauses + //&& (current.bits & requiredMask) == requiredMask) { if (current.doc >= max){ tmp = current; @@ -245,6 +259,7 @@ final class BooleanScorer extends Scorer { if (current.coord >= minNrShouldMatch) { bs.score = current.score * coordFactors[current.coord]; bs.doc = current.doc; + bs.freq = current.coord; collector.collect(current.doc); } } @@ -294,8 +309,9 @@ final class BooleanScorer extends Scorer { // check prohibited & required, and minNrShouldMatch if ((current.bits & prohibitedMask) == 0 && - (current.bits & requiredMask) == requiredMask && current.coord >= minNrShouldMatch) { + // TODO: re-enable this if BQ ever sends us required clauses + // (current.bits & requiredMask) == requiredMask && return doc = current.doc; } } @@ -339,5 +355,28 @@ final class BooleanScorer extends Scorer { buffer.append(")"); return buffer.toString(); } + + @Override + protected void visitSubScorers(Query parent, Occur relationship, ScorerVisitor visitor) { + super.visitSubScorers(parent, relationship, visitor); + final Query q = weight.getQuery(); + SubScorer sub = scorers; + while(sub != null) { + // TODO: re-enable this if BQ ever sends us required + //clauses + //if (sub.required) { + //relationship = Occur.MUST; + if (!sub.prohibited) { + relationship = Occur.SHOULD; + } else { + // TODO: maybe it's pointless to do this, but, it is + // possible the doc may still be collected, eg foo + // OR (bar -fee) + relationship = Occur.MUST_NOT; + } + sub.scorer.visitSubScorers(q, relationship, visitor); + sub = sub.next; + } + } } diff --git a/lucene/src/java/org/apache/lucene/search/BooleanScorer2.java b/lucene/src/java/org/apache/lucene/search/BooleanScorer2.java index 33e911d48da..74d277d61b8 100644 --- a/lucene/src/java/org/apache/lucene/search/BooleanScorer2.java +++ b/lucene/src/java/org/apache/lucene/search/BooleanScorer2.java @@ -21,6 +21,8 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import org.apache.lucene.search.BooleanClause.Occur; + /* See the description in BooleanScorer.java, comparing * BooleanScorer & BooleanScorer2 */ @@ -59,7 +61,7 @@ class BooleanScorer2 extends Scorer { /** The number of optionalScorers that need to match (if there are any) */ private final int minNrShouldMatch; - + private int doc = -1; /** @@ -80,9 +82,9 @@ class BooleanScorer2 extends Scorer { * @param optional * the list of optional scorers. */ - public BooleanScorer2(Similarity similarity, int minNrShouldMatch, + public BooleanScorer2(Weight weight, Similarity similarity, int minNrShouldMatch, List required, List prohibited, List optional, int maxCoord) throws IOException { - super(similarity); + super(similarity, weight); if (minNrShouldMatch < 0) { throw new IllegalArgumentException("Minimum number of optional scorers should not be negative"); } @@ -301,10 +303,28 @@ class BooleanScorer2 extends Scorer { return sum * coordinator.coordFactors[coordinator.nrMatchers]; } + @Override + public float freq() { + return coordinator.nrMatchers; + } + @Override public int advance(int target) throws IOException { return doc = countingSumScorer.advance(target); } + + @Override + protected void visitSubScorers(Query parent, Occur relationship, ScorerVisitor visitor) { + super.visitSubScorers(parent, relationship, visitor); + final Query q = weight.getQuery(); + for (Scorer s : optionalScorers) { + s.visitSubScorers(q, Occur.SHOULD, visitor); + } + for (Scorer s : prohibitedScorers) { + s.visitSubScorers(q, Occur.MUST_NOT, visitor); + } + for (Scorer s : requiredScorers) { + s.visitSubScorers(q, Occur.MUST, visitor); + } + } } - - diff --git a/lucene/src/java/org/apache/lucene/search/ConstantScoreQuery.java b/lucene/src/java/org/apache/lucene/search/ConstantScoreQuery.java index af0b400ed0e..323f817e763 100644 --- a/lucene/src/java/org/apache/lucene/search/ConstantScoreQuery.java +++ b/lucene/src/java/org/apache/lucene/search/ConstantScoreQuery.java @@ -123,7 +123,7 @@ public class ConstantScoreQuery extends Query { int doc = -1; public ConstantScorer(Similarity similarity, IndexReader reader, Weight w) throws IOException { - super(similarity); + super(similarity,w); theScore = w.getValue(); DocIdSet docIdSet = filter.getDocIdSet(reader); if (docIdSet == null) { diff --git a/lucene/src/java/org/apache/lucene/search/ExactPhraseScorer.java b/lucene/src/java/org/apache/lucene/search/ExactPhraseScorer.java index 14ab22de1af..aa8014a552a 100644 --- a/lucene/src/java/org/apache/lucene/search/ExactPhraseScorer.java +++ b/lucene/src/java/org/apache/lucene/search/ExactPhraseScorer.java @@ -21,9 +21,9 @@ import java.io.IOException; import java.util.Arrays; import org.apache.lucene.index.*; +import org.apache.lucene.search.BooleanClause.Occur; final class ExactPhraseScorer extends Scorer { - private final Weight weight; private final byte[] norms; private final float value; @@ -63,8 +63,7 @@ final class ExactPhraseScorer extends Scorer { ExactPhraseScorer(Weight weight, PhraseQuery.PostingsAndFreq[] postings, Similarity similarity, byte[] norms) throws IOException { - super(similarity); - this.weight = weight; + super(similarity, weight); this.norms = norms; this.value = weight.getValue(); @@ -193,8 +192,8 @@ final class ExactPhraseScorer extends Scorer { return "ExactPhraseScorer(" + weight + ")"; } - // used by MultiPhraseQuery - float currentFreq() { + @Override + public float freq() { return freq; } diff --git a/lucene/src/java/org/apache/lucene/search/FilteredQuery.java b/lucene/src/java/org/apache/lucene/search/FilteredQuery.java index 01bea83d01c..70301d3ab25 100644 --- a/lucene/src/java/org/apache/lucene/search/FilteredQuery.java +++ b/lucene/src/java/org/apache/lucene/search/FilteredQuery.java @@ -126,7 +126,7 @@ extends Query { return null; } - return new Scorer(similarity) { + return new Scorer(similarity, this) { private int doc = -1; diff --git a/lucene/src/java/org/apache/lucene/search/MatchAllDocsQuery.java b/lucene/src/java/org/apache/lucene/search/MatchAllDocsQuery.java index bbddfa6a575..c96a8d06e03 100644 --- a/lucene/src/java/org/apache/lucene/search/MatchAllDocsQuery.java +++ b/lucene/src/java/org/apache/lucene/search/MatchAllDocsQuery.java @@ -54,7 +54,7 @@ public class MatchAllDocsQuery extends Query { MatchAllScorer(IndexReader reader, Similarity similarity, Weight w, byte[] norms) throws IOException { - super(similarity); + super(similarity,w); delDocs = MultiFields.getDeletedDocs(reader); score = w.getValue(); maxDoc = reader.maxDoc(); diff --git a/lucene/src/java/org/apache/lucene/search/MultiPhraseQuery.java b/lucene/src/java/org/apache/lucene/search/MultiPhraseQuery.java index ea774b83d43..f4c786e06ae 100644 --- a/lucene/src/java/org/apache/lucene/search/MultiPhraseQuery.java +++ b/lucene/src/java/org/apache/lucene/search/MultiPhraseQuery.java @@ -271,11 +271,7 @@ public class MultiPhraseQuery extends Query { int d = scorer.advance(doc); float phraseFreq; if (d == doc) { - if (slop == 0) { - phraseFreq = ((ExactPhraseScorer) scorer).currentFreq(); - } else { - phraseFreq = ((SloppyPhraseScorer) scorer).currentFreq(); - } + phraseFreq = scorer.freq(); } else { phraseFreq = 0.0f; } diff --git a/lucene/src/java/org/apache/lucene/search/PhraseQuery.java b/lucene/src/java/org/apache/lucene/search/PhraseQuery.java index e24178ab18c..f129d717204 100644 --- a/lucene/src/java/org/apache/lucene/search/PhraseQuery.java +++ b/lucene/src/java/org/apache/lucene/search/PhraseQuery.java @@ -275,11 +275,7 @@ public class PhraseQuery extends Query { int d = scorer.advance(doc); float phraseFreq; if (d == doc) { - if (slop == 0) { - phraseFreq = ((ExactPhraseScorer) scorer).currentFreq(); - } else { - phraseFreq = ((SloppyPhraseScorer) scorer).currentFreq(); - } + phraseFreq = scorer.freq(); } else { phraseFreq = 0.0f; } diff --git a/lucene/src/java/org/apache/lucene/search/PhraseScorer.java b/lucene/src/java/org/apache/lucene/search/PhraseScorer.java index 4dc62cdea60..a82df03d027 100644 --- a/lucene/src/java/org/apache/lucene/search/PhraseScorer.java +++ b/lucene/src/java/org/apache/lucene/search/PhraseScorer.java @@ -19,6 +19,8 @@ package org.apache.lucene.search; import java.io.IOException; +import org.apache.lucene.search.BooleanClause.Occur; + /** Expert: Scoring functionality for phrase queries. *
A document is considered matching if it contains the phrase-query terms * at "valid" positions. What "valid positions" are @@ -30,7 +32,6 @@ import java.io.IOException; * means a match. */ abstract class PhraseScorer extends Scorer { - private Weight weight; protected byte[] norms; protected float value; @@ -43,9 +44,8 @@ abstract class PhraseScorer extends Scorer { PhraseScorer(Weight weight, PhraseQuery.PostingsAndFreq[] postings, Similarity similarity, byte[] norms) { - super(similarity); + super(similarity, weight); this.norms = norms; - this.weight = weight; this.value = weight.getValue(); // convert tps to a list of phrase positions. @@ -129,8 +129,11 @@ abstract class PhraseScorer extends Scorer { /** * phrase frequency in current doc as computed by phraseFreq(). */ - public final float currentFreq() { return freq; } - + @Override + public final float freq() { + return freq; + } + /** * For a document containing all the phrase query terms, compute the * frequency of the phrase in that document. @@ -179,5 +182,5 @@ abstract class PhraseScorer extends Scorer { @Override public String toString() { return "scorer(" + weight + ")"; } - + } diff --git a/lucene/src/java/org/apache/lucene/search/ScoreCachingWrappingScorer.java b/lucene/src/java/org/apache/lucene/search/ScoreCachingWrappingScorer.java index a3e982f3efc..09a0bcd817d 100644 --- a/lucene/src/java/org/apache/lucene/search/ScoreCachingWrappingScorer.java +++ b/lucene/src/java/org/apache/lucene/search/ScoreCachingWrappingScorer.java @@ -32,7 +32,7 @@ import java.io.IOException; */ public class ScoreCachingWrappingScorer extends Scorer { - private Scorer scorer; + private final Scorer scorer; private int curDoc = -1; private float curScore; diff --git a/lucene/src/java/org/apache/lucene/search/Scorer.java b/lucene/src/java/org/apache/lucene/search/Scorer.java index 27a545432f8..84e51431e88 100644 --- a/lucene/src/java/org/apache/lucene/search/Scorer.java +++ b/lucene/src/java/org/apache/lucene/search/Scorer.java @@ -19,6 +19,8 @@ package org.apache.lucene.search; import java.io.IOException; +import org.apache.lucene.search.BooleanClause.Occur; + /** * Expert: Common scoring functionality for different types of queries. * @@ -39,12 +41,23 @@ import java.io.IOException; */ public abstract class Scorer extends DocIdSetIterator { private final Similarity similarity; + protected final Weight weight; /** Constructs a Scorer. * @param similarity The Similarity implementation used by this scorer. */ protected Scorer(Similarity similarity) { + this(similarity, null); + } + + /** + * Constructs a Scorer + * @param similarity The Similarity implementation used by this scorer. + * @param weight The scorers Weight + */ + protected Scorer(Similarity similarity, Weight weight) { this.similarity = similarity; + this.weight = weight; } /** Returns the Similarity implementation used by this scorer. */ @@ -94,4 +107,92 @@ public abstract class Scorer extends DocIdSetIterator { */ public abstract float score() throws IOException; + /** Returns number of matches for the current document. + * This returns a float (not int) because + * SloppyPhraseScorer discounts its freq according to how + * "sloppy" the match was. + * + * @lucene.experimental */ + public float freq() throws IOException { + throw new UnsupportedOperationException(this + " does not implement freq()"); + } + + /** + * A callback to gather information from a scorer and its sub-scorers. Each + * the top-level scorer as well as each of its sub-scorers are passed to + * either one of the visit methods depending on their boolean relationship in + * the query. + * @lucene.experimental + */ + public static abstract class ScorerVisitor

{ + /** + * Invoked for all optional scorer + * + * @param parent the parent query of the child query or null if the child is a top-level query + * @param child the query of the currently visited scorer + * @param scorer the current scorer + */ + public void visitOptional(P parent, C child, S scorer) {} + + /** + * Invoked for all required scorer + * + * @param parent the parent query of the child query or null if the child is a top-level query + * @param child the query of the currently visited scorer + * @param scorer the current scorer + */ + public void visitRequired(P parent, C child, S scorer) {} + + /** + * Invoked for all prohibited scorer + * + * @param parent the parent query of the child query or null if the child is a top-level query + * @param child the query of the currently visited scorer + * @param scorer the current scorer + */ + public void visitProhibited(P parent, C child, S scorer) {} + } + + /** + * Expert: call this to gather details for all sub-scorers for this query. + * This can be used, in conjunction with a custom {@link Collector} to gather + * details about how each sub-query matched the current hit. + * + * @param visitor a callback executed for each sub-scorer + * @lucene.experimental + */ + public void visitScorers(ScorerVisitor visitor) { + visitSubScorers(null, Occur.MUST/*must id default*/, visitor); + } + + /** + * {@link Scorer} subclasses should implement this method if the subclass + * itself contains multiple scorers to support gathering details for + * sub-scorers via {@link ScorerVisitor} + *

+ * Note: this method will throw {@link UnsupportedOperationException} if no + * associated {@link Weight} instance is provided to + * {@link #Scorer(Similarity, Weight)} + *

+ * + * @lucene.experimental + */ + protected void visitSubScorers(Query parent, Occur relationship, + ScorerVisitor visitor) { + if (weight == null) + throw new UnsupportedOperationException(); + + final Query q = weight.getQuery(); + switch (relationship) { + case MUST: + visitor.visitRequired(parent, q, this); + break; + case MUST_NOT: + visitor.visitProhibited(parent, q, this); + break; + case SHOULD: + visitor.visitOptional(parent, q, this); + break; + } + } } diff --git a/lucene/src/java/org/apache/lucene/search/TermScorer.java b/lucene/src/java/org/apache/lucene/search/TermScorer.java index efec1ab6a3e..13329ca910a 100644 --- a/lucene/src/java/org/apache/lucene/search/TermScorer.java +++ b/lucene/src/java/org/apache/lucene/search/TermScorer.java @@ -20,11 +20,11 @@ package org.apache.lucene.search; import java.io.IOException; import org.apache.lucene.index.DocsEnum; +import org.apache.lucene.search.BooleanClause.Occur; /** Expert: A Scorer for documents matching a Term. */ final class TermScorer extends Scorer { - private Weight weight; private DocsEnum docsEnum; private byte[] norms; private float weightValue; @@ -54,9 +54,8 @@ final class TermScorer extends Scorer { * The field norms of the document fields for the Term. */ TermScorer(Weight weight, DocsEnum td, Similarity similarity, byte[] norms) { - super(similarity); + super(similarity, weight); - this.weight = weight; this.docsEnum = td; this.norms = norms; this.weightValue = weight.getValue(); @@ -103,6 +102,11 @@ final class TermScorer extends Scorer { return doc; } + @Override + public float freq() { + return freq; + } + /** * Advances to the next document matching the query.
* The iterator over the matching documents is buffered using @@ -172,4 +176,5 @@ final class TermScorer extends Scorer { /** Returns a string representation of this TermScorer. */ @Override public String toString() { return "scorer(" + weight + ")"; } + } diff --git a/lucene/src/java/org/apache/lucene/search/function/CustomScoreQuery.java b/lucene/src/java/org/apache/lucene/search/function/CustomScoreQuery.java index 0cad1e05818..683334ce98f 100755 --- a/lucene/src/java/org/apache/lucene/search/function/CustomScoreQuery.java +++ b/lucene/src/java/org/apache/lucene/search/function/CustomScoreQuery.java @@ -300,7 +300,7 @@ public class CustomScoreQuery extends Query { // constructor private CustomScorer(Similarity similarity, IndexReader reader, CustomWeight w, Scorer subQueryScorer, Scorer[] valSrcScorers) throws IOException { - super(similarity); + super(similarity,w); this.qWeight = w.getValue(); this.subQueryScorer = subQueryScorer; this.valSrcScorers = valSrcScorers; diff --git a/lucene/src/java/org/apache/lucene/search/function/ValueSourceQuery.java b/lucene/src/java/org/apache/lucene/search/function/ValueSourceQuery.java index 404d72f55fa..b76ecf49cc9 100644 --- a/lucene/src/java/org/apache/lucene/search/function/ValueSourceQuery.java +++ b/lucene/src/java/org/apache/lucene/search/function/ValueSourceQuery.java @@ -134,7 +134,7 @@ public class ValueSourceQuery extends Query { // constructor private ValueSourceScorer(Similarity similarity, IndexReader reader, ValueSourceWeight w) throws IOException { - super(similarity); + super(similarity,w); qWeight = w.getValue(); // this is when/where the values are first created. vals = valSrc.getValues(reader); diff --git a/lucene/src/java/org/apache/lucene/search/spans/SpanScorer.java b/lucene/src/java/org/apache/lucene/search/spans/SpanScorer.java index e44fcbc836c..1d2d9f50bca 100644 --- a/lucene/src/java/org/apache/lucene/search/spans/SpanScorer.java +++ b/lucene/src/java/org/apache/lucene/search/spans/SpanScorer.java @@ -29,7 +29,6 @@ import org.apache.lucene.search.Similarity; */ public class SpanScorer extends Scorer { protected Spans spans; - protected Weight weight; protected byte[] norms; protected float value; @@ -40,10 +39,9 @@ public class SpanScorer extends Scorer { protected SpanScorer(Spans spans, Weight weight, Similarity similarity, byte[] norms) throws IOException { - super(similarity); + super(similarity, weight); this.spans = spans; this.norms = norms; - this.weight = weight; this.value = weight.getValue(); if (this.spans.next()) { doc = -1; @@ -97,6 +95,11 @@ public class SpanScorer extends Scorer { float raw = getSimilarity().tf(freq) * value; // raw score return norms == null? raw : raw * getSimilarity().decodeNormValue(norms[doc]); // normalize } + + @Override + public float freq() throws IOException { + return freq; + } /** This method is no longer an official member of {@link Scorer}, * but it is needed by SpanWeight to build an explanation. */ diff --git a/lucene/src/test/org/apache/lucene/search/TestBooleanScorer.java b/lucene/src/test/org/apache/lucene/search/TestBooleanScorer.java index 394a8bd5a9f..aea311d28f2 100644 --- a/lucene/src/test/org/apache/lucene/search/TestBooleanScorer.java +++ b/lucene/src/test/org/apache/lucene/search/TestBooleanScorer.java @@ -90,7 +90,7 @@ public class TestBooleanScorer extends LuceneTestCase } }}; - BooleanScorer bs = new BooleanScorer(sim, 1, Arrays.asList(scorers), null, scorers.length); + BooleanScorer bs = new BooleanScorer(null, sim, 1, Arrays.asList(scorers), null, scorers.length); assertEquals("should have received 3000", 3000, bs.nextDoc()); assertEquals("should have received NO_MORE_DOCS", DocIdSetIterator.NO_MORE_DOCS, bs.nextDoc()); diff --git a/lucene/src/test/org/apache/lucene/search/TestSubScorerFreqs.java b/lucene/src/test/org/apache/lucene/search/TestSubScorerFreqs.java new file mode 100644 index 00000000000..8341d21937b --- /dev/null +++ b/lucene/src/test/org/apache/lucene/search/TestSubScorerFreqs.java @@ -0,0 +1,226 @@ +package org.apache.lucene.search; + +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.document.*; +import org.apache.lucene.index.*; +import org.apache.lucene.util.*; +import org.apache.lucene.search.BooleanClause.Occur; +import org.apache.lucene.search.Scorer.ScorerVisitor; +import org.apache.lucene.store.*; + +import java.util.*; +import java.io.*; + +import org.junit.Test; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import static org.junit.Assert.*; + +public class TestSubScorerFreqs extends LuceneTestCaseJ4 { + + private static Directory dir; + private static IndexSearcher s; + + @BeforeClass + public static void makeIndex() throws Exception { + dir = new RAMDirectory(); + RandomIndexWriter w = new RandomIndexWriter( + newStaticRandom(TestSubScorerFreqs.class), dir); + // make sure we have more than one segment occationally + for (int i = 0; i < 31 * RANDOM_MULTIPLIER; i++) { + Document doc = new Document(); + doc.add(new Field("f", "a b c d b c d c d d", Field.Store.NO, + Field.Index.ANALYZED)); + w.addDocument(doc); + + doc = new Document(); + doc.add(new Field("f", "a b c d", Field.Store.NO, Field.Index.ANALYZED)); + w.addDocument(doc); + } + + s = new IndexSearcher(w.getReader()); + w.close(); + } + + @AfterClass + public static void finish() throws Exception { + s.getIndexReader().close(); + s.close(); + dir.close(); + } + + private static class CountingCollector extends Collector { + private final Collector other; + private int docBase; + + public final Map> docCounts = new HashMap>(); + + private final Map subScorers = new HashMap(); + private final ScorerVisitor visitor = new MockScorerVisitor(); + private final EnumSet collect; + + private class MockScorerVisitor extends ScorerVisitor { + + @Override + public void visitOptional(Query parent, Query child, Scorer scorer) { + if (collect.contains(Occur.SHOULD)) + subScorers.put(child, scorer); + } + + @Override + public void visitProhibited(Query parent, Query child, Scorer scorer) { + if (collect.contains(Occur.MUST_NOT)) + subScorers.put(child, scorer); + } + + @Override + public void visitRequired(Query parent, Query child, Scorer scorer) { + if (collect.contains(Occur.MUST)) + subScorers.put(child, scorer); + } + + } + + public CountingCollector(Collector other) { + this(other, EnumSet.allOf(Occur.class)); + } + + public CountingCollector(Collector other, EnumSet collect) { + this.other = other; + this.collect = collect; + } + + @Override + public void setScorer(Scorer scorer) throws IOException { + other.setScorer(scorer); + scorer.visitScorers(visitor); + } + + @Override + public void collect(int doc) throws IOException { + final Map freqs = new HashMap(); + for (Map.Entry ent : subScorers.entrySet()) { + Scorer value = ent.getValue(); + int matchId = value.docID(); + freqs.put(ent.getKey(), matchId == doc ? value.freq() : 0.0f); + } + docCounts.put(doc + docBase, freqs); + other.collect(doc); + } + + @Override + public void setNextReader(IndexReader reader, int docBase) + throws IOException { + this.docBase = docBase; + other.setNextReader(reader, docBase); + } + + @Override + public boolean acceptsDocsOutOfOrder() { + return other.acceptsDocsOutOfOrder(); + } + } + + private static final float FLOAT_TOLERANCE = 0.00001F; + + @Test + public void testTermQuery() throws Exception { + TermQuery q = new TermQuery(new Term("f", "d")); + CountingCollector c = new CountingCollector(TopScoreDocCollector.create(10, + true)); + s.search(q, null, c); + final int maxDocs = s.maxDoc(); + assertEquals(maxDocs, c.docCounts.size()); + for (int i = 0; i < maxDocs; i++) { + Map doc0 = c.docCounts.get(i); + assertEquals(1, doc0.size()); + assertEquals(4.0F, doc0.get(q), FLOAT_TOLERANCE); + + Map doc1 = c.docCounts.get(++i); + assertEquals(1, doc1.size()); + assertEquals(1.0F, doc1.get(q), FLOAT_TOLERANCE); + } + } + + @SuppressWarnings("unchecked") + @Test + public void testBooleanQuery() throws Exception { + TermQuery aQuery = new TermQuery(new Term("f", "a")); + TermQuery dQuery = new TermQuery(new Term("f", "d")); + TermQuery cQuery = new TermQuery(new Term("f", "c")); + TermQuery yQuery = new TermQuery(new Term("f", "y")); + + BooleanQuery query = new BooleanQuery(); + BooleanQuery inner = new BooleanQuery(); + + inner.add(cQuery, Occur.SHOULD); + inner.add(yQuery, Occur.MUST_NOT); + query.add(inner, Occur.MUST); + query.add(aQuery, Occur.MUST); + query.add(dQuery, Occur.MUST); + EnumSet[] occurList = new EnumSet[] {EnumSet.of(Occur.MUST), EnumSet.of(Occur.MUST, Occur.SHOULD)}; + for (EnumSet occur : occurList) { + CountingCollector c = new CountingCollector(TopScoreDocCollector.create( + 10, true), occur); + s.search(query, null, c); + final int maxDocs = s.maxDoc(); + assertEquals(maxDocs, c.docCounts.size()); + boolean includeOptional = occur.contains(Occur.SHOULD); + for (int i = 0; i < maxDocs; i++) { + Map doc0 = c.docCounts.get(i); + assertEquals(includeOptional ? 5 : 4, doc0.size()); + assertEquals(1.0F, doc0.get(aQuery), FLOAT_TOLERANCE); + assertEquals(4.0F, doc0.get(dQuery), FLOAT_TOLERANCE); + if (includeOptional) + assertEquals(3.0F, doc0.get(cQuery), FLOAT_TOLERANCE); + + Map doc1 = c.docCounts.get(++i); + assertEquals(includeOptional ? 5 : 4, doc1.size()); + assertEquals(1.0F, doc1.get(aQuery), FLOAT_TOLERANCE); + assertEquals(1.0F, doc1.get(dQuery), FLOAT_TOLERANCE); + if (includeOptional) + assertEquals(1.0F, doc1.get(cQuery), FLOAT_TOLERANCE); + + } + } + } + + @Test + public void testPhraseQuery() throws Exception { + PhraseQuery q = new PhraseQuery(); + q.add(new Term("f", "b")); + q.add(new Term("f", "c")); + CountingCollector c = new CountingCollector(TopScoreDocCollector.create(10, + true)); + s.search(q, null, c); + final int maxDocs = s.maxDoc(); + assertEquals(maxDocs, c.docCounts.size()); + for (int i = 0; i < maxDocs; i++) { + Map doc0 = c.docCounts.get(i); + assertEquals(1, doc0.size()); + assertEquals(2.0F, doc0.get(q), FLOAT_TOLERANCE); + + Map doc1 = c.docCounts.get(++i); + assertEquals(1, doc1.size()); + assertEquals(1.0F, doc1.get(q), FLOAT_TOLERANCE); + } + + } +}