function_score: match only document with score above custom score threshold

functon_score matched each document regardless of the computed score.
This commit adds a query parameter `min_score` (-Float.MAX_VALUE default).
Documents that have a score lower than this threshold will not be mached.

closes #6952
This commit is contained in:
Britta Weber 2014-09-20 00:23:55 +02:00
parent 93b52c925d
commit 59507cf793
7 changed files with 260 additions and 84 deletions

View File

@ -49,7 +49,8 @@ given filter:
],
"max_boost": number,
"score_mode": "(multiply|max|...)",
"boost_mode": "(multiply|replace|...)"
"boost_mode": "(multiply|replace|...)",
"min_score" : number
}
--------------------------------------------------
@ -74,7 +75,7 @@ If weight is given without any other function declaration, `weight` acts as a fu
The new score can be restricted to not exceed a certain limit by setting
the `max_boost` parameter. The default for `max_boost` is FLT_MAX.
Finally, the newly computed score is combined with the score of the
The newly computed score is combined with the score of the
query. The parameter `boost_mode` defines how:
[horizontal]
@ -85,6 +86,10 @@ query. The parameter `boost_mode` defines how:
`max`:: max of query score and function score
`min`:: min of query score and function score
coming[1.5.0]
By default, modifying the score does not change which documents match. To exclude
documents that do not meet a certain score threshold the `min_score` parameter can be set to the desired score threshold.
==== Score functions

View File

@ -0,0 +1,140 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.common.lucene.search.function;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import java.io.IOException;
abstract class CustomBoostFactorScorer extends Scorer {
final float subQueryBoost;
final Scorer scorer;
final float maxBoost;
final CombineFunction scoreCombiner;
Float minScore;
NextDoc nextDoc;
CustomBoostFactorScorer(Weight w, Scorer scorer, float maxBoost, CombineFunction scoreCombiner, Float minScore)
throws IOException {
super(w);
if (minScore == null) {
nextDoc = new AnyNextDoc();
} else {
nextDoc = new MinScoreNextDoc();
}
this.subQueryBoost = w.getQuery().getBoost();
this.scorer = scorer;
this.maxBoost = maxBoost;
this.scoreCombiner = scoreCombiner;
this.minScore = minScore;
}
@Override
public int docID() {
return scorer.docID();
}
@Override
public int advance(int target) throws IOException {
return nextDoc.advance(target);
}
@Override
public int nextDoc() throws IOException {
return nextDoc.nextDoc();
}
public abstract float innerScore() throws IOException;
@Override
public float score() throws IOException {
return nextDoc.score();
}
@Override
public int freq() throws IOException {
return scorer.freq();
}
@Override
public long cost() {
return scorer.cost();
}
public interface NextDoc {
public int advance(int target) throws IOException;
public int nextDoc() throws IOException;
public float score() throws IOException;
}
public class MinScoreNextDoc implements NextDoc {
float currentScore = Float.MAX_VALUE * -1.0f;
public int nextDoc() throws IOException {
int doc;
do {
doc = scorer.nextDoc();
if (doc == NO_MORE_DOCS) {
return doc;
}
currentScore = innerScore();
} while (currentScore < minScore);
return doc;
}
@Override
public float score() throws IOException {
return currentScore;
}
public int advance(int target) throws IOException {
int doc = scorer.advance(target);
if (doc == NO_MORE_DOCS) {
return doc;
}
currentScore = innerScore();
if (currentScore < minScore) {
return scorer.nextDoc();
}
return doc;
}
}
public class AnyNextDoc implements NextDoc {
public int nextDoc() throws IOException {
return scorer.nextDoc();
}
@Override
public float score() throws IOException {
return innerScore();
}
public int advance(int target) throws IOException {
return scorer.advance(target);
}
}
}

View File

