Correct boost calculation in script_score query (#52478) (#52724)

Before boost in script_score query was wrongly applied only to the subquery.
This commit makes sure that the boost is applied to the whole score
that comes out of script.

Closes #48465
This commit is contained in:
Mayya Sharipova 2020-02-24 13:48:21 -05:00 committed by GitHub
parent a7bdb0b456
commit 034b1c0ba3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 119 additions and 21 deletions

View File

@ -48,9 +48,12 @@ scores be positive or `0`.
-- --
`min_score`:: `min_score`::
(Optional, float) Documents with a <<relevance-scores,relevance score>> lower (Optional, float) Documents with a score lower
than this floating point number are excluded from the search results. than this floating point number are excluded from the search results.
`boost`::
(Optional, float) Documents' scores produced by `script` are
multiplied by `boost` to produce final documents' scores. Defaults to `1.0`.
[[script-score-query-notes]] [[script-score-query-notes]]
==== Notes ==== Notes

View File

@ -0,0 +1,86 @@
# Integration tests for ScriptScoreQuery using Painless
setup:
- skip:
version: " - 7.6.99"
reason: "boost was corrected in script_score query from 7.7"
- do:
indices.create:
index: test_index
body:
settings:
index:
number_of_shards: 1
number_of_replicas: 0
mappings:
properties:
k:
type: keyword
i:
type: integer
- do:
bulk:
index: test_index
refresh: true
body:
- '{"index": {"_id": "1"}}'
- '{"k": "k", "i" : 1}'
- '{"index": {"_id": "2"}}'
- '{"k": "kk", "i" : 2}'
- '{"index": {"_id": "3"}}'
- '{"k": "kkk", "i" : 3}'
---
"Boost script_score":
- do:
search:
index: test_index
body:
query:
script_score:
query: {match_all: {}}
script:
source: "doc['i'].value * _score"
boost: 10
- match: { hits.total.value: 3 }
- match: { hits.hits.0._score: 30 }
- match: { hits.hits.1._score: 20 }
- match: { hits.hits.2._score: 10 }
---
"Boost script_score and boost internal query":
- do:
search:
index: test_index
body:
query:
script_score:
query: {match_all: {boost: 5}}
script:
source: "doc['i'].value * _score"
boost: 10
- match: { hits.total.value: 3 }
- match: { hits.hits.0._score: 150 }
- match: { hits.hits.1._score: 100 }
- match: { hits.hits.2._score: 50 }
---
"Boost script_score with explain":
- do:
search:
index: test_index
body:
explain: true
query:
script_score:
query: {term: {"k": "kkk"}}
script:
source: "doc['i'].value"
boost: 10
- match: { hits.total.value: 1 }
- match: { hits.hits.0._score: 30 }
- match: { hits.hits.0._explanation.value: 30 }
- match: { hits.hits.0._explanation.details.0.description: "boost" }
- match: { hits.hits.0._explanation.details.0.value: 10}

View File

@ -36,7 +36,6 @@ import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.script.ScoreScript; import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.ScoreScript.ExplanationHolder; import org.elasticsearch.script.ScoreScript.ExplanationHolder;
@ -85,7 +84,7 @@ public class ScriptScoreQuery extends Query {
} }
boolean needsScore = scriptBuilder.needs_score(); boolean needsScore = scriptBuilder.needs_score();
ScoreMode subQueryScoreMode = needsScore ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES; ScoreMode subQueryScoreMode = needsScore ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
Weight subQueryWeight = subQuery.createWeight(searcher, subQueryScoreMode, boost); Weight subQueryWeight = subQuery.createWeight(searcher, subQueryScoreMode, 1.0f);
return new Weight(this){ return new Weight(this){
@Override @Override
@ -95,7 +94,7 @@ public class ScriptScoreQuery extends Query {
if (subQueryBulkScorer == null) { if (subQueryBulkScorer == null) {
return null; return null;
} }
return new ScriptScoreBulkScorer(subQueryBulkScorer, subQueryScoreMode, makeScoreScript(context)); return new ScriptScoreBulkScorer(subQueryBulkScorer, subQueryScoreMode, makeScoreScript(context), boost);
} else { } else {
return super.bulkScorer(context); return super.bulkScorer(context);
} }
@ -112,7 +111,7 @@ public class ScriptScoreQuery extends Query {
if (subQueryScorer == null) { if (subQueryScorer == null) {
return null; return null;
} }
Scorer scriptScorer = new ScriptScorer(this, makeScoreScript(context), subQueryScorer, subQueryScoreMode, null); Scorer scriptScorer = new ScriptScorer(this, makeScoreScript(context), subQueryScorer, subQueryScoreMode, boost, null);
if (minScore != null) { if (minScore != null) {
scriptScorer = new MinScoreScorer(this, scriptScorer, minScore); scriptScorer = new MinScoreScorer(this, scriptScorer, minScore);
} }
@ -127,11 +126,11 @@ public class ScriptScoreQuery extends Query {
} }
ExplanationHolder explanationHolder = new ExplanationHolder(); ExplanationHolder explanationHolder = new ExplanationHolder();
Scorer scorer = new ScriptScorer(this, makeScoreScript(context), Scorer scorer = new ScriptScorer(this, makeScoreScript(context),
subQueryWeight.scorer(context), subQueryScoreMode, explanationHolder); subQueryWeight.scorer(context), subQueryScoreMode, 1f, explanationHolder);
int newDoc = scorer.iterator().advance(doc); int newDoc = scorer.iterator().advance(doc);
assert doc == newDoc; // subquery should have already matched above assert doc == newDoc; // subquery should have already matched above
float score = scorer.score(); float score = scorer.score(); // score without boost
Explanation explanation = explanationHolder.get(score, needsScore ? subQueryExplanation : null); Explanation explanation = explanationHolder.get(score, needsScore ? subQueryExplanation : null);
if (explanation == null) { if (explanation == null) {
// no explanation provided by user; give a simple one // no explanation provided by user; give a simple one
@ -143,7 +142,10 @@ public class ScriptScoreQuery extends Query {
explanation = Explanation.match(score, desc); explanation = Explanation.match(score, desc);
} }
} }
if (boost != 1f) {
explanation = Explanation.match(boost * explanation.getValue().floatValue(), "Boosted score, product of:",
Explanation.match(boost, "boost"), explanation);
}
if (minScore != null && minScore > explanation.getValue().floatValue()) { if (minScore != null && minScore > explanation.getValue().floatValue()) {
explanation = Explanation.noMatch("Score value is too low, expected at least " + minScore + explanation = Explanation.noMatch("Score value is too low, expected at least " + minScore +
" but got " + explanation.getValue(), explanation); " but got " + explanation.getValue(), explanation);
@ -203,16 +205,18 @@ public class ScriptScoreQuery extends Query {
private static class ScriptScorer extends Scorer { private static class ScriptScorer extends Scorer {
private final ScoreScript scoreScript; private final ScoreScript scoreScript;
private final Scorer subQueryScorer; private final Scorer subQueryScorer;
private final float boost;
private final ExplanationHolder explanation; private final ExplanationHolder explanation;
ScriptScorer(Weight weight, ScoreScript scoreScript, Scorer subQueryScorer, ScriptScorer(Weight weight, ScoreScript scoreScript, Scorer subQueryScorer,
ScoreMode subQueryScoreMode, ExplanationHolder explanation) { ScoreMode subQueryScoreMode, float boost, ExplanationHolder explanation) {
super(weight); super(weight);
this.scoreScript = scoreScript; this.scoreScript = scoreScript;
if (subQueryScoreMode == ScoreMode.COMPLETE) { if (subQueryScoreMode == ScoreMode.COMPLETE) {
scoreScript.setScorer(subQueryScorer); scoreScript.setScorer(subQueryScorer);
} }
this.subQueryScorer = subQueryScorer; this.subQueryScorer = subQueryScorer;
this.boost = boost;
this.explanation = explanation; this.explanation = explanation;
} }
@ -221,12 +225,13 @@ public class ScriptScoreQuery extends Query {
int docId = docID(); int docId = docID();
scoreScript.setDocument(docId); scoreScript.setDocument(docId);
float score = (float) scoreScript.execute(explanation); float score = (float) scoreScript.execute(explanation);
if (score == Float.NEGATIVE_INFINITY || Float.isNaN(score)) { if (score < 0f || Float.isNaN(score)) {
throw new ElasticsearchException( throw new IllegalArgumentException("script_score script returned an invalid score [" + score + "] " +
"script_score query returned an invalid score [" + score + "] for doc [" + docId + "]."); "for doc [" + docId + "]. Must be a non-negative score!");
} }
return score; return score * boost;
} }
@Override @Override
public int docID() { public int docID() {
return subQueryScorer.docID(); return subQueryScorer.docID();
@ -247,15 +252,17 @@ public class ScriptScoreQuery extends Query {
private static class ScriptScorable extends Scorable { private static class ScriptScorable extends Scorable {
private final ScoreScript scoreScript; private final ScoreScript scoreScript;
private final Scorable subQueryScorer; private final Scorable subQueryScorer;
private final float boost;
private final ExplanationHolder explanation; private final ExplanationHolder explanation;
ScriptScorable(ScoreScript scoreScript, Scorable subQueryScorer, ScriptScorable(ScoreScript scoreScript, Scorable subQueryScorer,
ScoreMode subQueryScoreMode, ExplanationHolder explanation) { ScoreMode subQueryScoreMode, float boost, ExplanationHolder explanation) {
this.scoreScript = scoreScript; this.scoreScript = scoreScript;
if (subQueryScoreMode == ScoreMode.COMPLETE) { if (subQueryScoreMode == ScoreMode.COMPLETE) {
scoreScript.setScorer(subQueryScorer); scoreScript.setScorer(subQueryScorer);
} }
this.subQueryScorer = subQueryScorer; this.subQueryScorer = subQueryScorer;
this.boost = boost;
this.explanation = explanation; this.explanation = explanation;
} }
@ -264,11 +271,11 @@ public class ScriptScoreQuery extends Query {
int docId = docID(); int docId = docID();
scoreScript.setDocument(docId); scoreScript.setDocument(docId);
float score = (float) scoreScript.execute(explanation); float score = (float) scoreScript.execute(explanation);
if (score == Float.NEGATIVE_INFINITY || Float.isNaN(score)) { if (score < 0f || Float.isNaN(score)) {
throw new ElasticsearchException( throw new IllegalArgumentException("script_score script returned an invalid score [" + score + "] " +
"script_score query returned an invalid score [" + score + "] for doc [" + docId + "]."); "for doc [" + docId + "]. Must be a non-negative score!");
} }
return score; return score * boost;
} }
@Override @Override
public int docID() { public int docID() {
@ -284,11 +291,13 @@ public class ScriptScoreQuery extends Query {
private final BulkScorer subQueryBulkScorer; private final BulkScorer subQueryBulkScorer;
private final ScoreMode subQueryScoreMode; private final ScoreMode subQueryScoreMode;
private final ScoreScript scoreScript; private final ScoreScript scoreScript;
private final float boost;
ScriptScoreBulkScorer(BulkScorer subQueryBulkScorer, ScoreMode subQueryScoreMode, ScoreScript scoreScript) { ScriptScoreBulkScorer(BulkScorer subQueryBulkScorer, ScoreMode subQueryScoreMode, ScoreScript scoreScript, float boost) {
this.subQueryBulkScorer = subQueryBulkScorer; this.subQueryBulkScorer = subQueryBulkScorer;
this.subQueryScoreMode = subQueryScoreMode; this.subQueryScoreMode = subQueryScoreMode;
this.scoreScript = scoreScript; this.scoreScript = scoreScript;
this.boost = boost;
} }
@Override @Override
@ -300,7 +309,7 @@ public class ScriptScoreQuery extends Query {
return new FilterLeafCollector(collector) { return new FilterLeafCollector(collector) {
@Override @Override
public void setScorer(Scorable scorer) throws IOException { public void setScorer(Scorable scorer) throws IOException {
in.setScorer(new ScriptScorable(scoreScript, scorer, subQueryScoreMode, null)); in.setScorer(new ScriptScorable(scoreScript, scorer, subQueryScoreMode, boost, null));
} }
}; };
} }