From 59507cf793c17ca2ea99c15230c126b88bebd49e Mon Sep 17 00:00:00 2001 From: Britta Weber Date: Sat, 20 Sep 2014 00:23:55 +0200 Subject: [PATCH] function_score: match only document with score above custom score threshold functon_score matched each document regardless of the computed score. This commit adds a query parameter `min_score` (-Float.MAX_VALUE default). Documents that have a score lower than this threshold will not be mached. closes #6952 --- .../queries/function-score-query.asciidoc | 9 +- .../function/CustomBoostFactorScorer.java | 140 ++++++++++++++++++ .../function/FiltersFunctionScoreQuery.java | 50 ++----- .../search/function/FunctionScoreQuery.java | 53 ++----- .../FunctionScoreQueryBuilder.java | 10 ++ .../FunctionScoreQueryParser.java | 7 +- .../functionscore/FunctionScoreTests.java | 75 ++++++++++ 7 files changed, 260 insertions(+), 84 deletions(-) create mode 100644 src/main/java/org/elasticsearch/common/lucene/search/function/CustomBoostFactorScorer.java diff --git a/docs/reference/query-dsl/queries/function-score-query.asciidoc b/docs/reference/query-dsl/queries/function-score-query.asciidoc index eb995c4b795..e96c1660e5a 100644 --- a/docs/reference/query-dsl/queries/function-score-query.asciidoc +++ b/docs/reference/query-dsl/queries/function-score-query.asciidoc @@ -49,7 +49,8 @@ given filter: ], "max_boost": number, "score_mode": "(multiply|max|...)", - "boost_mode": "(multiply|replace|...)" + "boost_mode": "(multiply|replace|...)", + "min_score" : number } -------------------------------------------------- @@ -74,7 +75,7 @@ If weight is given without any other function declaration, `weight` acts as a fu The new score can be restricted to not exceed a certain limit by setting the `max_boost` parameter. The default for `max_boost` is FLT_MAX. -Finally, the newly computed score is combined with the score of the +The newly computed score is combined with the score of the query. The parameter `boost_mode` defines how: [horizontal] @@ -85,6 +86,10 @@ query. The parameter `boost_mode` defines how: `max`:: max of query score and function score `min`:: min of query score and function score +coming[1.5.0] + +By default, modifying the score does not change which documents match. To exclude +documents that do not meet a certain score threshold the `min_score` parameter can be set to the desired score threshold. ==== Score functions diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/CustomBoostFactorScorer.java b/src/main/java/org/elasticsearch/common/lucene/search/function/CustomBoostFactorScorer.java new file mode 100644 index 00000000000..bcc785aeebc --- /dev/null +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/CustomBoostFactorScorer.java @@ -0,0 +1,140 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.common.lucene.search.function; + +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; + +import java.io.IOException; + +abstract class CustomBoostFactorScorer extends Scorer { + + final float subQueryBoost; + final Scorer scorer; + final float maxBoost; + final CombineFunction scoreCombiner; + + Float minScore; + NextDoc nextDoc; + + CustomBoostFactorScorer(Weight w, Scorer scorer, float maxBoost, CombineFunction scoreCombiner, Float minScore) + throws IOException { + super(w); + if (minScore == null) { + nextDoc = new AnyNextDoc(); + } else { + nextDoc = new MinScoreNextDoc(); + } + this.subQueryBoost = w.getQuery().getBoost(); + this.scorer = scorer; + this.maxBoost = maxBoost; + this.scoreCombiner = scoreCombiner; + this.minScore = minScore; + } + + @Override + public int docID() { + return scorer.docID(); + } + + @Override + public int advance(int target) throws IOException { + return nextDoc.advance(target); + } + + @Override + public int nextDoc() throws IOException { + return nextDoc.nextDoc(); + } + + public abstract float innerScore() throws IOException; + + @Override + public float score() throws IOException { + return nextDoc.score(); + } + + @Override + public int freq() throws IOException { + return scorer.freq(); + } + + @Override + public long cost() { + return scorer.cost(); + } + + public interface NextDoc { + public int advance(int target) throws IOException; + + public int nextDoc() throws IOException; + + public float score() throws IOException; + } + + public class MinScoreNextDoc implements NextDoc { + float currentScore = Float.MAX_VALUE * -1.0f; + + public int nextDoc() throws IOException { + int doc; + do { + doc = scorer.nextDoc(); + if (doc == NO_MORE_DOCS) { + return doc; + } + currentScore = innerScore(); + } while (currentScore < minScore); + return doc; + } + + @Override + public float score() throws IOException { + return currentScore; + } + + public int advance(int target) throws IOException { + int doc = scorer.advance(target); + if (doc == NO_MORE_DOCS) { + return doc; + } + currentScore = innerScore(); + if (currentScore < minScore) { + return scorer.nextDoc(); + } + return doc; + } + } + + public class AnyNextDoc implements NextDoc { + + public int nextDoc() throws IOException { + return scorer.nextDoc(); + } + + @Override + public float score() throws IOException { + return innerScore(); + } + + public int advance(int target) throws IOException { + return scorer.advance(target); + } + } +} diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java b/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java index 8cedd928926..15ceca06434 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java @@ -78,15 +78,17 @@ public class FiltersFunctionScoreQuery extends Query { final FilterFunction[] filterFunctions; final ScoreMode scoreMode; final float maxBoost; + private Float minScore; protected CombineFunction combineFunction; - public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions, float maxBoost) { + public FiltersFunctionScoreQuery(Query subQuery, ScoreMode scoreMode, FilterFunction[] filterFunctions, float maxBoost, Float minScore) { this.subQuery = subQuery; this.scoreMode = scoreMode; this.filterFunctions = filterFunctions; this.maxBoost = maxBoost; combineFunction = CombineFunction.MULT; + this.minScore = minScore; } public FiltersFunctionScoreQuery setCombineFunction(CombineFunction combineFunction) { @@ -163,7 +165,7 @@ public class FiltersFunctionScoreQuery extends Query { filterFunction.function.setNextReader(context); docSets[i] = DocIdSets.toSafeBits(context.reader(), filterFunction.filter.getDocIdSet(context, acceptDocs)); } - return new CustomBoostFactorScorer(this, subQueryScorer, scoreMode, filterFunctions, maxBoost, docSets, combineFunction); + return new FiltersFunctionFactorScorer(this, subQueryScorer, scoreMode, filterFunctions, maxBoost, docSets, combineFunction, minScore); } @Override @@ -245,45 +247,21 @@ public class FiltersFunctionScoreQuery extends Query { } } - static class CustomBoostFactorScorer extends Scorer { - - private final float subQueryBoost; - private final Scorer scorer; + static class FiltersFunctionFactorScorer extends CustomBoostFactorScorer { private final FilterFunction[] filterFunctions; private final ScoreMode scoreMode; - private final float maxBoost; private final Bits[] docSets; - private final CombineFunction scoreCombiner; - private CustomBoostFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreMode scoreMode, FilterFunction[] filterFunctions, - float maxBoost, Bits[] docSets, CombineFunction scoreCombiner) throws IOException { - super(w); - this.subQueryBoost = w.getQuery().getBoost(); - this.scorer = scorer; + private FiltersFunctionFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreMode scoreMode, FilterFunction[] filterFunctions, + float maxBoost, Bits[] docSets, CombineFunction scoreCombiner, Float minScore) throws IOException { + super(w, scorer, maxBoost, scoreCombiner, minScore); this.scoreMode = scoreMode; this.filterFunctions = filterFunctions; - this.maxBoost = maxBoost; this.docSets = docSets; - this.scoreCombiner = scoreCombiner; } @Override - public int docID() { - return scorer.docID(); - } - - @Override - public int advance(int target) throws IOException { - return scorer.advance(target); - } - - @Override - public int nextDoc() throws IOException { - return scorer.nextDoc(); - } - - @Override - public float score() throws IOException { + public float innerScore() throws IOException { int docId = scorer.docID(); double factor = 1.0f; float subQueryScore = scorer.score(); @@ -338,16 +316,6 @@ public class FiltersFunctionScoreQuery extends Query { } return scoreCombiner.combine(subQueryBoost, subQueryScore, factor, maxBoost); } - - @Override - public int freq() throws IOException { - return scorer.freq(); - } - - @Override - public long cost() { - return scorer.cost(); - } } public String toString(String field) { diff --git a/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java b/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java index 5f730fc7fc3..2a161580f9c 100644 --- a/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java +++ b/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java @@ -38,7 +38,15 @@ public class FunctionScoreQuery extends Query { final ScoreFunction function; float maxBoost = Float.MAX_VALUE; CombineFunction combineFunction; - + private Float minScore = null; + + public FunctionScoreQuery(Query subQuery, ScoreFunction function, Float minScore) { + this.subQuery = subQuery; + this.function = function; + this.combineFunction = function.getDefaultScoreCombiner(); + this.minScore = minScore; + } + public FunctionScoreQuery(Query subQuery, ScoreFunction function) { this.subQuery = subQuery; this.function = function; @@ -121,7 +129,7 @@ public class FunctionScoreQuery extends Query { return null; } function.setNextReader(context); - return new CustomBoostFactorScorer(this, subQueryScorer, function, maxBoost, combineFunction); + return new FunctionFactorScorer(this, subQueryScorer, function, maxBoost, combineFunction, minScore); } @Override @@ -136,55 +144,22 @@ public class FunctionScoreQuery extends Query { } } - static class CustomBoostFactorScorer extends Scorer { + static class FunctionFactorScorer extends CustomBoostFactorScorer { - private final float subQueryBoost; - private final Scorer scorer; private final ScoreFunction function; - private final float maxBoost; - private final CombineFunction scoreCombiner; - private CustomBoostFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreFunction function, float maxBoost, CombineFunction scoreCombiner) + private FunctionFactorScorer(CustomBoostFactorWeight w, Scorer scorer, ScoreFunction function, float maxBoost, CombineFunction scoreCombiner, Float minScore) throws IOException { - super(w); - this.subQueryBoost = w.getQuery().getBoost(); - this.scorer = scorer; + super(w, scorer, maxBoost, scoreCombiner, minScore); this.function = function; - this.maxBoost = maxBoost; - this.scoreCombiner = scoreCombiner; } @Override - public int docID() { - return scorer.docID(); - } - - @Override - public int advance(int target) throws IOException { - return scorer.advance(target); - } - - @Override - public int nextDoc() throws IOException { - return scorer.nextDoc(); - } - - @Override - public float score() throws IOException { + public float innerScore() throws IOException { float score = scorer.score(); return scoreCombiner.combine(subQueryBoost, score, function.score(scorer.docID(), score), maxBoost); } - - @Override - public int freq() throws IOException { - return scorer.freq(); - } - - @Override - public long cost() { - return scorer.cost(); - } } public String toString(String field) { diff --git a/src/main/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryBuilder.java b/src/main/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryBuilder.java index 0ea796fe40f..7d7e3d3c34d 100644 --- a/src/main/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryBuilder.java +++ b/src/main/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryBuilder.java @@ -21,6 +21,7 @@ package org.elasticsearch.index.query.functionscore; import org.elasticsearch.ElasticsearchIllegalArgumentException; import org.elasticsearch.common.lucene.search.function.CombineFunction; +import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.index.query.BaseQueryBuilder; import org.elasticsearch.index.query.BoostableQueryBuilder; @@ -50,6 +51,7 @@ public class FunctionScoreQueryBuilder extends BaseQueryBuilder implements Boost private ArrayList filters = new ArrayList<>(); private ArrayList scoreFunctions = new ArrayList<>(); + private Float minScore = null; public FunctionScoreQueryBuilder(QueryBuilder queryBuilder) { this.queryBuilder = queryBuilder; @@ -158,7 +160,15 @@ public class FunctionScoreQueryBuilder extends BaseQueryBuilder implements Boost if (boost != null) { builder.field("boost", boost); } + if (minScore != null) { + builder.field("min_score", minScore); + } builder.endObject(); } + + public FunctionScoreQueryBuilder setMinScore(float minScore) { + this.minScore = minScore; + return this; + } } \ No newline at end of file diff --git a/src/main/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryParser.java b/src/main/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryParser.java index 64508e21fc3..eebf0b69979 100644 --- a/src/main/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryParser.java +++ b/src/main/java/org/elasticsearch/index/query/functionscore/FunctionScoreQueryParser.java @@ -89,6 +89,7 @@ public class FunctionScoreQueryParser implements QueryParser { FiltersFunctionScoreQuery.ScoreMode scoreMode = FiltersFunctionScoreQuery.ScoreMode.Multiply; ArrayList filterFunctions = new ArrayList<>(); float maxBoost = Float.MAX_VALUE; + Float minScore = null; String currentFieldName = null; XContentParser.Token token; @@ -113,6 +114,8 @@ public class FunctionScoreQueryParser implements QueryParser { maxBoost = parser.floatValue(); } else if ("boost".equals(currentFieldName)) { boost = parser.floatValue(); + } else if ("min_score".equals(currentFieldName) || "minScore".equals(currentFieldName)) { + minScore = parser.floatValue(); } else if ("functions".equals(currentFieldName)) { if (singleFunctionFound) { String errorString = "Found \"" + singleFunctionName + "\" already, now encountering \"functions\": [...]."; @@ -154,7 +157,7 @@ public class FunctionScoreQueryParser implements QueryParser { // handle cases where only one score function and no filter was // provided. In this case we create a FunctionScoreQuery. if (filterFunctions.size() == 1 && (filterFunctions.get(0).filter == null || filterFunctions.get(0).filter instanceof MatchAllDocsFilter)) { - FunctionScoreQuery theQuery = new FunctionScoreQuery(query, filterFunctions.get(0).function); + FunctionScoreQuery theQuery = new FunctionScoreQuery(query, filterFunctions.get(0).function, minScore); if (combineFunction != null) { theQuery.setCombineFunction(combineFunction); } @@ -164,7 +167,7 @@ public class FunctionScoreQueryParser implements QueryParser { // in all other cases we create a FiltersFunctionScoreQuery. } else { FiltersFunctionScoreQuery functionScoreQuery = new FiltersFunctionScoreQuery(query, scoreMode, - filterFunctions.toArray(new FiltersFunctionScoreQuery.FilterFunction[filterFunctions.size()]), maxBoost); + filterFunctions.toArray(new FiltersFunctionScoreQuery.FilterFunction[filterFunctions.size()]), maxBoost, minScore); if (combineFunction != null) { functionScoreQuery.setCombineFunction(combineFunction); } diff --git a/src/test/java/org/elasticsearch/search/functionscore/FunctionScoreTests.java b/src/test/java/org/elasticsearch/search/functionscore/FunctionScoreTests.java index 053caa257d2..93e9d5017bb 100644 --- a/src/test/java/org/elasticsearch/search/functionscore/FunctionScoreTests.java +++ b/src/test/java/org/elasticsearch/search/functionscore/FunctionScoreTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.search.functionscore; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.common.geo.GeoPoint; @@ -33,6 +34,8 @@ import org.elasticsearch.test.ElasticsearchIntegrationTest; import org.junit.Test; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.ExecutionException; import static org.elasticsearch.client.Requests.searchRequest; @@ -46,6 +49,7 @@ import static org.elasticsearch.search.builder.SearchSourceBuilder.searchSource; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; import static org.hamcrest.Matchers.*; +import static org.hamcrest.Matchers.is; public class FunctionScoreTests extends ElasticsearchIntegrationTest { @@ -204,6 +208,7 @@ public class FunctionScoreTests extends ElasticsearchIntegrationTest { .add(scriptFunction("_index['" + TEXT_FIELD + "']['value'].tf()").setWeight(2)) ))).actionGet(); + assertSearchResponse(response); assertThat(response.getHits().getAt(0).getScore(), is(1.0f)); assertThat(responseWithWeights.getHits().getAt(0).getScore(), is(8.0f)); } @@ -433,5 +438,75 @@ public class FunctionScoreTests extends ElasticsearchIntegrationTest { assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getKeyAsNumber().floatValue(), is(1f)); assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getDocCount(), is(1l)); } + + public void testMinScoreFunctionScoreBasic() throws IOException { + index(INDEX, TYPE, jsonBuilder().startObject().field("num", 2).endObject()); + refresh(); + float score = randomFloat(); + float minScore = randomFloat(); + SearchResponse searchResponse = client().search( + searchRequest().source(searchSource().query(functionScoreQuery().add(scriptFunction(Float.toString(score))).setMinScore(minScore))) + ).actionGet(); + if (score < minScore) { + assertThat(searchResponse.getHits().getTotalHits(), is(0l)); + } else { + assertThat(searchResponse.getHits().getTotalHits(), is(1l)); + } + + searchResponse = client().search( + searchRequest().source(searchSource().query(functionScoreQuery() + .add(scriptFunction(Float.toString(score))) + .add(scriptFunction(Float.toString(score))) + .scoreMode("avg").setMinScore(minScore))) + ).actionGet(); + if (score < minScore) { + assertThat(searchResponse.getHits().getTotalHits(), is(0l)); + } else { + assertThat(searchResponse.getHits().getTotalHits(), is(1l)); + } + } + + @Test + public void testMinScoreFunctionScoreManyDocsAndRandomMinScore() throws IOException, ExecutionException, InterruptedException { + List docs = new ArrayList<>(); + int numDocs = randomIntBetween(1, 100); + int scoreOffset = randomIntBetween(-2 * numDocs, 2 * numDocs); + int minScore = randomIntBetween(-2 * numDocs, 2 * numDocs); + for (int i = 0; i < numDocs; i++) { + docs.add(client().prepareIndex(INDEX, TYPE, Integer.toString(i)).setSource("num", i + scoreOffset)); + } + indexRandom(true, docs); + String script = "return (doc['num'].value)"; + int numMatchingDocs = numDocs + scoreOffset - minScore; + if (numMatchingDocs < 0) { + numMatchingDocs = 0; + } + if (numMatchingDocs > numDocs) { + numMatchingDocs = numDocs; + } + + SearchResponse searchResponse = client().search( + searchRequest().source(searchSource().query(functionScoreQuery() + .add(scriptFunction(script)) + .setMinScore(minScore)).size(numDocs))).actionGet(); + assertMinScoreSearchResponses(numDocs, searchResponse, numMatchingDocs); + + searchResponse = client().search( + searchRequest().source(searchSource().query(functionScoreQuery() + .add(scriptFunction(script)) + .add(scriptFunction(script)) + .scoreMode("avg").setMinScore(minScore)).size(numDocs))).actionGet(); + assertMinScoreSearchResponses(numDocs, searchResponse, numMatchingDocs); + } + + protected void assertMinScoreSearchResponses(int numDocs, SearchResponse searchResponse, int numMatchingDocs) { + assertSearchResponse(searchResponse); + assertThat((int) searchResponse.getHits().totalHits(), is(numMatchingDocs)); + int pos = 0; + for (int hitId = numDocs - 1; (numDocs - hitId) < searchResponse.getHits().totalHits(); hitId--) { + assertThat(searchResponse.getHits().getAt(pos).getId(), equalTo(Integer.toString(hitId))); + pos++; + } + } }