script with _score: remove dependency of DocLookup and scorer

As pointed out in #7487 DocLookup is a variable that is accessible by all scripts
for one doc while the query is executed. But the _score and therfore the scorer
depends on the current context, that is, which part of query is currently executed.
Instead of setting the scorer for DocLookup
and have Script access the DocLookup for getting the score, the Scorer should just
be explicitely set for each script.
DocLookup should not have any reference to a scorer.
This was similarly discussed in #7043.

This dependency caused a stackoverflow when running script score in combination with an
aggregation on _score. Also the wrong scorer was called when nesting several script scores.

closes #7487
closes #7819
This commit is contained in:
Britta Weber 2014-09-19 18:03:02 +02:00
parent 9c9cd01854
commit 7feb742a9b
9 changed files with 54 additions and 53 deletions

View File

@ -95,7 +95,7 @@ public abstract class AbstractSearchScript extends AbstractExecutableScript impl
@Override @Override
public void setScorer(Scorer scorer) { public void setScorer(Scorer scorer) {
lookup.setScorer(scorer); throw new UnsupportedOperationException();
} }
@Override @Override

View File

@ -19,6 +19,7 @@
package org.elasticsearch.script; package org.elasticsearch.script;
import org.apache.lucene.search.Scorer;
import org.elasticsearch.search.lookup.DocLookup; import org.elasticsearch.search.lookup.DocLookup;
import java.io.IOException; import java.io.IOException;
@ -31,15 +32,15 @@ import java.io.IOException;
*/ */
public final class ScoreAccessor extends Number { public final class ScoreAccessor extends Number {
final DocLookup doc; Scorer scorer;
public ScoreAccessor(DocLookup d) { public ScoreAccessor(Scorer scorer) {
doc = d; this.scorer = scorer;
} }
float score() { float score() {
try { try {
return doc.score(); return scorer.score();
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException("Could not get score", e); throw new RuntimeException("Could not get score", e);
} }

View File

@ -230,9 +230,6 @@ public class ScriptService extends AbstractComponent {
} }
this.scriptEngines = builder.build(); this.scriptEngines = builder.build();
// put some default optimized scripts
staticCache.put("doc.score", new CompiledScript("native", new DocScoreNativeScriptFactory()));
// add file watcher for static scripts // add file watcher for static scripts
scriptsDirectory = new File(env.configFile(), "scripts"); scriptsDirectory = new File(env.configFile(), "scripts");
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
@ -574,22 +571,4 @@ public class ScriptService extends AbstractComponent {
return lang.hashCode() + 31 * script.hashCode(); return lang.hashCode() + 31 * script.hashCode();
} }
} }
public static class DocScoreNativeScriptFactory implements NativeScriptFactory {
@Override
public ExecutableScript newScript(@Nullable Map<String, Object> params) {
return new DocScoreSearchScript();
}
}
public static class DocScoreSearchScript extends AbstractFloatSearchScript {
@Override
public float runAsFloat() {
try {
return doc().score();
} catch (IOException e) {
return 0;
}
}
}
} }

View File

@ -43,6 +43,7 @@ import org.elasticsearch.common.logging.ESLogger;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.script.*; import org.elasticsearch.script.*;
import org.elasticsearch.search.lookup.SearchLookup; import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.search.suggest.term.TermSuggestion;
import java.io.IOException; import java.io.IOException;
import java.math.BigDecimal; import java.math.BigDecimal;
@ -186,6 +187,7 @@ public class GroovyScriptEngineService extends AbstractComponent implements Scri
private final SearchLookup lookup; private final SearchLookup lookup;
private final Map<String, Object> variables; private final Map<String, Object> variables;
private final ESLogger logger; private final ESLogger logger;
private Scorer scorer;
public GroovyScript(Script script, ESLogger logger) { public GroovyScript(Script script, ESLogger logger) {
this(script, null, logger); this(script, null, logger);
@ -196,17 +198,12 @@ public class GroovyScriptEngineService extends AbstractComponent implements Scri
this.lookup = lookup; this.lookup = lookup;
this.logger = logger; this.logger = logger;
this.variables = script.getBinding().getVariables(); this.variables = script.getBinding().getVariables();
if (lookup != null) {
// Add the _score variable, which will access score from lookup.doc()
this.variables.put("_score", new ScoreAccessor(lookup.doc()));
}
} }
@Override @Override
public void setScorer(Scorer scorer) { public void setScorer(Scorer scorer) {
if (lookup != null) { this.scorer = scorer;
lookup.setScorer(scorer); this.variables.put("_score", new ScoreAccessor(scorer));
}
} }
@Override @Override

View File

