Allow scripted metric agg to access `_score` (#24295)
* Fixes #24259 Corrects the ScriptedMetricAggregator so that the script can have access to scores during the map stage. * Restored original tests. Added seperate test. As requested, I've restored the non-score dependant tests, and added the score dependent metric as a seperate test.
This commit is contained in:
parent
ad3c042fc4
commit
c1ba4fdcb4
|
@ -63,7 +63,7 @@ public class ScriptedMetricAggregator extends MetricsAggregator {
|
||||||
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
|
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
|
||||||
final LeafBucketCollector sub) throws IOException {
|
final LeafBucketCollector sub) throws IOException {
|
||||||
final LeafSearchScript leafMapScript = mapScript.getLeafSearchScript(ctx);
|
final LeafSearchScript leafMapScript = mapScript.getLeafSearchScript(ctx);
|
||||||
return new LeafBucketCollectorBase(sub, mapScript) {
|
return new LeafBucketCollectorBase(sub, leafMapScript) {
|
||||||
@Override
|
@Override
|
||||||
public void collect(int doc, long bucket) throws IOException {
|
public void collect(int doc, long bucket) throws IOException {
|
||||||
assert bucket == 0 : bucket;
|
assert bucket == 0 : bucket;
|
||||||
|
|
|
@ -33,6 +33,7 @@ import org.elasticsearch.index.mapper.MappedFieldType;
|
||||||
import org.elasticsearch.index.query.QueryShardContext;
|
import org.elasticsearch.index.query.QueryShardContext;
|
||||||
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
||||||
import org.elasticsearch.script.MockScriptEngine;
|
import org.elasticsearch.script.MockScriptEngine;
|
||||||
|
import org.elasticsearch.script.ScoreAccessor;
|
||||||
import org.elasticsearch.script.Script;
|
import org.elasticsearch.script.Script;
|
||||||
import org.elasticsearch.script.ScriptContextRegistry;
|
import org.elasticsearch.script.ScriptContextRegistry;
|
||||||
import org.elasticsearch.script.ScriptEngineRegistry;
|
import org.elasticsearch.script.ScriptEngineRegistry;
|
||||||
|
@ -59,6 +60,11 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
|
||||||
private static final Script MAP_SCRIPT = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScript", Collections.emptyMap());
|
private static final Script MAP_SCRIPT = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScript", Collections.emptyMap());
|
||||||
private static final Script COMBINE_SCRIPT = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScript",
|
private static final Script COMBINE_SCRIPT = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScript",
|
||||||
Collections.emptyMap());
|
Collections.emptyMap());
|
||||||
|
|
||||||
|
private static final Script INIT_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptScore", Collections.emptyMap());
|
||||||
|
private static final Script MAP_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptScore", Collections.emptyMap());
|
||||||
|
private static final Script COMBINE_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptScore",
|
||||||
|
Collections.emptyMap());
|
||||||
private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();
|
private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,6 +85,21 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
|
||||||
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
|
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
|
||||||
return ((List<Integer>) agg.get("collector")).stream().mapToInt(Integer::intValue).sum();
|
return ((List<Integer>) agg.get("collector")).stream().mapToInt(Integer::intValue).sum();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
SCRIPTS.put("initScriptScore", params -> {
|
||||||
|
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
|
||||||
|
agg.put("collector", new ArrayList<Double>());
|
||||||
|
return agg;
|
||||||
|
});
|
||||||
|
SCRIPTS.put("mapScriptScore", params -> {
|
||||||
|
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
|
||||||
|
((List<Double>) agg.get("collector")).add(((ScoreAccessor) params.get("_score")).doubleValue());
|
||||||
|
return agg;
|
||||||
|
});
|
||||||
|
SCRIPTS.put("combineScriptScore", params -> {
|
||||||
|
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
|
||||||
|
return ((List<Double>) agg.get("collector")).stream().mapToDouble(Double::doubleValue).sum();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
|
@ -144,6 +165,29 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* test that uses the score of the documents
|
||||||
|
*/
|
||||||
|
public void testScriptedMetricWithCombineAccessesScores() throws IOException {
|
||||||
|
try (Directory directory = newDirectory()) {
|
||||||
|
Integer numDocs = randomInt(100);
|
||||||
|
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
|
||||||
|
for (int i = 0; i < numDocs; i++) {
|
||||||
|
indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
try (IndexReader indexReader = DirectoryReader.open(directory)) {
|
||||||
|
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
|
||||||
|
aggregationBuilder.initScript(INIT_SCRIPT_SCORE).mapScript(MAP_SCRIPT_SCORE).combineScript(COMBINE_SCRIPT_SCORE);
|
||||||
|
ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder);
|
||||||
|
assertEquals(AGG_NAME, scriptedMetric.getName());
|
||||||
|
assertNotNull(scriptedMetric.aggregation());
|
||||||
|
// all documents have score of 1.0
|
||||||
|
assertEquals((double) numDocs, scriptedMetric.aggregation());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* We cannot use Mockito for mocking QueryShardContext in this case because
|
* We cannot use Mockito for mocking QueryShardContext in this case because
|
||||||
* script-related methods (e.g. QueryShardContext#getLazyExecutableScript)
|
* script-related methods (e.g. QueryShardContext#getLazyExecutableScript)
|
||||||
|
|
Loading…
Reference in New Issue