Merge pull request #14202 from jpountz/enhancement/min_score

Improve `min_score` implementation.
This commit is contained in:
Adrien Grand 2015-10-26 14:10:36 +01:00
commit 1804e7d9e8
16 changed files with 545 additions and 221 deletions

View File

@ -1,68 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.common.lucene;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.*;
import java.io.IOException;
/**
*
*/
public class MinimumScoreCollector extends SimpleCollector {
private final Collector collector;
private final float minimumScore;
private Scorer scorer;
private LeafCollector leafCollector;
public MinimumScoreCollector(Collector collector, float minimumScore) {
this.collector = collector;
this.minimumScore = minimumScore;
}
@Override
public void setScorer(Scorer scorer) throws IOException {
if (!(scorer instanceof ScoreCachingWrappingScorer)) {
scorer = new ScoreCachingWrappingScorer(scorer);
}
this.scorer = scorer;
leafCollector.setScorer(scorer);
}
@Override
public void collect(int doc) throws IOException {
if (scorer.score() >= minimumScore) {
leafCollector.collect(doc);
}
}
@Override
public void doSetNextReader(LeafReaderContext context) throws IOException {
leafCollector = collector.getLeafCollector(context);
}
@Override
public boolean needsScores() {
return true;
}
}

View File

@ -0,0 +1,285 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.common.lucene.search;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreCachingWrappingScorer;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import java.io.IOException;
import java.util.Objects;
import java.util.Set;
/**
* A {@link Query} wrapper that only emits as hits documents whose score is
* above a given threshold. This query only really makes sense for queries
* whose score is computed manually, like eg. function score queries.
*/
public final class MinScoreQuery extends Query {
private final Query query;
private final float minScore;
private final IndexSearcher searcher;
/** Sole constructor. */
public MinScoreQuery(Query query, float minScore) {
this(query, minScore, null);
}
MinScoreQuery(Query query, float minScore, IndexSearcher searcher) {
this.query = query;
this.minScore = minScore;
this.searcher = searcher;
}
/** Return the wrapped query. */
public Query getQuery() {
return query;
}
/** Return the minimum score. */
public float getMinScore() {
return minScore;
}
@Override
public String toString(String field) {
return getClass().getSimpleName() + "(" + query.toString(field) + ", minScore=" + minScore + ")";
}
@Override
public boolean equals(Object obj) {
if (super.equals(obj) == false) {
return false;
}
MinScoreQuery that = (MinScoreQuery) obj;
return minScore == that.minScore
&& searcher == that.searcher
&& query.equals(that.query);
}
@Override
public int hashCode() {
return 31 * super.hashCode() + Objects.hash(query, minScore, searcher);
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
if (getBoost() != 1f) {
return super.rewrite(reader);
}
Query rewritten = query.rewrite(reader);
if (rewritten != query) {
return new MinScoreQuery(rewritten, minScore);
}
return super.rewrite(reader);
}
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores) throws IOException {
final Weight weight = searcher.createWeight(query, true);
// We specialize the query for the provided index searcher because it
// can't really be cached as the documents that match depend on the
// Similarity implementation and the top-level reader.
final Query key = new MinScoreQuery(query, minScore, searcher);
return new Weight(key) {
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
Scorer scorer = weight.scorer(context);
if (scorer == null) {
return null;
}
return new MinScoreScorer(this, scorer, minScore);
}
@Override
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
BulkScorer bulkScorer = weight.bulkScorer(context);
if (bulkScorer == null) {
return null;
}
return new MinScoreBulkScorer(bulkScorer, minScore);
}
@Override
public void normalize(float norm, float boost) {
weight.normalize(norm, boost);
}
@Override
public float getValueForNormalization() throws IOException {
return weight.getValueForNormalization();
}
@Override
public void extractTerms(Set<Term> terms) {
weight.extractTerms(terms);
}
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
Explanation expl = weight.explain(context, doc);
if (expl.isMatch() == false || expl.getValue() >= minScore) {
return expl;
} else {
return Explanation.noMatch("Min score is less than the configured min score=" + minScore, expl);
}
}
};
}
private static class MinScoreScorer extends Scorer {
private final Scorer scorer;
private final float minScore;
private float score;
protected MinScoreScorer(Weight weight, Scorer scorer, float minScore) {
super(weight);
this.scorer = scorer;
this.minScore = minScore;
}
@Override
public float score() throws IOException {
return score;
}
@Override
public int freq() throws IOException {
return scorer.freq();
}
@Override
public int docID() {
return scorer.docID();
}
@Override
public int nextDoc() throws IOException {
return doNext(scorer.nextDoc());
}
@Override
public int advance(int target) throws IOException {
return doNext(scorer.advance(target));
}
private int doNext(int doc) throws IOException {
for (; doc != NO_MORE_DOCS; doc = scorer.nextDoc()) {
final float score = scorer.score();
if (score >= minScore) {
this.score = score;
return doc;
}
}
return NO_MORE_DOCS;
}
@Override
public TwoPhaseIterator asTwoPhaseIterator() {
final TwoPhaseIterator twoPhase = scorer.asTwoPhaseIterator();
final DocIdSetIterator approximation = twoPhase == null
? scorer
: twoPhase.approximation();
return new TwoPhaseIterator(approximation) {
@Override
public boolean matches() throws IOException {
if (twoPhase != null && twoPhase.matches() == false) {
return false;
}
score = scorer.score();
return score >= minScore;
}
};
}
@Override
public long cost() {
return scorer.cost();
}
}
private static class MinScoreBulkScorer extends BulkScorer {
private final BulkScorer bulkScorer;
private final float minScore;
public MinScoreBulkScorer(BulkScorer bulkScorer, float minScore) {
this.bulkScorer = bulkScorer;
this.minScore = minScore;
}
@Override
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
return bulkScorer.score(new MinScoreLeafCollector(collector, minScore), acceptDocs, min, max);
}
@Override
public long cost() {
return bulkScorer.cost();
}
}
private static class MinScoreLeafCollector implements LeafCollector {
private final LeafCollector collector;
private final float minScore;
private Scorer scorer;
public MinScoreLeafCollector(LeafCollector collector, float minScore) {
this.collector = collector;
this.minScore = minScore;
}
@Override
public void setScorer(Scorer scorer) throws IOException {
// we will need scores at least once, maybe more due to the wrapped
// collector, so we wrap with a ScoreCachingWrappingScorer
if (scorer instanceof ScoreCachingWrappingScorer == false) {
scorer = new ScoreCachingWrappingScorer(scorer);
}
this.scorer = scorer;
collector.setScorer(scorer);
}
@Override
public void collect(int doc) throws IOException {
if (scorer.score() >= minScore) {
collector.collect(doc);
}
}
}
}