@ -49,8 +49,6 @@ public class DocLookup implements Map {
private AtomicReaderContext reader; private AtomicReaderContext reader;
private Scorer scorer;
private int docId = -1; private int docId = -1;
DocLookup(MapperService mapperService, IndexFieldDataService fieldDataService, @Nullable String[] types) { DocLookup(MapperService mapperService, IndexFieldDataService fieldDataService, @Nullable String[] types) {
@ -76,22 +74,10 @@ public class DocLookup implements Map {
localCacheFieldData.clear(); localCacheFieldData.clear();
} }
public void setScorer(Scorer scorer) {
this.scorer = scorer;
}
public void setNextDocId(int docId) { public void setNextDocId(int docId) {
this.docId = docId; this.docId = docId;
} }
public float score() throws IOException {
return scorer.score();
}
public float getScore() throws IOException {
return scorer.score();
}
@Override @Override
public Object get(Object key) { public Object get(Object key) {
// assume its a string... // assume its a string...

View File

@ -76,10 +76,6 @@ public class SearchLookup {
return this.docMap; return this.docMap;
} }
public void setScorer(Scorer scorer) {
docMap.setScorer(scorer);
}
public void setNextReader(AtomicReaderContext context) { public void setNextReader(AtomicReaderContext context) {
docMap.setNextReader(context); docMap.setNextReader(context);
sourceLookup.setNextReader(context); sourceLookup.setNextReader(context);

View File

@ -1172,7 +1172,7 @@ public class DoubleTermsTests extends ElasticsearchIntegrationTest {
.setQuery(functionScoreQuery(matchAllQuery()).add(ScoreFunctionBuilders.scriptFunction("doc['" + SINGLE_VALUED_FIELD_NAME + "'].value"))) .setQuery(functionScoreQuery(matchAllQuery()).add(ScoreFunctionBuilders.scriptFunction("doc['" + SINGLE_VALUED_FIELD_NAME + "'].value")))
.addAggregation(terms("terms") .addAggregation(terms("terms")
.collectMode(randomFrom(SubAggCollectionMode.values())) .collectMode(randomFrom(SubAggCollectionMode.values()))
.script("ceil(_doc.score()/3)") .script("ceil(_score.doubleValue()/3)")
).execute().actionGet(); ).execute().actionGet();
assertSearchResponse(response); assertSearchResponse(response);

View File

@ -270,7 +270,7 @@ public class TopHitsTests extends ElasticsearchIntegrationTest {
topHits("hits").setSize(1) topHits("hits").setSize(1)
) )
.subAggregation( .subAggregation(
max("max_score").script("_doc.score()") max("max_score").script("_score.doubleValue()")
) )
) )
.get(); .get();

View File

@ -28,6 +28,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder; import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
import org.elasticsearch.index.query.functionscore.weight.WeightBuilder; import org.elasticsearch.index.query.functionscore.weight.WeightBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.test.ElasticsearchIntegrationTest; import org.elasticsearch.test.ElasticsearchIntegrationTest;
import org.junit.Test; import org.junit.Test;
@ -40,6 +41,7 @@ import static org.elasticsearch.index.query.FilterBuilders.termFilter;
import static org.elasticsearch.index.query.QueryBuilders.functionScoreQuery; import static org.elasticsearch.index.query.QueryBuilders.functionScoreQuery;
import static org.elasticsearch.index.query.QueryBuilders.termQuery; import static org.elasticsearch.index.query.QueryBuilders.termQuery;
import static org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders.*; import static org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders.*;
import static org.elasticsearch.search.aggregations.AggregationBuilders.terms;
import static org.elasticsearch.search.builder.SearchSourceBuilder.searchSource; import static org.elasticsearch.search.builder.SearchSourceBuilder.searchSource;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
@ -388,4 +390,44 @@ public class FunctionScoreTests extends ElasticsearchIntegrationTest {
assertSearchResponse(response); assertSearchResponse(response);
assertThat(response.getHits().getAt(0).score(), equalTo(2.0f)); assertThat(response.getHits().getAt(0).score(), equalTo(2.0f));
} }
@Test
public void testScriptScoresNested() throws IOException {
index(INDEX, TYPE, "1", jsonBuilder().startObject().field("dummy_field", 1).endObject());
refresh();
SearchResponse response = client().search(
searchRequest().source(
searchSource().query(
functionScoreQuery(
functionScoreQuery(
functionScoreQuery().add(scriptFunction("1")))
.add(scriptFunction("_score.doubleValue()")))
.add(scriptFunction("_score.doubleValue()")
)
)
)
).actionGet();
assertSearchResponse(response);
assertThat(response.getHits().getAt(0).score(), equalTo(1.0f));
}
@Test
public void testScriptScoresWithAgg() throws IOException {
index(INDEX, TYPE, "1", jsonBuilder().startObject().field("dummy_field", 1).endObject());
refresh();
SearchResponse response = client().search(
searchRequest().source(
searchSource().query(
functionScoreQuery()
.add(scriptFunction("_score.doubleValue()")
)
).aggregation(terms("score_agg").script("_score.doubleValue()"))
)
).actionGet();
assertSearchResponse(response);
assertThat(response.getHits().getAt(0).score(), equalTo(1.0f));
assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getKeyAsNumber().floatValue(), is(1f));
assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getDocCount(), is(1l));
}
} }