Add scripting support to AggregatorTestCase (#43494)

This refactors AggregatorTestCase to allow testing mock scripts.
The main change is to QueryShardContext.  This was previously mocked,
but to get the ScriptService you have to invoke a final method
which can't be mocked.

Instead, we just create a mostly-empty QueryShardContext and populate
the fields that are needed for testing.  It also introduces a few
new helper methods that can be overridden to change the default
behavior a bit.

Most tests should be able to override getMockScriptService() to supply
a ScriptService to the context, which is later used by the aggs.
More complicated tests can override queryShardContextMock() as before.

Adds a test to MaxAggregatorTests to test out the new functionality.
This commit is contained in:
Zachary Tong 2019-06-25 11:50:19 -04:00
parent 2beb193311
commit 63fef5a31e
3 changed files with 75 additions and 19 deletions

View File

@ -46,8 +46,15 @@ import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FutureArrays;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.script.MockScriptEngine;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.ScriptModule;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.aggregations.AggregatorTestCase;
import org.elasticsearch.search.aggregations.support.AggregationInspectionHelper;
@ -57,6 +64,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
@ -65,6 +73,19 @@ import static java.util.Collections.singleton;
import static org.hamcrest.Matchers.equalTo;
public class MaxAggregatorTests extends AggregatorTestCase {
private final String SCRIPT_NAME = "script_name";
private final long SCRIPT_VALUE = 19L;
@Override
protected ScriptService getMockScriptService() {
MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
Collections.singletonMap(SCRIPT_NAME, script -> SCRIPT_VALUE), // return 19 from script
Collections.emptyMap());
Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
}
public void testNoDocs() throws IOException {
testCase(new MatchAllDocsQuery(), iw -> {
// Intentionally not writing any docs
@ -147,6 +168,23 @@ public class MaxAggregatorTests extends AggregatorTestCase {
}, null);
}
public void testScript() throws IOException {
MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
fieldType.setName("number");
MaxAggregationBuilder aggregationBuilder = new MaxAggregationBuilder("_name")
.field("number")
.script(new Script(ScriptType.INLINE, MockScriptEngine.NAME, SCRIPT_NAME, Collections.emptyMap()));
testCase(aggregationBuilder, new DocValuesFieldExistsQuery("number"), iw -> {
iw.addDocument(singleton(new NumericDocValuesField("number", 7)));
iw.addDocument(singleton(new NumericDocValuesField("number", 1)));
}, max -> {
assertEquals(max.getValue(), SCRIPT_VALUE, 0); // Note this is the script value (19L), not the doc values above
assertTrue(AggregationInspectionHelper.hasValue(max));
}, fieldType);
}
private void testCase(Query query,
CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
Consumer<InternalMax> verify) throws IOException {
@ -282,4 +320,5 @@ public class MaxAggregatorTests extends AggregatorTestCase {
});
assertTrue(seen[0]);
}
}

View File

@ -26,8 +26,10 @@ import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.store.Directory;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.script.MockScriptEngine;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptEngine;
@ -403,11 +405,12 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
* is final and cannot be mocked
*/
@Override
protected QueryShardContext queryShardContextMock(MapperService mapperService) {
protected QueryShardContext queryShardContextMock(MapperService mapperService, IndexSettings indexSettings,
CircuitBreakerService circuitBreakerService) {
MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME, SCRIPTS, Collections.emptyMap());
Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
ScriptService scriptService = new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
return new QueryShardContext(0, mapperService.getIndexSettings(), null, null, null, mapperService, null, scriptService,
return new QueryShardContext(0, indexSettings, null, null, null, mapperService, null, scriptService,
xContentRegistry(), writableRegistry(), null, null, System::currentTimeMillis, null);
}
}

View File

@ -48,6 +48,7 @@ import org.elasticsearch.index.cache.bitset.BitsetFilterCache;
import org.elasticsearch.index.cache.bitset.BitsetFilterCache.Listener;
import org.elasticsearch.index.cache.query.DisabledQueryCache;
import org.elasticsearch.index.engine.Engine;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.fielddata.IndexFieldDataCache;
import org.elasticsearch.index.fielddata.IndexFieldDataService;
import org.elasticsearch.index.mapper.ContentPath;
@ -58,7 +59,6 @@ import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.ObjectMapper;
import org.elasticsearch.index.mapper.ObjectMapper.Nested;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.index.query.support.NestedScope;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
@ -83,10 +83,10 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
@ -157,8 +157,8 @@ public abstract class AggregatorTestCase extends ESTestCase {
SearchLookup searchLookup = new SearchLookup(mapperService, ifds::getForField, new String[]{TYPE_NAME});
when(searchContext.lookup()).thenReturn(searchLookup);
QueryShardContext queryShardContext = queryShardContextMock(mapperService);
when(queryShardContext.getIndexSettings()).thenReturn(indexSettings);
QueryShardContext queryShardContext = queryShardContextMock(mapperService, indexSettings, circuitBreakerService);
when(searchContext.getQueryShardContext()).thenReturn(queryShardContext);
Map<String, MappedFieldType> fieldNameToType = new HashMap<>();
fieldNameToType.putAll(Arrays.stream(fieldTypes)
@ -189,16 +189,11 @@ public abstract class AggregatorTestCase extends ESTestCase {
String fieldName = entry.getKey();
MappedFieldType fieldType = entry.getValue();
when(queryShardContext.fieldMapper(fieldName)).thenReturn(fieldType);
when(mapperService.fullName(fieldName)).thenReturn(fieldType);
when(searchContext.smartNameFieldType(fieldName)).thenReturn(fieldType);
}
for (MappedFieldType fieldType : new HashSet<>(fieldNameToType.values())) {
when(queryShardContext.getForField(fieldType)).then(invocation ->
fieldType.fielddataBuilder(mapperService.getIndexSettings().getIndex().getName())
.build(mapperService.getIndexSettings(), fieldType,
new IndexFieldDataCache.None(), circuitBreakerService, mapperService));
}
}
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
@ -304,12 +299,31 @@ public abstract class AggregatorTestCase extends ESTestCase {
/**
* sub-tests that need a more complex mock can overwrite this
*/
protected QueryShardContext queryShardContextMock(MapperService mapperService) {
QueryShardContext queryShardContext = mock(QueryShardContext.class);
when(queryShardContext.getMapperService()).thenReturn(mapperService);
NestedScope nestedScope = new NestedScope();
when(queryShardContext.nestedScope()).thenReturn(nestedScope);
return queryShardContext;
protected QueryShardContext queryShardContextMock(MapperService mapperService, IndexSettings indexSettings,
CircuitBreakerService circuitBreakerService) {
return new QueryShardContext(0, indexSettings, null, null,
getIndexFieldDataLookup(mapperService, circuitBreakerService),
mapperService, null, getMockScriptService(), xContentRegistry(),
writableRegistry(), null, null, System::currentTimeMillis, null);
}
/**
* Sub-tests that need a more complex index field data provider can override this
*/
protected BiFunction<MappedFieldType, String, IndexFieldData<?>> getIndexFieldDataLookup(MapperService mapperService,
CircuitBreakerService circuitBreakerService) {
return (fieldType, s) -> fieldType.fielddataBuilder(mapperService.getIndexSettings().getIndex().getName())
.build(mapperService.getIndexSettings(), fieldType,
new IndexFieldDataCache.None(), circuitBreakerService, mapperService);
}
/**
* Sub-tests that need scripting can override this method to provide a script service and pre-baked scripts
*/
protected ScriptService getMockScriptService() {
return null;
}
protected <A extends InternalAggregation, C extends Aggregator> A search(IndexSearcher searcher,