View File

@ -30,21 +30,12 @@ abstract class CustomBoostFactorScorer extends Scorer {
final float maxBoost;
final CombineFunction scoreCombiner;
Float minScore;
NextDoc nextDoc;
CustomBoostFactorScorer(Weight w, Scorer scorer, float maxBoost, CombineFunction scoreCombiner, Float minScore)
CustomBoostFactorScorer(Weight w, Scorer scorer, float maxBoost, CombineFunction scoreCombiner)
throws IOException {
super(w);
if (minScore == null) {
nextDoc = new AnyNextDoc();
} else {
nextDoc = new MinScoreNextDoc();
}
this.scorer = scorer;
this.maxBoost = maxBoost;
this.scoreCombiner = scoreCombiner;
this.minScore = minScore;
}
@Override
@ -54,20 +45,16 @@ abstract class CustomBoostFactorScorer extends Scorer {
@Override
public int advance(int target) throws IOException {
return nextDoc.advance(target);
return scorer.advance(target);
}
@Override
public int nextDoc() throws IOException {
return nextDoc.nextDoc();
return scorer.nextDoc();
}
public abstract float innerScore() throws IOException;
@Override
public float score() throws IOException {
return nextDoc.score();
}
public abstract float score() throws IOException;
@Override
public int freq() throws IOException {
@ -79,64 +66,4 @@ abstract class CustomBoostFactorScorer extends Scorer {
return scorer.cost();
}
public interface NextDoc {
public int advance(int target) throws IOException;
public int nextDoc() throws IOException;
public float score() throws IOException;
}
public class MinScoreNextDoc implements NextDoc {
float currentScore = Float.MAX_VALUE * -1.0f;
@Override
public int nextDoc() throws IOException {
int doc;
do {
doc = scorer.nextDoc();
if (doc == NO_MORE_DOCS) {
return doc;
}
currentScore = innerScore();
} while (currentScore < minScore);
return doc;
}
@Override
public float score() throws IOException {
return currentScore;
}
@Override
public int advance(int target) throws IOException {
int doc = scorer.advance(target);
if (doc == NO_MORE_DOCS) {
return doc;
}
currentScore = innerScore();
if (currentScore < minScore) {
return scorer.nextDoc();
}
return doc;
}
}
public class AnyNextDoc implements NextDoc {
@Override
public int nextDoc() throws IOException {
return scorer.nextDoc();
}
@Override
public float score() throws IOException {
return innerScore();
}
@Override
public int advance(int target) throws IOException {
return scorer.advance(target);
}
}
}

View File

@ -100,17 +100,15 @@ public class FiltersFunctionScoreQuery extends Query {
final FilterFunction[] filterFunctions;
final ScoreMode scoreMode;
final float maxBoost;
private final Float minScore;
final protected CombineFunction combineFunction;
public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions, float maxBoost, Float minScore, CombineFunction combineFunction) {
public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions, float maxBoost, CombineFunction combineFunction) {
this.subQuery = subQuery;
this.scoreMode = scoreMode;
this.filterFunctions = filterFunctions;
this.maxBoost = maxBoost;
this.combineFunction = combineFunction;
this.minScore = minScore;
}
public Query getSubQuery() {
@ -195,7 +193,7 @@ public class FiltersFunctionScoreQuery extends Query {
Scorer filterScorer = filterWeights[i].scorer(context);
docSets[i] = Lucene.asSequentialAccessBits(context.reader().maxDoc(), filterScorer);
}
return new FiltersFunctionFactorScorer(this, subQueryScorer, scoreMode, filterFunctions, maxBoost, functions, docSets, combineFunction, minScore, needsScores);
return new FiltersFunctionFactorScorer(this, subQueryScorer, scoreMode, filterFunctions, maxBoost, functions, docSets, combineFunction, needsScores);
}
@Override
@ -244,8 +242,8 @@ public class FiltersFunctionScoreQuery extends Query {
private final boolean needsScores;
private FiltersFunctionFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreMode scoreMode, FilterFunction[] filterFunctions,
float maxBoost, LeafScoreFunction[] functions, Bits[] docSets, CombineFunction scoreCombiner, Float minScore, boolean needsScores) throws IOException {
super(w, scorer, maxBoost, scoreCombiner, minScore);
float maxBoost, LeafScoreFunction[] functions, Bits[] docSets, CombineFunction scoreCombiner, boolean needsScores) throws IOException {
super(w, scorer, maxBoost, scoreCombiner);
this.scoreMode = scoreMode;
this.filterFunctions = filterFunctions;
this.functions = functions;
@ -254,7 +252,7 @@ public class FiltersFunctionScoreQuery extends Query {
}
@Override
public float innerScore() throws IOException {
public float score() throws IOException {
int docId = scorer.docID();
// Even if the weight is created with needsScores=false, it might
// be costly to call score(), so we explicitly check if scores
@ -351,12 +349,12 @@ public class FiltersFunctionScoreQuery extends Query {
}
FiltersFunctionScoreQuery other = (FiltersFunctionScoreQuery) o;
return Objects.equals(this.subQuery, other.subQuery) && this.maxBoost == other.maxBoost &&
Objects.equals(this.combineFunction, other.combineFunction) && Objects.equals(this.minScore, other.minScore) &&
Objects.equals(this.combineFunction, other.combineFunction) &&
Arrays.equals(this.filterFunctions, other.filterFunctions);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), subQuery, maxBoost, combineFunction, minScore, filterFunctions);
return Objects.hash(super.hashCode(), subQuery, maxBoost, combineFunction, filterFunctions);
}
}

View File

@ -40,13 +40,11 @@ public class FunctionScoreQuery extends Query {
final ScoreFunction function;
final float maxBoost;
final CombineFunction combineFunction;
private Float minScore;
public FunctionScoreQuery(Query subQuery, ScoreFunction function, Float minScore, CombineFunction combineFunction, float maxBoost) {
public FunctionScoreQuery(Query subQuery, ScoreFunction function, CombineFunction combineFunction, float maxBoost) {
this.subQuery = subQuery;
this.function = function;
this.combineFunction = combineFunction;
this.minScore = minScore;
this.maxBoost = maxBoost;
}
@ -133,7 +131,7 @@ public class FunctionScoreQuery extends Query {
if (function != null) {
leafFunction = function.getLeafScoreFunction(context);
}
return new FunctionFactorScorer(this, subQueryScorer, leafFunction, maxBoost, combineFunction, minScore, needsScores);
return new FunctionFactorScorer(this, subQueryScorer, leafFunction, maxBoost, combineFunction, needsScores);
}
@Override
@ -156,15 +154,15 @@ public class FunctionScoreQuery extends Query {
private final LeafScoreFunction function;
private final boolean needsScores;
private FunctionFactorScorer(CustomBoostFactorWeight w, Scorer scorer, LeafScoreFunction function, float maxBoost, CombineFunction scoreCombiner, Float minScore, boolean needsScores)
private FunctionFactorScorer(CustomBoostFactorWeight w, Scorer scorer, LeafScoreFunction function, float maxBoost, CombineFunction scoreCombiner, boolean needsScores)
throws IOException {
super(w, scorer, maxBoost, scoreCombiner, minScore);
super(w, scorer, maxBoost, scoreCombiner);
this.function = function;
this.needsScores = needsScores;
}
@Override
public float innerScore() throws IOException {
public float score() throws IOException {
// Even if the weight is created with needsScores=false, it might
// be costly to call score(), so we explicitly check if scores
// are needed
@ -197,11 +195,11 @@ public class FunctionScoreQuery extends Query {
FunctionScoreQuery other = (FunctionScoreQuery) o;
return Objects.equals(this.subQuery, other.subQuery) && Objects.equals(this.function, other.function)
&& Objects.equals(this.combineFunction, other.combineFunction)
&& Objects.equals(this.minScore, other.minScore) && this.maxBoost == other.maxBoost;
&& this.maxBoost == other.maxBoost;
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), subQuery.hashCode(), function, combineFunction, minScore, maxBoost);
return Objects.hash(super.hashCode(), subQuery.hashCode(), function, combineFunction, maxBoost);
}
}

View File

@ -24,6 +24,7 @@ import org.apache.lucene.search.Query;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.lucene.search.MinScoreQuery;
import org.elasticsearch.common.lucene.search.function.CombineFunction;
import org.elasticsearch.common.lucene.search.function.FiltersFunctionScoreQuery;
import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery;
@ -312,10 +313,17 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
combineFunction = DEFAULT_BOOST_MODE;
}
}
return new FunctionScoreQuery(query, function, minScore, combineFunction, maxBoost);
query = new FunctionScoreQuery(query, function, combineFunction, maxBoost);
} else {
// in all other cases we create a FiltersFunctionScoreQuery
query = new FiltersFunctionScoreQuery(query, scoreMode, filterFunctions, maxBoost, boostMode == null ? DEFAULT_BOOST_MODE : boostMode);
}
// in all other cases we create a FiltersFunctionScoreQuery
return new FiltersFunctionScoreQuery(query, scoreMode, filterFunctions, maxBoost, minScore, boostMode == null ? DEFAULT_BOOST_MODE : boostMode);
if (minScore != null) {
query = new MinScoreQuery(query, minScore);
}
return query;
}
/**

View File

@ -478,11 +478,6 @@ public class PercolateContext extends SearchContext {
throw new UnsupportedOperationException();
}
@Override
public Float minimumScore() {
return null;
}
@Override
public SearchContext sort(Sort sort) {
this.sort = sort;

View File

@ -27,6 +27,7 @@ import org.elasticsearch.cache.recycler.PageCacheRecycler;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.lucene.search.MinScoreQuery;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery;
import org.elasticsearch.common.lucene.search.function.WeightFactorFunction;
@ -215,6 +216,9 @@ public class DefaultSearchContext extends SearchContext {
parsedQuery(new ParsedQuery(filtered, parsedQuery()));
}
}
if (minimumScore != null) {
this.query = new MinScoreQuery(query, minimumScore);
}
try {
this.query = searcher().rewrite(this.query);
} catch (IOException e) {
@ -491,11 +495,6 @@ public class DefaultSearchContext extends SearchContext {
return this;
}
@Override
public Float minimumScore() {
return this.minimumScore;
}
@Override
public SearchContext sort(Sort sort) {
this.sort = sort;

View File

@ -317,11 +317,6 @@ public abstract class FilteredSearchContext extends SearchContext {
return in.minimumScore(minimumScore);
}
@Override
public Float minimumScore() {
return in.minimumScore();
}
@Override
public SearchContext sort(Sort sort) {
return in.sort(sort);

View File

@ -230,8 +230,6 @@ public abstract class SearchContext extends DelegatingHasContextAndHeaders imple
public abstract SearchContext minimumScore(float minimumScore);
public abstract Float minimumScore();
public abstract SearchContext sort(Sort sort);
public abstract Sort sort();

View File

@ -44,7 +44,6 @@ import org.apache.lucene.search.Weight;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.lucene.MinimumScoreCollector;
import org.elasticsearch.common.lucene.search.FilteredCollector;
import org.elasticsearch.search.SearchParseElement;
import org.elasticsearch.search.SearchPhase;
@ -277,11 +276,6 @@ public class QueryPhase implements SearchPhase {
allCollectors.addAll(searchContext.queryCollectors().values());
collector = MultiCollector.wrap(allCollectors);
// apply the minimum score after multi collector so we filter aggs as well
if (searchContext.minimumScore() != null) {
collector = new MinimumScoreCollector(collector, searchContext.minimumScore());
}
if (collector.getClass() == TotalHitCountCollector.class) {
// Optimize counts in simple cases to return in constant time
// instead of using a collector

View File

@ -0,0 +1,195 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.common.lucene.search;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.MockTokenizer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiReader;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.*;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.TestUtil;
import org.elasticsearch.test.ESTestCase;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import java.io.IOException;
public class MinScoreQueryTests extends ESTestCase {
private static String[] terms;
private static Directory dir;
private static IndexReader r;
private static IndexSearcher s;
@BeforeClass
public static void before() throws IOException {
dir = newDirectory();
terms = new String[TestUtil.nextInt(random(), 2, 15)];
for (int i = 0; i < terms.length; ++i) {
terms[i] = TestUtil.randomSimpleString(random());
}
final int numDocs = TestUtil.nextInt(random(), 1, 200);
Analyzer analyzer = new MockAnalyzer(random(), MockTokenizer.WHITESPACE, false);
RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer));
for (int i = 0; i < numDocs; ++i) {
StringBuilder value = new StringBuilder();
final int numTerms = random().nextInt(10);
for (int j = 0; j < numTerms; ++j) {
if (j > 0) {
value.append(' ');
}
// simulate zipf distribution
String term = terms[TestUtil.nextInt(random(), 0, TestUtil.nextInt(random(), 0, terms.length - 1))];
value.append(term);
}
Document doc = new Document();
doc.add(new TextField("field", value.toString(), Store.NO));
w.addDocument(doc);
}
r = w.getReader();
s = newSearcher(r);
w.close();
}
@AfterClass
public static void after() throws IOException {
IOUtils.close(r, dir);
terms = null;
r = null;
s = null;
dir = null;
}
private static Term randomTerm() {
return new Term("field", terms[random().nextInt(terms.length)]);
}
public void testEquals() throws IOException {
QueryUtils.checkEqual(new MinScoreQuery(new MatchAllDocsQuery(), 0.5f), new MinScoreQuery(new MatchAllDocsQuery(), 0.5f));
QueryUtils.checkUnequal(new MinScoreQuery(new MatchAllDocsQuery(), 0.5f), new MinScoreQuery(new MatchAllDocsQuery(), 0.3f));
QueryUtils.checkUnequal(new MinScoreQuery(new MatchAllDocsQuery(), 0.5f), new MinScoreQuery(new MatchNoDocsQuery(), 0.5f));
IndexSearcher s1 = new IndexSearcher(new MultiReader());
IndexSearcher s2 = new IndexSearcher(new MultiReader());
QueryUtils.checkEqual(new MinScoreQuery(new MatchAllDocsQuery(), 0.5f, s1), new MinScoreQuery(new MatchAllDocsQuery(), 0.5f, s1));
QueryUtils.checkUnequal(new MinScoreQuery(new MatchAllDocsQuery(), 0.5f, s1), new MinScoreQuery(new MatchAllDocsQuery(), 0.5f, s2));
}
/** pick a min score which is in the range of scores produced by the query */
private float randomMinScore(Query query) throws IOException {
TopDocs topDocs = s.search(query, 2);
float base = 0;
switch (topDocs.totalHits) {
case 0:
break;
case 1:
base = topDocs.scoreDocs[0].score;
break;
default:
base = (topDocs.scoreDocs[0].score + topDocs.scoreDocs[1].score) / 2;
break;
}
float delta = random().nextFloat() - 0.5f;
return base + delta;
}
private void assertMinScoreEquivalence(Query query, Query minScoreQuery, float minScore) throws IOException {
final TopDocs topDocs = s.search(query, s.getIndexReader().maxDoc());
final TopDocs minScoreTopDocs = s.search(minScoreQuery, s.getIndexReader().maxDoc());
int j = 0;
for (int i = 0; i < topDocs.totalHits; ++i) {
if (topDocs.scoreDocs[i].score >= minScore) {
assertEquals(topDocs.scoreDocs[i].doc, minScoreTopDocs.scoreDocs[j].doc);
assertEquals(topDocs.scoreDocs[i].score, minScoreTopDocs.scoreDocs[j].score, 1e-5f);
j++;
}
}
assertEquals(minScoreTopDocs.totalHits, j);
}
public void testBasics() throws Exception {
final int iters = 5;
for (int iter = 0; iter < iters; ++iter) {
Term term = randomTerm();
Query query = new TermQuery(term);
float minScore = randomMinScore(query);
Query minScoreQuery = new MinScoreQuery(query, minScore);
assertMinScoreEquivalence(query, minScoreQuery, minScore);
}
}
public void testFilteredApproxQuery() throws Exception {
// same as testBasics but with a query that exposes approximations
final int iters = 5;
for (int iter = 0; iter < iters; ++iter) {
Term term = randomTerm();
Query query = new TermQuery(term);
float minScore = randomMinScore(query);
Query minScoreQuery = new MinScoreQuery(new RandomApproximationQuery(query, random()), minScore);
assertMinScoreEquivalence(query, minScoreQuery, minScore);
}
}
public void testNestedInConjunction() throws Exception {
// To test scorers as well, not only bulk scorers
Term t1 = randomTerm();
Term t2 = randomTerm();
Query tq1 = new TermQuery(t1);
Query tq2 = new TermQuery(t2);
float minScore = randomMinScore(tq1);
BooleanQuery bq1 = new BooleanQuery.Builder()
.add(new MinScoreQuery(tq1, minScore), Occur.MUST)
.add(tq2, Occur.FILTER)
.build();
BooleanQuery bq2 = new BooleanQuery.Builder()
.add(tq1, Occur.MUST)
.add(tq2, Occur.FILTER)
.build();
assertMinScoreEquivalence(bq2, bq1, minScore);
}
public void testNestedInConjunctionWithApprox() throws Exception {
// same, but with approximations
Term t1 = randomTerm();
Term t2 = randomTerm();
Query tq1 = new TermQuery(t1);
Query tq2 = new TermQuery(t2);
float minScore = randomMinScore(tq1);
BooleanQuery bq1 = new BooleanQuery.Builder()
.add(new MinScoreQuery(new RandomApproximationQuery(tq1, random()), minScore), Occur.MUST)
.add(tq2, Occur.FILTER)
.build();
BooleanQuery bq2 = new BooleanQuery.Builder()
.add(tq1, Occur.MUST)
.add(tq2, Occur.FILTER)
.build();
assertMinScoreEquivalence(bq2, bq1, minScore);
}
}

