Move score script context from SearchScript to its own class (#30816)

This commit is contained in:
Martijn van Groningen 2018-05-25 07:17:50 +02:00 committed by GitHub
parent e1ffbeb824
commit ae2f021f1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 221 additions and 42 deletions

View File

@ -23,8 +23,10 @@ import org.apache.lucene.expressions.Expression;
import org.apache.lucene.expressions.SimpleBindings;
import org.apache.lucene.expressions.js.JavascriptCompiler;
import org.apache.lucene.expressions.js.VariableContext;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.queries.function.valuesource.DoubleConstValueSource;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.SortField;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.common.Nullable;
@ -39,12 +41,14 @@ import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.script.ClassPermission;
import org.elasticsearch.script.ExecutableScript;
import org.elasticsearch.script.FilterScript;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.ScriptException;
import org.elasticsearch.script.SearchScript;
import org.elasticsearch.search.lookup.SearchLookup;
import java.io.IOException;
import java.security.AccessControlContext;
import java.security.AccessController;
import java.security.PrivilegedAction;
@ -111,6 +115,9 @@ public class ExpressionScriptEngine extends AbstractComponent implements ScriptE
} else if (context.instanceClazz.equals(FilterScript.class)) {
FilterScript.Factory factory = (p, lookup) -> newFilterScript(expr, lookup, p);
return context.factoryClazz.cast(factory);
} else if (context.instanceClazz.equals(ScoreScript.class)) {
ScoreScript.Factory factory = (p, lookup) -> newScoreScript(expr, lookup, p);
return context.factoryClazz.cast(factory);
}
throw new IllegalArgumentException("expression engine does not know how to handle script context [" + context.name + "]");
}
@ -260,6 +267,42 @@ public class ExpressionScriptEngine extends AbstractComponent implements ScriptE
};
};
}
private ScoreScript.LeafFactory newScoreScript(Expression expr, SearchLookup lookup, @Nullable Map<String, Object> vars) {
SearchScript.LeafFactory searchLeafFactory = newSearchScript(expr, lookup, vars);
return new ScoreScript.LeafFactory() {
@Override
public boolean needs_score() {
return searchLeafFactory.needs_score();
}
@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
SearchScript script = searchLeafFactory.newInstance(ctx);
return new ScoreScript(vars, lookup, ctx) {
@Override
public double execute() {
return script.runAsDouble();
}
@Override
public void setDocument(int docid) {
script.setDocument(docid);
}
@Override
public void setScorer(Scorer scorer) {
script.setScorer(scorer);
}
@Override
public double get_score() {
return script.getScore();
}
};
}
};
}
/**
* converts a ParseException at compile-time or link-time to a ScriptException

View File

@ -30,9 +30,9 @@ import org.apache.lucene.index.Term;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.ScriptPlugin;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.SearchScript;
/**
* An example script plugin that adds a {@link ScriptEngine} implementing expert scoring.
@ -54,12 +54,12 @@ public class ExpertScriptPlugin extends Plugin implements ScriptPlugin {
@Override
public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
if (context.equals(SearchScript.SCRIPT_SCORE_CONTEXT) == false) {
if (context.equals(ScoreScript.CONTEXT) == false) {
throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]");
}
// we use the script "source" as the script identifier
if ("pure_df".equals(scriptSource)) {
SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {
ScoreScript.Factory factory = (p, lookup) -> new ScoreScript.LeafFactory() {
final String field;
final String term;
{
@ -74,18 +74,18 @@ public class ExpertScriptPlugin extends Plugin implements ScriptPlugin {
}
@Override
public SearchScript newInstance(LeafReaderContext context) throws IOException {
public ScoreScript newInstance(LeafReaderContext context) throws IOException {
PostingsEnum postings = context.reader().postings(new Term(field, term));
if (postings == null) {
// the field and/or term don't exist in this segment, so always return 0
return new SearchScript(p, lookup, context) {
return new ScoreScript(p, lookup, context) {
@Override
public double runAsDouble() {
public double execute() {
return 0.0d;
}
};
}
return new SearchScript(p, lookup, context) {
return new ScoreScript(p, lookup, context) {
int currentDocid = -1;
@Override
public void setDocument(int docid) {
@ -100,7 +100,7 @@ public class ExpertScriptPlugin extends Plugin implements ScriptPlugin {
currentDocid = docid;
}
@Override
public double runAsDouble() {
public double execute() {
if (postings.docID() != currentDocid) {
// advance moved past the current doc, so this doc has no occurrences of the term
return 0.0d;

View File

@ -24,8 +24,8 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Scorer;
import org.elasticsearch.script.ExplainableSearchScript;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.SearchScript;
import java.io.IOException;
import java.util.Objects;
@ -58,10 +58,10 @@ public class ScriptScoreFunction extends ScoreFunction {
private final Script sScript;
private final SearchScript.LeafFactory script;
private final ScoreScript.LeafFactory script;
public ScriptScoreFunction(Script sScript, SearchScript.LeafFactory script) {
public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script) {
super(CombineFunction.REPLACE);
this.sScript = sScript;
this.script = script;
@ -69,7 +69,7 @@ public class ScriptScoreFunction extends ScoreFunction {
@Override
public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOException {
final SearchScript leafScript = script.newInstance(ctx);
final ScoreScript leafScript = script.newInstance(ctx);
final CannedScorer scorer = new CannedScorer();
leafScript.setScorer(scorer);
return new LeafScoreFunction() {
@ -78,7 +78,7 @@ public class ScriptScoreFunction extends ScoreFunction {
leafScript.setDocument(docId);
scorer.docid = docId;
scorer.score = subQueryScore;
double result = leafScript.runAsDouble();
double result = leafScript.execute();
return result;
}

View File

@ -28,6 +28,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.index.query.QueryShardException;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.SearchScript;
@ -92,8 +93,8 @@ public class ScriptScoreFunctionBuilder extends ScoreFunctionBuilder<ScriptScore
@Override
protected ScoreFunction doToFunction(QueryShardContext context) {
try {
SearchScript.Factory factory = context.getScriptService().compile(script, SearchScript.SCRIPT_SCORE_CONTEXT);
SearchScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
ScoreScript.Factory factory = context.getScriptService().compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
return new ScriptScoreFunction(script, searchScript);
} catch (Exception e) {
throw new QueryShardException(context, "script_score: the script could not be loaded", e);

View File

@ -0,0 +1,102 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.script;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Scorer;
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.search.lookup.LeafSearchLookup;
import org.elasticsearch.search.lookup.SearchLookup;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;
import java.util.function.DoubleSupplier;
/**
* A script used for adjusting the score on a per document basis.
*/
public abstract class ScoreScript {
public static final String[] PARAMETERS = new String[]{};
/** The generic runtime parameters for the script. */
private final Map<String, Object> params;
/** A leaf lookup for the bound segment this script will operate on. */
private final LeafSearchLookup leafLookup;
private DoubleSupplier scoreSupplier = () -> 0.0;
public ScoreScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
this.params = params;
this.leafLookup = lookup.getLeafSearchLookup(leafContext);
}
public abstract double execute();
/** Return the parameters for this script. */
public Map<String, Object> getParams() {
return params;
}
/** The doc lookup for the Lucene segment this script was created for. */
public final Map<String, ScriptDocValues<?>> getDoc() {
return leafLookup.doc();
}
/** Set the current document to run the script on next. */
public void setDocument(int docid) {
leafLookup.setDocument(docid);
}
public void setScorer(Scorer scorer) {
this.scoreSupplier = () -> {
try {
return scorer.score();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
};
}
public double get_score() {
return scoreSupplier.getAsDouble();
}
/** A factory to construct {@link ScoreScript} instances. */
public interface LeafFactory {
/**
* Return {@code true} if the script needs {@code _score} calculated, or {@code false} otherwise.
*/
boolean needs_score();
ScoreScript newInstance(LeafReaderContext ctx) throws IOException;
}
/** A factory to construct stateful {@link ScoreScript} factories for a specific index. */
public interface Factory {
ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup);
}
public static final ScriptContext<ScoreScript.Factory> CONTEXT = new ScriptContext<>("score", ScoreScript.Factory.class);
}

