diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java index 457b5222394..05f564c1334 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java @@ -5,25 +5,31 @@ */ package org.elasticsearch.license; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.ingest.PutPipelineAction; import org.elasticsearch.action.ingest.PutPipelineRequest; import org.elasticsearch.action.ingest.SimulateDocumentBaseResult; import org.elasticsearch.action.ingest.SimulatePipelineAction; import org.elasticsearch.action.ingest.SimulatePipelineRequest; import org.elasticsearch.action.ingest.SimulatePipelineResponse; +import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.transport.TransportClient; import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.license.License.OperationMode; import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder; import org.elasticsearch.transport.Transport; import org.elasticsearch.xpack.core.TestXPackTransportClient; import org.elasticsearch.xpack.core.XPackField; @@ -53,12 +59,15 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.ml.LocalStateMachineLearning; +import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; import org.junit.Before; import java.nio.charset.StandardCharsets; -import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; @@ -163,7 +172,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase { } } - public void testMachineLearningPutDatafeedActionRestricted() throws Exception { + public void testMachineLearningPutDatafeedActionRestricted() { String jobId = "testmachinelearningputdatafeedactionrestricted"; String datafeedId = jobId + "-datafeed"; assertMLAllowed(true); @@ -497,7 +506,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase { } } - public void testMachineLearningDeleteJobActionNotRestricted() throws Exception { + public void testMachineLearningDeleteJobActionNotRestricted() { String jobId = "testmachinelearningclosejobactionnotrestricted"; assertMLAllowed(true); // test that license restricted apis do now work @@ -522,7 +531,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase { } } - public void testMachineLearningDeleteDatafeedActionNotRestricted() throws Exception { + public void testMachineLearningDeleteDatafeedActionNotRestricted() { String jobId = "testmachinelearningdeletedatafeedactionnotrestricted"; String datafeedId = jobId + "-datafeed"; assertMLAllowed(true); @@ -554,7 +563,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase { } } - public void testMachineLearningCreateInferenceProcessorRestricted() throws Exception { + public void testMachineLearningCreateInferenceProcessorRestricted() { String modelId = "modelprocessorlicensetest"; assertMLAllowed(true); putInferenceModel(modelId); @@ -686,7 +695,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase { .actionGet(); } - public void testMachineLearningInferModelRestricted() throws Exception { + public void testMachineLearningInferModelRestricted() { String modelId = "modelinfermodellicensetest"; assertMLAllowed(true); putInferenceModel(modelId); @@ -748,6 +757,58 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase { assertThat(listener.actionGet().getInferenceResults(), is(not(empty()))); } + public void testInferenceAggRestricted() { + String modelId = "inference-agg-restricted"; + assertMLAllowed(true); + putInferenceModel(modelId); + + // index some data + String index = "inference-agg-licence-test"; + client().admin().indices().prepareCreate(index) + .addMapping("type", "feature1", "type=double", "feature2", "type=keyword").get(); + client().prepareBulk() + .add(new IndexRequest(index).source("feature1", "10.0", "feature2", "foo")) + .add(new IndexRequest(index).source("feature1", "20.0", "feature2", "foo")) + .add(new IndexRequest(index).source("feature1", "20.0", "feature2", "bar")) + .add(new IndexRequest(index).source("feature1", "20.0", "feature2", "bar")) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get(); + + TermsAggregationBuilder termsAgg = new TermsAggregationBuilder("foobar").field("feature2"); + AvgAggregationBuilder avgAgg = new AvgAggregationBuilder("avg_feature1").field("feature1"); + termsAgg.subAggregation(avgAgg); + + XPackLicenseState licenseState = internalCluster().getInstance(XPackLicenseState.class); + ModelLoadingService modelLoading = internalCluster().getInstance(ModelLoadingService.class); + + Map bucketPaths = new HashMap<>(); + bucketPaths.put("feature1", "avg_feature1"); + InferencePipelineAggregationBuilder inferenceAgg = + new InferencePipelineAggregationBuilder("infer_agg", new SetOnce<>(modelLoading), licenseState, bucketPaths); + inferenceAgg.setModelId(modelId); + + termsAgg.subAggregation(inferenceAgg); + + SearchRequest search = new SearchRequest(index); + search.source().aggregation(termsAgg); + client().search(search).actionGet(); + + // Pick a license that does not allow machine learning + License.OperationMode mode = randomInvalidLicenseType(); + enableLicensing(mode); + assertMLAllowed(false); + + // inferring against a model should now fail + SearchRequest invalidSearch = new SearchRequest(index); + invalidSearch.source().aggregation(termsAgg); + ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, + () -> client().search(invalidSearch).actionGet()); + + assertThat(e.status(), is(RestStatus.FORBIDDEN)); + assertThat(e.getMessage(), containsString("current license is non-compliant for [ml]")); + assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING)); + } + private void putInferenceModel(String modelId) { TrainedModelConfig config = TrainedModelConfig.builder() .setParsedDefinition( @@ -755,13 +816,13 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase { .setTrainedModel( Tree.builder() .setTargetType(TargetType.REGRESSION) - .setFeatureNames(Arrays.asList("feature1")) + .setFeatureNames(Collections.singletonList("feature1")) .setNodes(TreeNode.builder(0).setLeafValue(1.0)) .build()) .setPreProcessors(Collections.emptyList())) .setModelId(modelId) .setDescription("test model for classification") - .setInput(new TrainedModelInput(Arrays.asList("feature1"))) + .setInput(new TrainedModelInput(Collections.singletonList("feature1"))) .setInferenceConfig(RegressionConfig.EMPTY_PARAMS) .build(); client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index ff7a82cde02..2a6dbecf76e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -980,10 +980,9 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, @Override public List getPipelineAggregations() { PipelineAggregationSpec spec = new PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME, - in -> new InferencePipelineAggregationBuilder(in, modelLoadingService), + in -> new InferencePipelineAggregationBuilder(in, getLicenseState(), modelLoadingService), (ContextParser) - (parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, name, parser - )); + (parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, getLicenseState(), name, parser)); spec.addResultReader(InternalInferenceAggregation::new); return Collections.singletonList(spec); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilder.java index 9ca041c6b73..940b6a7178a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilder.java @@ -10,15 +10,17 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; @@ -44,10 +46,10 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega static final String AGGREGATIONS_RESULTS_FIELD = "value"; @SuppressWarnings("unchecked") - private static final ConstructingObjectParser, String>> PARSER = new ConstructingObjectParser<>( - NAME, false, - (args, context) -> new InferencePipelineAggregationBuilder(context.v2(), context.v1(), (Map) args[0]) + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, false, + (args, context) -> new InferencePipelineAggregationBuilder(context.name, context.modelLoadingService, + context.licenseState, (Map) args[0]) ); static { @@ -60,34 +62,52 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega private final Map bucketPathMap; private String modelId; private InferenceConfigUpdate inferenceConfig; + private final XPackLicenseState licenseState; private final SetOnce modelLoadingService; /** * The model. Set to a non-null value during the rewrite phase. */ private final Supplier model; + private static class ParserSupplement { + final XPackLicenseState licenseState; + final SetOnce modelLoadingService; + final String name; + + ParserSupplement(String name, XPackLicenseState licenseState, SetOnce modelLoadingService) { + this.name = name; + this.licenseState = licenseState; + this.modelLoadingService = modelLoadingService; + } + } public static InferencePipelineAggregationBuilder parse(SetOnce modelLoadingService, + XPackLicenseState licenseState, String pipelineAggregatorName, XContentParser parser) { - Tuple, String> context = new Tuple<>(modelLoadingService, pipelineAggregatorName); - return PARSER.apply(parser, context); + return PARSER.apply(parser, new ParserSupplement(pipelineAggregatorName, licenseState, modelLoadingService)); } - public InferencePipelineAggregationBuilder(String name, SetOnce modelLoadingService, + public InferencePipelineAggregationBuilder(String name, + SetOnce modelLoadingService, + XPackLicenseState licenseState, Map bucketsPath) { super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {})); this.modelLoadingService = modelLoadingService; this.bucketPathMap = bucketsPath; this.model = null; + this.licenseState = licenseState; } - public InferencePipelineAggregationBuilder(StreamInput in, SetOnce modelLoadingService) throws IOException { + public InferencePipelineAggregationBuilder(StreamInput in, + XPackLicenseState licenseState, + SetOnce modelLoadingService) throws IOException { super(in, NAME); modelId = in.readString(); bucketPathMap = in.readMap(StreamInput::readString, StreamInput::readString); inferenceConfig = in.readOptionalNamedWriteable(InferenceConfigUpdate.class); this.modelLoadingService = modelLoadingService; this.model = null; + this.licenseState = licenseState; } /** @@ -98,7 +118,8 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega Map bucketsPath, Supplier model, String modelId, - InferenceConfigUpdate inferenceConfig + InferenceConfigUpdate inferenceConfig, + XPackLicenseState licenseState ) { super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {})); modelLoadingService = null; @@ -113,13 +134,14 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega */ this.modelId = modelId; this.inferenceConfig = inferenceConfig; + this.licenseState = licenseState; } - void setModelId(String modelId) { + public void setModelId(String modelId) { this.modelId = modelId; } - void setInferenceConfig(InferenceConfigUpdate inferenceConfig) { + public void setInferenceConfig(InferenceConfigUpdate inferenceConfig) { this.inferenceConfig = inferenceConfig; } @@ -160,7 +182,7 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega } @Override - public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) throws IOException { + public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) { if (model != null) { return this; } @@ -168,10 +190,17 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega context.registerAsyncAction((client, listener) -> { modelLoadingService.get().getModelForSearch(modelId, ActionListener.delegateFailure(listener, (delegate, model) -> { loadedModel.set(model); - delegate.onResponse(null); + + boolean isLicensed = licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING) || + licenseState.isAllowedByLicense(model.getLicenseLevel()); + if (isLicensed) { + delegate.onResponse(null); + } else { + delegate.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + } })); }); - return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig); + return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig, licenseState); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index 88e92dfa929..bf87427efee 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.license.License; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; @@ -38,6 +39,7 @@ public class LocalModel { private volatile long persistenceQuotient = 100; private final LongAdder currentInferenceCount; private final InferenceConfig inferenceConfig; + private final License.OperationMode licenseLevel; public LocalModel(String modelId, String nodeId, @@ -45,6 +47,7 @@ public class LocalModel { TrainedModelInput input, Map defaultFieldMap, InferenceConfig modelInferenceConfig, + License.OperationMode licenseLevel, TrainedModelStatsService trainedModelStatsService) { this.trainedModelDefinition = trainedModelDefinition; this.modelId = modelId; @@ -56,6 +59,7 @@ public class LocalModel { this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap); this.currentInferenceCount = new LongAdder(); this.inferenceConfig = modelInferenceConfig; + this.licenseLevel = licenseLevel; } long ramBytesUsed() { @@ -66,6 +70,10 @@ public class LocalModel { return modelId; } + public License.OperationMode getLicenseLevel() { + return licenseLevel; + } + public InferenceStats getLatestStatsAndReset() { return statsAccumulator.currentStatsAndReset(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 590240556ce..672fd67d2ce 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -309,6 +309,7 @@ public class ModelLoadingService implements ClusterStateListener { trainedModelConfig.getInput(), trainedModelConfig.getDefaultFieldMap(), inferenceConfig, + trainedModelConfig.getLicenseLevel(), modelStatsService)); }, // Failure getting the definition, remove the initial estimation value @@ -337,6 +338,7 @@ public class ModelLoadingService implements ClusterStateListener { trainedModelConfig.getInput(), trainedModelConfig.getDefaultFieldMap(), inferenceConfig, + trainedModelConfig.getLicenseLevel(), modelStatsService); synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilderTests.java index d1ea66dfefc..40b50488f35 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilderTests.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase; @@ -61,7 +62,8 @@ public class InferencePipelineAggregationBuilderTests extends BasePipelineAggreg .collect(Collectors.toMap(Function.identity(), (t) -> randomAlphaOfLength(5))); InferencePipelineAggregationBuilder builder = - new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)), bucketPaths); + new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)), + mock(XPackLicenseState.class), bucketPaths); builder.setModelId(randomAlphaOfLength(6)); if (randomBoolean()) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 36f05f0b0a9..10eb7e7161a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.license.License; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; @@ -73,6 +74,7 @@ public class LocalModelTests extends ESTestCase { new TrainedModelInput(inputFields), Collections.singletonMap("field.foo", "field.foo.keyword"), ClassificationConfig.EMPTY_PARAMS, + randomFrom(License.OperationMode.values()), modelStatsService); Map fields = new HashMap() {{ put("field.foo", 1.0); @@ -102,6 +104,7 @@ public class LocalModelTests extends ESTestCase { new TrainedModelInput(inputFields), Collections.singletonMap("field.foo", "field.foo.keyword"), ClassificationConfig.EMPTY_PARAMS, + License.OperationMode.PLATINUM, modelStatsService); result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS); assertThat(result.value(), equalTo(0.0)); @@ -144,6 +147,7 @@ public class LocalModelTests extends ESTestCase { new TrainedModelInput(inputFields), Collections.singletonMap("field.foo", "field.foo.keyword"), ClassificationConfig.EMPTY_PARAMS, + License.OperationMode.PLATINUM, modelStatsService); Map fields = new HashMap() {{ put("field.foo", 1.0); @@ -199,6 +203,7 @@ public class LocalModelTests extends ESTestCase { new TrainedModelInput(inputFields), Collections.singletonMap("bar", "bar.keyword"), RegressionConfig.EMPTY_PARAMS, + License.OperationMode.PLATINUM, modelStatsService); Map fields = new HashMap() {{ @@ -226,6 +231,7 @@ public class LocalModelTests extends ESTestCase { new TrainedModelInput(inputFields), null, RegressionConfig.EMPTY_PARAMS, + License.OperationMode.PLATINUM, modelStatsService); Map fields = new HashMap() {{ @@ -256,6 +262,7 @@ public class LocalModelTests extends ESTestCase { new TrainedModelInput(inputFields), null, ClassificationConfig.EMPTY_PARAMS, + License.OperationMode.PLATINUM, modelStatsService ); Map fields = new HashMap() {{