View File

@ -26,11 +26,8 @@ import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.lucene.search.function.CombineFunction;
import org.elasticsearch.common.lucene.search.function.FieldValueFactorFunction;
import org.elasticsearch.common.lucene.search.function.FiltersFunctionScoreQuery;
import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery;
import org.elasticsearch.common.lucene.search.function.WeightFactorFunction;
import org.elasticsearch.common.lucene.search.MinScoreQuery;
import org.elasticsearch.common.lucene.search.function.*;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.AbstractQueryTestCase;
@ -182,6 +179,12 @@ public class FunctionScoreQueryBuilderTests extends AbstractQueryTestCase<Functi
@Override
protected void doAssertLuceneQuery(FunctionScoreQueryBuilder queryBuilder, Query query, QueryShardContext context) throws IOException {
if (queryBuilder.getMinScore() != null) {
assertThat(query, instanceOf(MinScoreQuery.class));
MinScoreQuery msq = (MinScoreQuery) query;
assertEquals(queryBuilder.getMinScore(), msq.getMinScore(), 0f);
query = msq.getQuery();
}
assertThat(query, either(instanceOf(FunctionScoreQuery.class)).or(instanceOf(FiltersFunctionScoreQuery.class)));
}

View File

@ -24,7 +24,6 @@ import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
@ -32,6 +31,7 @@ import org.apache.lucene.index.Term;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
@ -39,14 +39,8 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.lucene.search.function.CombineFunction;
import org.elasticsearch.common.lucene.search.function.FieldValueFactorFunction;
import org.elasticsearch.common.lucene.search.function.FiltersFunctionScoreQuery;
import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery;
import org.elasticsearch.common.lucene.search.function.LeafScoreFunction;
import org.elasticsearch.common.lucene.search.function.RandomScoreFunction;
import org.elasticsearch.common.lucene.search.function.ScoreFunction;
import org.elasticsearch.common.lucene.search.function.WeightFactorFunction;
import org.elasticsearch.common.lucene.search.MinScoreQuery;
import org.elasticsearch.common.lucene.search.function.*;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.fielddata.AtomicFieldData;
import org.elasticsearch.index.fielddata.AtomicNumericFieldData;
@ -318,7 +312,7 @@ public class FunctionScoreTests extends ESTestCase {
}
public Explanation getFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction scoreFunction) throws IOException {
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new TermQuery(TERM), scoreFunction, 0.0f, CombineFunction.AVG, 100);
Query functionScoreQuery = new MinScoreQuery(new FunctionScoreQuery(new TermQuery(TERM), scoreFunction, CombineFunction.AVG, 100), 0f);
Weight weight = searcher.createNormalizedWeight(functionScoreQuery, true);
Explanation explanation = weight.explain(searcher.getIndexReader().leaves().get(0), 0);
return explanation.getDetails()[1];
@ -382,22 +376,22 @@ public class FunctionScoreTests extends ESTestCase {
}
public Explanation getFiltersFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction... scoreFunctions) throws IOException {
FiltersFunctionScoreQuery filtersFunctionScoreQuery = getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode.AVG, CombineFunction.AVG, scoreFunctions);
Query filtersFunctionScoreQuery = getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode.AVG, CombineFunction.AVG, scoreFunctions);
return getExplanation(searcher, filtersFunctionScoreQuery).getDetails()[1];
}
protected Explanation getExplanation(IndexSearcher searcher, FiltersFunctionScoreQuery filtersFunctionScoreQuery) throws IOException {
protected Explanation getExplanation(IndexSearcher searcher, Query filtersFunctionScoreQuery) throws IOException {
Weight weight = searcher.createNormalizedWeight(filtersFunctionScoreQuery, true);
return weight.explain(searcher.getIndexReader().leaves().get(0), 0);
}
public FiltersFunctionScoreQuery getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode scoreMode, CombineFunction combineFunction, ScoreFunction... scoreFunctions) {
public Query getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode scoreMode, CombineFunction combineFunction, ScoreFunction... scoreFunctions) {
FiltersFunctionScoreQuery.FilterFunction[] filterFunctions = new FiltersFunctionScoreQuery.FilterFunction[scoreFunctions.length];
for (int i = 0; i < scoreFunctions.length; i++) {
filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(
new TermQuery(TERM), scoreFunctions[i]);
}
return new FiltersFunctionScoreQuery(new TermQuery(TERM), scoreMode, filterFunctions, Float.MAX_VALUE, Float.MAX_VALUE * -1, combineFunction);
return new MinScoreQuery(new FiltersFunctionScoreQuery(new TermQuery(TERM), scoreMode, filterFunctions, Float.MAX_VALUE, combineFunction), Float.MAX_VALUE * -1);
}
public void checkFiltersFunctionScoreExplanation(Explanation randomExplanation, String functionExpl, int whichFunction) {
@ -471,7 +465,7 @@ public class FunctionScoreTests extends ESTestCase {
weightFunctionStubs[i] = new WeightFactorFunction(weights[i], scoreFunctionStubs[i]);
}
FiltersFunctionScoreQuery filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
Query filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.MULTIPLY
, CombineFunction.REPLACE
, weightFunctionStubs
@ -554,10 +548,10 @@ public class FunctionScoreTests extends ESTestCase {
assertThat(explainedScore / scoreWithWeight, is(1f));
}
public void testWeightOnlyCreatesBoostFunction() throws IOException {
FunctionScoreQuery filtersFunctionScoreQueryWithWeights = new FunctionScoreQuery(new MatchAllDocsQuery(), new WeightFactorFunction(2), 0.0f, CombineFunction.MULTIPLY, 100);
public void checkWeightOnlyCreatesBoostFunction() throws IOException {
Query filtersFunctionScoreQueryWithWeights = new MinScoreQuery(new FunctionScoreQuery(new MatchAllDocsQuery(), new WeightFactorFunction(2), CombineFunction.MULTIPLY, 100), 0f);
TopDocs topDocsWithWeights = searcher.search(filtersFunctionScoreQueryWithWeights, 1);
float score = topDocsWithWeights.scoreDocs[0].score;
assertThat(score, equalTo(2.0f));
}
}
}