View File

@ -42,7 +42,7 @@ public class ScriptModule {
CORE_CONTEXTS = Stream.of(
SearchScript.CONTEXT,
SearchScript.AGGS_CONTEXT,
SearchScript.SCRIPT_SCORE_CONTEXT,
ScoreScript.CONTEXT,
SearchScript.SCRIPT_SORT_CONTEXT,
SearchScript.TERMS_SET_QUERY_CONTEXT,
ExecutableScript.CONTEXT,

View File

@ -162,8 +162,6 @@ public abstract class SearchScript implements ScorerAware, ExecutableScript {
public static final ScriptContext<Factory> AGGS_CONTEXT = new ScriptContext<>("aggs", Factory.class);
// Can return a double. (For ScriptSortType#NUMBER only, for ScriptSortType#STRING normal CONTEXT should be used)
public static final ScriptContext<Factory> SCRIPT_SORT_CONTEXT = new ScriptContext<>("sort", Factory.class);
// Can return a float
public static final ScriptContext<Factory> SCRIPT_SCORE_CONTEXT = new ScriptContext<>("score", Factory.class);
// Can return a long
public static final ScriptContext<Factory> TERMS_SET_QUERY_CONTEXT = new ScriptContext<>("terms_set", Factory.class);
}

View File

@ -30,14 +30,14 @@ import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.ScriptPlugin;
import org.elasticsearch.script.ExplainableSearchScript;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.script.SearchScript;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.lookup.LeafDocLookup;
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.ESIntegTestCase.ClusterScope;
import org.elasticsearch.test.ESIntegTestCase.Scope;
@ -76,16 +76,17 @@ public class ExplainableScriptIT extends ESIntegTestCase {
@Override
public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
assert scriptSource.equals("explainable_script");
assert context == SearchScript.SCRIPT_SCORE_CONTEXT;
SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {
@Override
public SearchScript newInstance(LeafReaderContext context) throws IOException {
return new MyScript(lookup.doc().getLeafDocLookup(context));
}
assert context == ScoreScript.CONTEXT;
ScoreScript.Factory factory = (params1, lookup) -> new ScoreScript.LeafFactory() {
@Override
public boolean needs_score() {
return false;
}
@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
return new MyScript(params1, lookup, ctx);
}
};
return context.factoryClazz.cast(factory);
}
@ -93,28 +94,21 @@ public class ExplainableScriptIT extends ESIntegTestCase {
}
}
static class MyScript extends SearchScript implements ExplainableSearchScript {
LeafDocLookup docLookup;
static class MyScript extends ScoreScript implements ExplainableSearchScript {
MyScript(LeafDocLookup docLookup) {
super(null, null, null);
this.docLookup = docLookup;
MyScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
super(params, lookup, leafContext);
}
@Override
public void setDocument(int doc) {
docLookup.setDocument(doc);
}
@Override
public Explanation explain(Explanation subQueryScore) throws IOException {
Explanation scoreExp = Explanation.match(subQueryScore.getValue(), "_score: ", subQueryScore);
return Explanation.match((float) (runAsDouble()), "This script returned " + runAsDouble(), scoreExp);
return Explanation.match((float) (execute()), "This script returned " + execute(), scoreExp);
}
@Override
public double runAsDouble() {
return ((Number) ((ScriptDocValues) docLookup.get("number_field")).getValues().get(0)).doubleValue();
public double execute() {
return ((Number) ((ScriptDocValues) getDoc().get("number_field")).getValues().get(0)).doubleValue();
}
}

View File

@ -25,7 +25,6 @@ import org.elasticsearch.index.similarity.ScriptedSimilarity.Doc;
import org.elasticsearch.index.similarity.ScriptedSimilarity.Field;
import org.elasticsearch.index.similarity.ScriptedSimilarity.Query;
import org.elasticsearch.index.similarity.ScriptedSimilarity.Term;
import org.elasticsearch.index.similarity.SimilarityService;
import org.elasticsearch.search.aggregations.pipeline.movfn.MovingFunctionScript;
import org.elasticsearch.search.aggregations.pipeline.movfn.MovingFunctions;
import org.elasticsearch.search.lookup.LeafSearchLookup;
@ -36,7 +35,6 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Predicate;
import static java.util.Collections.emptyMap;
@ -114,6 +112,9 @@ public class MockScriptEngine implements ScriptEngine {
} else if (context.instanceClazz.equals(MovingFunctionScript.class)) {
MovingFunctionScript.Factory factory = mockCompiled::createMovingFunctionScript;
return context.factoryClazz.cast(factory);
} else if (context.instanceClazz.equals(ScoreScript.class)) {
ScoreScript.Factory factory = new MockScoreScript(script);
return context.factoryClazz.cast(factory);
}
throw new IllegalArgumentException("mock script engine does not know how to handle context [" + context.name + "]");
}
@ -342,5 +343,45 @@ public class MockScriptEngine implements ScriptEngine {
return MovingFunctions.unweightedAvg(values);
}
}
public class MockScoreScript implements ScoreScript.Factory {
private final Function<Map<String, Object>, Object> scripts;
MockScoreScript(Function<Map<String, Object>, Object> scripts) {
this.scripts = scripts;
}
@Override
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup) {
return new ScoreScript.LeafFactory() {
@Override
public boolean needs_score() {
return true;
}
@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
Scorer[] scorerHolder = new Scorer[1];
return new ScoreScript(params, lookup, ctx) {
@Override
public double execute() {
Map<String, Object> vars = new HashMap<>(getParams());
vars.put("doc", getDoc());
if (scorerHolder[0] != null) {
vars.put("_score", new ScoreAccessor(scorerHolder[0]));
}
return ((Number) scripts.apply(vars)).doubleValue();
}
@Override
public void setScorer(Scorer scorer) {
scorerHolder[0] = scorer;
}
};
}
};
}
}
}