@ -78,15 +78,17 @@ public class FiltersFunctionScoreQuery extends Query {
final FilterFunction[] filterFunctions;
final ScoreMode scoreMode;
final float maxBoost;
private Float minScore;
protected CombineFunction combineFunction;
public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions, float maxBoost) {
public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions, float maxBoost, Float minScore) {
this.subQuery = subQuery;
this.scoreMode = scoreMode;
this.filterFunctions = filterFunctions;
this.maxBoost = maxBoost;
combineFunction = CombineFunction.MULT;
this.minScore = minScore;
}
public FiltersFunctionScoreQuery setCombineFunction(CombineFunction combineFunction) {
@ -163,7 +165,7 @@ public class FiltersFunctionScoreQuery extends Query {
filterFunction.function.setNextReader(context);
docSets[i] = DocIdSets.toSafeBits(context.reader(), filterFunction.filter.getDocIdSet(context, acceptDocs));
}
return new CustomBoostFactorScorer(this, subQueryScorer, scoreMode, filterFunctions, maxBoost, docSets, combineFunction);
return new FiltersFunctionFactorScorer(this, subQueryScorer, scoreMode, filterFunctions, maxBoost, docSets, combineFunction, minScore);
}
@Override
@ -245,45 +247,21 @@ public class FiltersFunctionScoreQuery extends Query {
}
}
static class CustomBoostFactorScorer extends Scorer {
private final float subQueryBoost;
private final Scorer scorer;
static class FiltersFunctionFactorScorer extends CustomBoostFactorScorer {
private final FilterFunction[] filterFunctions;
private final ScoreMode scoreMode;
private final float maxBoost;
private final Bits[] docSets;
private final CombineFunction scoreCombiner;
private CustomBoostFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreMode scoreMode, FilterFunction[] filterFunctions,
float maxBoost, Bits[] docSets, CombineFunction scoreCombiner) throws IOException {
super(w);
this.subQueryBoost = w.getQuery().getBoost();
this.scorer = scorer;
private FiltersFunctionFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreMode scoreMode, FilterFunction[] filterFunctions,
float maxBoost, Bits[] docSets, CombineFunction scoreCombiner, Float minScore) throws IOException {
super(w, scorer, maxBoost, scoreCombiner, minScore);
this.scoreMode = scoreMode;
this.filterFunctions = filterFunctions;
this.maxBoost = maxBoost;
this.docSets = docSets;
this.scoreCombiner = scoreCombiner;
}
@Override
public int docID() {
return scorer.docID();
}
@Override
public int advance(int target) throws IOException {
return scorer.advance(target);
}
@Override
public int nextDoc() throws IOException {
return scorer.nextDoc();
}
@Override
public float score() throws IOException {
public float innerScore() throws IOException {
int docId = scorer.docID();
double factor = 1.0f;
float subQueryScore = scorer.score();
@ -338,16 +316,6 @@ public class FiltersFunctionScoreQuery extends Query {
}
return scoreCombiner.combine(subQueryBoost, subQueryScore, factor, maxBoost);
}
@Override
public int freq() throws IOException {
return scorer.freq();
}
@Override
public long cost() {
return scorer.cost();
}
}
public String toString(String field) {

View File

@ -38,6 +38,14 @@ public class FunctionScoreQuery extends Query {
final ScoreFunction function;
float maxBoost = Float.MAX_VALUE;
CombineFunction combineFunction;
private Float minScore = null;
public FunctionScoreQuery(Query subQuery, ScoreFunction function, Float minScore) {
this.subQuery = subQuery;
this.function = function;
this.combineFunction = function.getDefaultScoreCombiner();
this.minScore = minScore;
}
public FunctionScoreQuery(Query subQuery, ScoreFunction function) {
this.subQuery = subQuery;
@ -121,7 +129,7 @@ public class FunctionScoreQuery extends Query {
return null;
}
function.setNextReader(context);
return new CustomBoostFactorScorer(this, subQueryScorer, function, maxBoost, combineFunction);
return new FunctionFactorScorer(this, subQueryScorer, function, maxBoost, combineFunction, minScore);
}
@Override
@ -136,55 +144,22 @@ public class FunctionScoreQuery extends Query {
}
}
static class CustomBoostFactorScorer extends Scorer {
static class FunctionFactorScorer extends CustomBoostFactorScorer {
private final float subQueryBoost;
private final Scorer scorer;
private final ScoreFunction function;
private final float maxBoost;
private final CombineFunction scoreCombiner;
private CustomBoostFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreFunction function, float maxBoost, CombineFunction scoreCombiner)
private FunctionFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreFunction function, float maxBoost, CombineFunction scoreCombiner, Float minScore)
throws IOException {
super(w);
this.subQueryBoost = w.getQuery().getBoost();
this.scorer = scorer;
super(w, scorer, maxBoost, scoreCombiner, minScore);
this.function = function;
this.maxBoost = maxBoost;
this.scoreCombiner = scoreCombiner;
}
@Override
public int docID() {
return scorer.docID();
}
@Override
public int advance(int target) throws IOException {
return scorer.advance(target);
}
@Override
public int nextDoc() throws IOException {
return scorer.nextDoc();
}
@Override
public float score() throws IOException {
public float innerScore() throws IOException {
float score = scorer.score();
return scoreCombiner.combine(subQueryBoost, score,
function.score(scorer.docID(), score), maxBoost);
}
@Override
public int freq() throws IOException {
return scorer.freq();
}
@Override
public long cost() {
return scorer.cost();
}
}
public String toString(String field) {

View File

@ -21,6 +21,7 @@ package org.elasticsearch.index.query.functionscore;
import org.elasticsearch.ElasticsearchIllegalArgumentException;
import org.elasticsearch.common.lucene.search.function.CombineFunction;
import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.BaseQueryBuilder;
import org.elasticsearch.index.query.BoostableQueryBuilder;
@ -50,6 +51,7 @@ public class FunctionScoreQueryBuilder extends BaseQueryBuilder implements Boost
private ArrayList<FilterBuilder> filters = new ArrayList<>();
private ArrayList<ScoreFunctionBuilder> scoreFunctions = new ArrayList<>();
private Float minScore = null;
public FunctionScoreQueryBuilder(QueryBuilder queryBuilder) {
this.queryBuilder = queryBuilder;
@ -158,7 +160,15 @@ public class FunctionScoreQueryBuilder extends BaseQueryBuilder implements Boost
if (boost != null) {
builder.field("boost", boost);
}
if (minScore != null) {
builder.field("min_score", minScore);
}
builder.endObject();
}
public FunctionScoreQueryBuilder setMinScore(float minScore) {
this.minScore = minScore;
return this;
}
}

View File

@ -89,6 +89,7 @@ public class FunctionScoreQueryParser implements QueryParser {
FiltersFunctionScoreQuery.ScoreMode scoreMode = FiltersFunctionScoreQuery.ScoreMode.Multiply;
ArrayList<FiltersFunctionScoreQuery.FilterFunction> filterFunctions = new ArrayList<>();
float maxBoost = Float.MAX_VALUE;
Float minScore = null;
String currentFieldName = null;
XContentParser.Token token;
@ -113,6 +114,8 @@ public class FunctionScoreQueryParser implements QueryParser {
maxBoost = parser.floatValue();
} else if ("boost".equals(currentFieldName)) {
boost = parser.floatValue();
} else if ("min_score".equals(currentFieldName) || "minScore".equals(currentFieldName)) {
minScore = parser.floatValue();
} else if ("functions".equals(currentFieldName)) {
if (singleFunctionFound) {
String errorString = "Found \"" + singleFunctionName + "\" already, now encountering \"functions\": [...].";
@ -154,7 +157,7 @@ public class FunctionScoreQueryParser implements QueryParser {
// handle cases where only one score function and no filter was
// provided. In this case we create a FunctionScoreQuery.
if (filterFunctions.size() == 1 && (filterFunctions.get(0).filter == null || filterFunctions.get(0).filter instanceof MatchAllDocsFilter)) {
FunctionScoreQuery theQuery = new FunctionScoreQuery(query, filterFunctions.get(0).function);
FunctionScoreQuery theQuery = new FunctionScoreQuery(query, filterFunctions.get(0).function, minScore);
if (combineFunction != null) {
theQuery.setCombineFunction(combineFunction);
}
@ -164,7 +167,7 @@ public class FunctionScoreQueryParser implements QueryParser {
// in all other cases we create a FiltersFunctionScoreQuery.
} else {
FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, scoreMode,
filterFunctions.toArray(new FiltersFunctionScoreQuery.FilterFunction[filterFunctions.size()]), maxBoost);
filterFunctions.toArray(new FiltersFunctionScoreQuery.FilterFunction[filterFunctions.size()]), maxBoost, minScore);
if (combineFunction != null) {
functionScoreQuery.setCombineFunction(combineFunction);
}

View File

@ -20,6 +20,7 @@
package org.elasticsearch.search.functionscore;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.common.geo.GeoPoint;
@ -33,6 +34,8 @@ import org.elasticsearch.test.ElasticsearchIntegrationTest;
import org.junit.Test;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import static org.elasticsearch.client.Requests.searchRequest;
@ -46,6 +49,7 @@ import static org.elasticsearch.search.builder.SearchSourceBuilder.searchSource;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
import static org.hamcrest.Matchers.*;
import static org.hamcrest.Matchers.is;
public class FunctionScoreTests extends ElasticsearchIntegrationTest {
@ -204,6 +208,7 @@ public class FunctionScoreTests extends ElasticsearchIntegrationTest {
.add(scriptFunction("_index['" + TEXT_FIELD + "']['value'].tf()").setWeight(2))
))).actionGet();
assertSearchResponse(response);
assertThat(response.getHits().getAt(0).getScore(), is(1.0f));
assertThat(responseWithWeights.getHits().getAt(0).getScore(), is(8.0f));
}
@ -433,5 +438,75 @@ public class FunctionScoreTests extends ElasticsearchIntegrationTest {
assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getKeyAsNumber().floatValue(), is(1f));
assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getDocCount(), is(1l));
}
public void testMinScoreFunctionScoreBasic() throws IOException {
index(INDEX, TYPE, jsonBuilder().startObject().field("num", 2).endObject());
refresh();
float score = randomFloat();
float minScore = randomFloat();
SearchResponse searchResponse = client().search(
searchRequest().source(searchSource().query(functionScoreQuery().add(scriptFunction(Float.toString(score))).setMinScore(minScore)))
).actionGet();
if (score < minScore) {
assertThat(searchResponse.getHits().getTotalHits(), is(0l));
} else {
assertThat(searchResponse.getHits().getTotalHits(), is(1l));
}
searchResponse = client().search(
searchRequest().source(searchSource().query(functionScoreQuery()
.add(scriptFunction(Float.toString(score)))
.add(scriptFunction(Float.toString(score)))
.scoreMode("avg").setMinScore(minScore)))
).actionGet();
if (score < minScore) {
assertThat(searchResponse.getHits().getTotalHits(), is(0l));
} else {
assertThat(searchResponse.getHits().getTotalHits(), is(1l));
}
}
@Test
public void testMinScoreFunctionScoreManyDocsAndRandomMinScore() throws IOException, ExecutionException, InterruptedException {
List<IndexRequestBuilder> docs = new ArrayList<>();
int numDocs = randomIntBetween(1, 100);
int scoreOffset = randomIntBetween(-2 * numDocs, 2 * numDocs);
int minScore = randomIntBetween(-2 * numDocs, 2 * numDocs);
for (int i = 0; i < numDocs; i++) {
docs.add(client().prepareIndex(INDEX, TYPE, Integer.toString(i)).setSource("num", i + scoreOffset));
}
indexRandom(true, docs);
String script = "return (doc['num'].value)";
int numMatchingDocs = numDocs + scoreOffset - minScore;
if (numMatchingDocs < 0) {
numMatchingDocs = 0;
}
if (numMatchingDocs > numDocs) {
numMatchingDocs = numDocs;
}
SearchResponse searchResponse = client().search(
searchRequest().source(searchSource().query(functionScoreQuery()
.add(scriptFunction(script))
.setMinScore(minScore)).size(numDocs))).actionGet();
assertMinScoreSearchResponses(numDocs, searchResponse, numMatchingDocs);
searchResponse = client().search(
searchRequest().source(searchSource().query(functionScoreQuery()
.add(scriptFunction(script))
.add(scriptFunction(script))
.scoreMode("avg").setMinScore(minScore)).size(numDocs))).actionGet();
assertMinScoreSearchResponses(numDocs, searchResponse, numMatchingDocs);
}
protected void assertMinScoreSearchResponses(int numDocs, SearchResponse searchResponse, int numMatchingDocs) {
assertSearchResponse(searchResponse);
assertThat((int) searchResponse.getHits().totalHits(), is(numMatchingDocs));
int pos = 0;
for (int hitId = numDocs - 1; (numDocs - hitId) < searchResponse.getHits().totalHits(); hitId--) {
assertThat(searchResponse.getHits().getAt(pos).getId(), equalTo(Integer.toString(hitId)));
pos++;
}
}
}