View File

@ -143,10 +143,6 @@ public class QueryPhaseTests extends ESTestCase {
}
public void testMinScoreDisablesCountOptimization() throws Exception {
TestSearchContext context = new TestSearchContext();
context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery()));
context.setSize(0);
final AtomicBoolean collected = new AtomicBoolean();
IndexSearcher contextSearcher = new IndexSearcher(new MultiReader()) {
protected void search(List<LeafReaderContext> leaves, Weight weight, Collector collector) throws IOException {
@ -155,11 +151,19 @@ public class QueryPhaseTests extends ESTestCase {
}
};
TestSearchContext context = new TestSearchContext();
context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery()));
context.setSize(0);
context.preProcess();
QueryPhase.execute(context, contextSearcher);
assertEquals(0, context.queryResult().topDocs().totalHits);
assertFalse(collected.get());
context = new TestSearchContext();
context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery()));
context.setSize(0);
context.minimumScore(1);
context.preProcess();
QueryPhase.execute(context, contextSearcher);
assertEquals(0, context.queryResult().topDocs().totalHits);
assertTrue(collected.get());

View File

@ -31,6 +31,7 @@ import org.elasticsearch.common.HasContextAndHeaders;
import org.elasticsearch.common.HasHeaders;
import org.elasticsearch.common.ParseFieldMatcher;
import org.elasticsearch.common.collect.ImmutableOpenMap;
import org.elasticsearch.common.lucene.search.MinScoreQuery;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.analysis.AnalysisService;
@ -129,6 +130,9 @@ public class TestSearchContext extends SearchContext {
@Override
public void preProcess() {
if (minScore != null) {
this.query = new MinScoreQuery(query, minScore);
}
}
@Override
@ -375,11 +379,6 @@ public class TestSearchContext extends SearchContext {
return this;
}
@Override
public Float minimumScore() {
return minScore;
}
@Override
public SearchContext sort(Sort sort) {
return null;