Add bulkScorer to script score query (#46336) (#49734)

Some queries return bulk scorers that can be significantly faster than
iterating naively over the scorer. By giving script_score a BulkScorer
that would delegate to the wrapped query, we could make it faster in some cases.

Closes #40837
This commit is contained in:
Mayya Sharipova 2019-11-29 16:51:50 -05:00 committed by GitHub
parent 1d745f1e5c
commit 62a891bfa3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 170 additions and 38 deletions

View File

@ -25,12 +25,17 @@ import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.FilterLeafCollector;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.util.Bits;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.script.ScoreScript;
@ -83,6 +88,19 @@ public class ScriptScoreQuery extends Query {
Weight subQueryWeight = subQuery.createWeight(searcher, subQueryScoreMode, boost);
return new Weight(this){
@Override
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
if (minScore == null) {
final BulkScorer subQueryBulkScorer = subQueryWeight.bulkScorer(context);
if (subQueryBulkScorer == null) {
return null;
}
return new ScriptScoreBulkScorer(subQueryBulkScorer, subQueryScoreMode, makeScoreScript(context));
} else {
return super.bulkScorer(context);
}
}
@Override
public void extractTerms(Set<Term> terms) {
subQueryWeight.extractTerms(terms);
@ -94,8 +112,7 @@ public class ScriptScoreQuery extends Query {
if (subQueryScorer == null) {
return null;
}
Scorer scriptScorer = makeScriptScorer(subQueryScorer, context, null);
Scorer scriptScorer = new ScriptScorer(this, makeScoreScript(context), subQueryScorer, subQueryScoreMode, null);
if (minScore != null) {
scriptScorer = new MinScoreScorer(this, scriptScorer, minScore);
}
@ -109,7 +126,8 @@ public class ScriptScoreQuery extends Query {
return subQueryExplanation;
}
ExplanationHolder explanationHolder = new ExplanationHolder();
Scorer scorer = makeScriptScorer(subQueryWeight.scorer(context), context, explanationHolder);
Scorer scorer = new ScriptScorer(this, makeScoreScript(context),
subQueryWeight.scorer(context), subQueryScoreMode, explanationHolder);
int newDoc = scorer.iterator().advance(doc);
assert doc == newDoc; // subquery should have already matched above
float score = scorer.score();
@ -132,42 +150,13 @@ public class ScriptScoreQuery extends Query {
}
return explanation;
}
private Scorer makeScriptScorer(Scorer subQueryScorer, LeafReaderContext context,
ExplanationHolder explanation) throws IOException {
private ScoreScript makeScoreScript(LeafReaderContext context) throws IOException {
final ScoreScript scoreScript = scriptBuilder.newInstance(context);
scoreScript.setScorer(subQueryScorer);
scoreScript._setIndexName(indexName);
scoreScript._setShard(shardId);
scoreScript._setIndexVersion(indexVersion);
return new Scorer(this) {
@Override
public float score() throws IOException {
int docId = docID();
scoreScript.setDocument(docId);
float score = (float) scoreScript.execute(explanation);
if (score == Float.NEGATIVE_INFINITY || Float.isNaN(score)) {
throw new ElasticsearchException(
"script score query returned an invalid score: " + score + " for doc: " + docId);
}
return score;
}
@Override
public int docID() {
return subQueryScorer.docID();
}
@Override
public DocIdSetIterator iterator() {
return subQueryScorer.iterator();
}
@Override
public float getMaxScore(int upTo) {
return Float.MAX_VALUE; // TODO: what would be a good upper bound?
}
};
return scoreScript;
}
@Override
@ -187,7 +176,7 @@ public class ScriptScoreQuery extends Query {
@Override
public String toString(String field) {
StringBuilder sb = new StringBuilder();
sb.append("script score (").append(subQuery.toString(field)).append(", script: ");
sb.append("script_score (").append(subQuery.toString(field)).append(", script: ");
sb.append("{" + script.toString() + "}");
return sb.toString();
}
@ -209,4 +198,118 @@ public class ScriptScoreQuery extends Query {
public int hashCode() {
return Objects.hash(subQuery, script, minScore, indexName, shardId, indexVersion);
}
private static class ScriptScorer extends Scorer {
private final ScoreScript scoreScript;
private final Scorer subQueryScorer;
private final ExplanationHolder explanation;
ScriptScorer(Weight weight, ScoreScript scoreScript, Scorer subQueryScorer,
ScoreMode subQueryScoreMode, ExplanationHolder explanation) {
super(weight);
this.scoreScript = scoreScript;
if (subQueryScoreMode == ScoreMode.COMPLETE) {
scoreScript.setScorer(subQueryScorer);
}
this.subQueryScorer = subQueryScorer;
this.explanation = explanation;
}
@Override
public float score() throws IOException {
int docId = docID();
scoreScript.setDocument(docId);
float score = (float) scoreScript.execute(explanation);
if (score == Float.NEGATIVE_INFINITY || Float.isNaN(score)) {
throw new ElasticsearchException(
"script_score query returned an invalid score [" + score + "] for doc [" + docId + "].");
}
return score;
}
@Override
public int docID() {
return subQueryScorer.docID();
}
@Override
public DocIdSetIterator iterator() {
return subQueryScorer.iterator();
}
@Override
public float getMaxScore(int upTo) {
return Float.MAX_VALUE; // TODO: what would be a good upper bound?
}
}
private static class ScriptScorable extends Scorable {
private final ScoreScript scoreScript;
private final Scorable subQueryScorer;
private final ExplanationHolder explanation;
ScriptScorable(ScoreScript scoreScript, Scorable subQueryScorer,
ScoreMode subQueryScoreMode, ExplanationHolder explanation) {
this.scoreScript = scoreScript;
if (subQueryScoreMode == ScoreMode.COMPLETE) {
scoreScript.setScorer(subQueryScorer);
}
this.subQueryScorer = subQueryScorer;
this.explanation = explanation;
}
@Override
public float score() throws IOException {
int docId = docID();
scoreScript.setDocument(docId);
float score = (float) scoreScript.execute(explanation);
if (score == Float.NEGATIVE_INFINITY || Float.isNaN(score)) {
throw new ElasticsearchException(
"script_score query returned an invalid score [" + score + "] for doc [" + docId + "].");
}
return score;
}
@Override
public int docID() {
return subQueryScorer.docID();
}
}
/**
* Use the {@link BulkScorer} of the sub-query,
* as it may be significantly faster (e.g. BooleanScorer) than iterating over the scorer
*/
private static class ScriptScoreBulkScorer extends BulkScorer {
private final BulkScorer subQueryBulkScorer;
private final ScoreMode subQueryScoreMode;
private final ScoreScript scoreScript;
ScriptScoreBulkScorer(BulkScorer subQueryBulkScorer, ScoreMode subQueryScoreMode, ScoreScript scoreScript) {
this.subQueryBulkScorer = subQueryBulkScorer;
this.subQueryScoreMode = subQueryScoreMode;
this.scoreScript = scoreScript;
}
@Override
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
return subQueryBulkScorer.score(wrapCollector(collector), acceptDocs, min, max);
}
private LeafCollector wrapCollector(LeafCollector collector) {
return new FilterLeafCollector(collector) {
@Override
public void setScorer(Scorable scorer) throws IOException {
in.setScorer(new ScriptScorable(scoreScript, scorer, subQueryScoreMode, null));
}
};
}
@Override
public long cost() {
return subQueryBulkScorer.cost();
}
}
}

View File

@ -22,6 +22,7 @@ package org.elasticsearch.search.query;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.script.MockScriptPlugin;
@ -35,6 +36,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import static org.elasticsearch.index.query.QueryBuilders.boolQuery;
import static org.elasticsearch.index.query.QueryBuilders.matchQuery;
import static org.elasticsearch.index.query.QueryBuilders.scriptScoreQuery;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
@ -104,6 +106,33 @@ public class ScriptScoreQueryIT extends ESIntegTestCase {
assertOrderedSearchHits(resp, "10", "8", "6");
}
public void testScriptScoreBoolQuery() {
assertAcked(
prepareCreate("test-index").addMapping("_doc", "field1", "type=text", "field2", "type=double")
);
int docCount = 10;
for (int i = 1; i <= docCount; i++) {
client().prepareIndex("test-index", "_doc", "" + i)
.setSource("field1", "text" + i, "field2", i)
.get();
}
refresh();
Map<String, Object> params = new HashMap<>();
params.put("param1", 0.1);
Script script = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "doc['field2'].value * param1", params);
QueryBuilder boolQuery = boolQuery().should(matchQuery("field1", "text1")).should(matchQuery("field1", "text10"));
SearchResponse resp = client()
.prepareSearch("test-index")
.setQuery(scriptScoreQuery(boolQuery, script))
.get();
assertNoFailures(resp);
assertOrderedSearchHits(resp, "10", "1");
assertFirstHit(resp, hasScore(1.0f));
assertSecondHit(resp, hasScore(0.1f));
}
// test that when the internal query is rewritten script_score works well
public void testRewrittenQuery() {
assertAcked(