diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScoreTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScoreTests.java index d02d2957ab5..deaabfcebc6 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScoreTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScoreTests.java @@ -19,50 +19,65 @@ package org.elasticsearch.painless; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.script.Script; -import org.elasticsearch.script.ScriptService; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.test.ESSingleNodeTestCase; +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicBoolean; -import java.util.Collection; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; -public class ScoreTests extends ESSingleNodeTestCase { - @Override - protected Collection> getPlugins() { - return pluginList(PainlessPlugin.class); +public class ScoreTests extends ScriptTestCase { + + /** Most of a dummy scorer impl that requires overriding just score(). */ + abstract class MockScorer extends Scorer { + MockScorer() { + super(null); + } + @Override + public int docID() { + return 0; + } + @Override + public int freq() throws IOException { + throw new UnsupportedOperationException(); + } + @Override + public DocIdSetIterator iterator() { + throw new UnsupportedOperationException(); + } } - public void testScore() { - createIndex("test", Settings.EMPTY, "type", "t", "type=text"); - ensureGreen("test"); + public void testScoreWorks() { + assertEquals(2.5, exec("_score", Collections.emptyMap(), Collections.emptyMap(), + new MockScorer() { + @Override + public float score() throws IOException { + return 2.5f; + } + })); + } - client().prepareIndex("test", "type", "1").setSource("t", "a").get(); - client().prepareIndex("test", "type", "2").setSource("t", "a a b").get(); - client().prepareIndex("test", "type", "3").setSource("t", "a a a b c").get(); - client().prepareIndex("test", "type", "4").setSource("t", "a b c d").get(); - client().prepareIndex("test", "type", "5").setSource("t", "a a b c d e").get(); - client().admin().indices().prepareRefresh("test").get(); + public void testScoreNotUsed() { + assertEquals(3.5, exec("3.5", Collections.emptyMap(), Collections.emptyMap(), + new MockScorer() { + @Override + public float score() throws IOException { + throw new AssertionError("score() should not be called"); + } + })); + } - final Script script = new Script("_score + 1", ScriptService.ScriptType.INLINE, "painless", null); - - final SearchResponse sr = client().prepareSearch("test").setQuery( - QueryBuilders.functionScoreQuery(QueryBuilders.matchQuery("t", "a"), - ScoreFunctionBuilders.scriptFunction(script))).get(); - final SearchHit[] hits = sr.getHits().getHits(); - - for (final SearchHit hit : hits) { - assertTrue(hit.score() > 0.9999F && hit.score() < 2.0001F); - } - - assertEquals("1", hits[0].getId()); - assertEquals("3", hits[1].getId()); - assertEquals("2", hits[2].getId()); - assertEquals("5", hits[3].getId()); - assertEquals("4", hits[4].getId()); + public void testScoreCached() { + assertEquals(9.0, exec("_score + _score", Collections.emptyMap(), Collections.emptyMap(), + new MockScorer() { + private boolean used = false; + @Override + public float score() throws IOException { + if (used == false) { + return 4.5f; + } + throw new AssertionError("score() should not be called twice"); + } + })); } } diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptTestCase.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptTestCase.java index d95fa3897da..27558dc745e 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptTestCase.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptTestCase.java @@ -19,8 +19,11 @@ package org.elasticsearch.painless; +import org.apache.lucene.search.Scorer; +import org.elasticsearch.common.lucene.ScorerAware; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.script.CompiledScript; +import org.elasticsearch.script.ExecutableScript; import org.elasticsearch.script.ScriptService; import org.elasticsearch.test.ESTestCase; import org.junit.Before; @@ -51,14 +54,18 @@ public abstract class ScriptTestCase extends ESTestCase { public Object exec(String script, Map vars) { Map compilerSettings = new HashMap<>(); compilerSettings.put(CompilerSettings.PICKY, "true"); - return exec(script, vars, compilerSettings); + return exec(script, vars, compilerSettings, null); } /** Compiles and returns the result of {@code script} with access to {@code vars} and compile-time parameters */ - public Object exec(String script, Map vars, Map compileParams) { + public Object exec(String script, Map vars, Map compileParams, Scorer scorer) { Object object = scriptEngine.compile(null, script, compileParams); CompiledScript compiled = new CompiledScript(ScriptService.ScriptType.INLINE, getTestName(), "painless", object); - return scriptEngine.executable(compiled, vars).run(); + ExecutableScript executableScript = scriptEngine.executable(compiled, vars); + if (scorer != null) { + ((ScorerAware)executableScript).setScorer(scorer); + } + return executableScript.run(); } /** diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/WhenThingsGoWrongTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/WhenThingsGoWrongTests.java index 9cdce7583f5..c2b40df8c9d 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/WhenThingsGoWrongTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/WhenThingsGoWrongTests.java @@ -77,7 +77,7 @@ public class WhenThingsGoWrongTests extends ScriptTestCase { public void testBogusParameter() { IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> { - exec("return 5;", null, Collections.singletonMap("bogusParameterKey", "bogusParameterValue")); + exec("return 5;", null, Collections.singletonMap("bogusParameterKey", "bogusParameterValue"), null); }); assertTrue(expected.getMessage().contains("Unrecognized compile-time parameter")); }