[7.x] Add ml licence check to the pipeline inference agg. (#59213) (#59412)

Ensures the licence is sufficient for the model used in inference
This commit is contained in:
David Kyle 2020-07-14 14:03:10 +01:00 committed by GitHub
parent f651487d74
commit d86435938b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 136 additions and 28 deletions

View File

@ -5,25 +5,31 @@
*/ */
package org.elasticsearch.license; package org.elasticsearch.license;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.ingest.PutPipelineAction; import org.elasticsearch.action.ingest.PutPipelineAction;
import org.elasticsearch.action.ingest.PutPipelineRequest; import org.elasticsearch.action.ingest.PutPipelineRequest;
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult; import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
import org.elasticsearch.action.ingest.SimulatePipelineAction; import org.elasticsearch.action.ingest.SimulatePipelineAction;
import org.elasticsearch.action.ingest.SimulatePipelineRequest; import org.elasticsearch.action.ingest.SimulatePipelineRequest;
import org.elasticsearch.action.ingest.SimulatePipelineResponse; import org.elasticsearch.action.ingest.SimulatePipelineResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.transport.TransportClient; import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.license.License.OperationMode; import org.elasticsearch.license.License.OperationMode;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport;
import org.elasticsearch.xpack.core.TestXPackTransportClient; import org.elasticsearch.xpack.core.TestXPackTransportClient;
import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.XPackField;
@ -53,12 +59,15 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.ml.LocalStateMachineLearning; import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
import org.junit.Before; import org.junit.Before;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
@ -163,7 +172,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
} }
} }
public void testMachineLearningPutDatafeedActionRestricted() throws Exception { public void testMachineLearningPutDatafeedActionRestricted() {
String jobId = "testmachinelearningputdatafeedactionrestricted"; String jobId = "testmachinelearningputdatafeedactionrestricted";
String datafeedId = jobId + "-datafeed"; String datafeedId = jobId + "-datafeed";
assertMLAllowed(true); assertMLAllowed(true);
@ -497,7 +506,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
} }
} }
public void testMachineLearningDeleteJobActionNotRestricted() throws Exception { public void testMachineLearningDeleteJobActionNotRestricted() {
String jobId = "testmachinelearningclosejobactionnotrestricted"; String jobId = "testmachinelearningclosejobactionnotrestricted";
assertMLAllowed(true); assertMLAllowed(true);
// test that license restricted apis do now work // test that license restricted apis do now work
@ -522,7 +531,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
} }
} }
public void testMachineLearningDeleteDatafeedActionNotRestricted() throws Exception { public void testMachineLearningDeleteDatafeedActionNotRestricted() {
String jobId = "testmachinelearningdeletedatafeedactionnotrestricted"; String jobId = "testmachinelearningdeletedatafeedactionnotrestricted";
String datafeedId = jobId + "-datafeed"; String datafeedId = jobId + "-datafeed";
assertMLAllowed(true); assertMLAllowed(true);
@ -554,7 +563,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
} }
} }
public void testMachineLearningCreateInferenceProcessorRestricted() throws Exception { public void testMachineLearningCreateInferenceProcessorRestricted() {
String modelId = "modelprocessorlicensetest"; String modelId = "modelprocessorlicensetest";
assertMLAllowed(true); assertMLAllowed(true);
putInferenceModel(modelId); putInferenceModel(modelId);
@ -686,7 +695,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
.actionGet(); .actionGet();
} }
public void testMachineLearningInferModelRestricted() throws Exception { public void testMachineLearningInferModelRestricted() {
String modelId = "modelinfermodellicensetest"; String modelId = "modelinfermodellicensetest";
assertMLAllowed(true); assertMLAllowed(true);
putInferenceModel(modelId); putInferenceModel(modelId);
@ -748,6 +757,58 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
assertThat(listener.actionGet().getInferenceResults(), is(not(empty()))); assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
} }
public void testInferenceAggRestricted() {
String modelId = "inference-agg-restricted";
assertMLAllowed(true);
putInferenceModel(modelId);
// index some data
String index = "inference-agg-licence-test";
client().admin().indices().prepareCreate(index)
.addMapping("type", "feature1", "type=double", "feature2", "type=keyword").get();
client().prepareBulk()
.add(new IndexRequest(index).source("feature1", "10.0", "feature2", "foo"))
.add(new IndexRequest(index).source("feature1", "20.0", "feature2", "foo"))
.add(new IndexRequest(index).source("feature1", "20.0", "feature2", "bar"))
.add(new IndexRequest(index).source("feature1", "20.0", "feature2", "bar"))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get();
TermsAggregationBuilder termsAgg = new TermsAggregationBuilder("foobar").field("feature2");
AvgAggregationBuilder avgAgg = new AvgAggregationBuilder("avg_feature1").field("feature1");
termsAgg.subAggregation(avgAgg);
XPackLicenseState licenseState = internalCluster().getInstance(XPackLicenseState.class);
ModelLoadingService modelLoading = internalCluster().getInstance(ModelLoadingService.class);
Map<String, String> bucketPaths = new HashMap<>();
bucketPaths.put("feature1", "avg_feature1");
InferencePipelineAggregationBuilder inferenceAgg =
new InferencePipelineAggregationBuilder("infer_agg", new SetOnce<>(modelLoading), licenseState, bucketPaths);
inferenceAgg.setModelId(modelId);
termsAgg.subAggregation(inferenceAgg);
SearchRequest search = new SearchRequest(index);
search.source().aggregation(termsAgg);
client().search(search).actionGet();
// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
enableLicensing(mode);
assertMLAllowed(false);
// inferring against a model should now fail
SearchRequest invalidSearch = new SearchRequest(index);
invalidSearch.source().aggregation(termsAgg);
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class,
() -> client().search(invalidSearch).actionGet());
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("current license is non-compliant for [ml]"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
}
private void putInferenceModel(String modelId) { private void putInferenceModel(String modelId) {
TrainedModelConfig config = TrainedModelConfig.builder() TrainedModelConfig config = TrainedModelConfig.builder()
.setParsedDefinition( .setParsedDefinition(
@ -755,13 +816,13 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
.setTrainedModel( .setTrainedModel(
Tree.builder() Tree.builder()
.setTargetType(TargetType.REGRESSION) .setTargetType(TargetType.REGRESSION)
.setFeatureNames(Arrays.asList("feature1")) .setFeatureNames(Collections.singletonList("feature1"))
.setNodes(TreeNode.builder(0).setLeafValue(1.0)) .setNodes(TreeNode.builder(0).setLeafValue(1.0))
.build()) .build())
.setPreProcessors(Collections.emptyList())) .setPreProcessors(Collections.emptyList()))
.setModelId(modelId) .setModelId(modelId)
.setDescription("test model for classification") .setDescription("test model for classification")
.setInput(new TrainedModelInput(Arrays.asList("feature1"))) .setInput(new TrainedModelInput(Collections.singletonList("feature1")))
.setInferenceConfig(RegressionConfig.EMPTY_PARAMS) .setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
.build(); .build();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet(); client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();

View File

@ -980,10 +980,9 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
@Override @Override
public List<PipelineAggregationSpec> getPipelineAggregations() { public List<PipelineAggregationSpec> getPipelineAggregations() {
PipelineAggregationSpec spec = new PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME, PipelineAggregationSpec spec = new PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME,
in -> new InferencePipelineAggregationBuilder(in, modelLoadingService), in -> new InferencePipelineAggregationBuilder(in, getLicenseState(), modelLoadingService),
(ContextParser<String, ? extends PipelineAggregationBuilder>) (ContextParser<String, ? extends PipelineAggregationBuilder>)
(parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, name, parser (parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, getLicenseState(), name, parser));
));
spec.addResultReader(InternalInferenceAggregation::new); spec.addResultReader(InternalInferenceAggregation::new);
return Collections.singletonList(spec); return Collections.singletonList(spec);

View File

@ -10,15 +10,17 @@ import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder; import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
@ -44,10 +46,10 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
static final String AGGREGATIONS_RESULTS_FIELD = "value"; static final String AGGREGATIONS_RESULTS_FIELD = "value";
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder, private static final ConstructingObjectParser<InferencePipelineAggregationBuilder, ParserSupplement> PARSER =
Tuple<SetOnce<ModelLoadingService>, String>> PARSER = new ConstructingObjectParser<>( new ConstructingObjectParser<>(NAME, false,
NAME, false, (args, context) -> new InferencePipelineAggregationBuilder(context.name, context.modelLoadingService,
(args, context) -> new InferencePipelineAggregationBuilder(context.v2(), context.v1(), (Map<String, String>) args[0]) context.licenseState, (Map<String, String>) args[0])
); );
static { static {
@ -60,34 +62,52 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
private final Map<String, String> bucketPathMap; private final Map<String, String> bucketPathMap;
private String modelId; private String modelId;
private InferenceConfigUpdate inferenceConfig; private InferenceConfigUpdate inferenceConfig;
private final XPackLicenseState licenseState;
private final SetOnce<ModelLoadingService> modelLoadingService; private final SetOnce<ModelLoadingService> modelLoadingService;
/** /**
* The model. Set to a non-null value during the rewrite phase. * The model. Set to a non-null value during the rewrite phase.
*/ */
private final Supplier<LocalModel> model; private final Supplier<LocalModel> model;
private static class ParserSupplement {
final XPackLicenseState licenseState;
final SetOnce<ModelLoadingService> modelLoadingService;
final String name;
ParserSupplement(String name, XPackLicenseState licenseState, SetOnce<ModelLoadingService> modelLoadingService) {
this.name = name;
this.licenseState = licenseState;
this.modelLoadingService = modelLoadingService;
}
}
public static InferencePipelineAggregationBuilder parse(SetOnce<ModelLoadingService> modelLoadingService, public static InferencePipelineAggregationBuilder parse(SetOnce<ModelLoadingService> modelLoadingService,
XPackLicenseState licenseState,
String pipelineAggregatorName, String pipelineAggregatorName,
XContentParser parser) { XContentParser parser) {
Tuple<SetOnce<ModelLoadingService>, String> context = new Tuple<>(modelLoadingService, pipelineAggregatorName); return PARSER.apply(parser, new ParserSupplement(pipelineAggregatorName, licenseState, modelLoadingService));
return PARSER.apply(parser, context);
} }
public InferencePipelineAggregationBuilder(String name, SetOnce<ModelLoadingService> modelLoadingService, public InferencePipelineAggregationBuilder(String name,
SetOnce<ModelLoadingService> modelLoadingService,
XPackLicenseState licenseState,
Map<String, String> bucketsPath) { Map<String, String> bucketsPath) {
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {})); super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
this.modelLoadingService = modelLoadingService; this.modelLoadingService = modelLoadingService;
this.bucketPathMap = bucketsPath; this.bucketPathMap = bucketsPath;
this.model = null; this.model = null;
this.licenseState = licenseState;
} }
public InferencePipelineAggregationBuilder(StreamInput in, SetOnce<ModelLoadingService> modelLoadingService) throws IOException { public InferencePipelineAggregationBuilder(StreamInput in,
XPackLicenseState licenseState,
SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
super(in, NAME); super(in, NAME);
modelId = in.readString(); modelId = in.readString();
bucketPathMap = in.readMap(StreamInput::readString, StreamInput::readString); bucketPathMap = in.readMap(StreamInput::readString, StreamInput::readString);
inferenceConfig = in.readOptionalNamedWriteable(InferenceConfigUpdate.class); inferenceConfig = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
this.modelLoadingService = modelLoadingService; this.modelLoadingService = modelLoadingService;
this.model = null; this.model = null;
this.licenseState = licenseState;
} }
/** /**
@ -98,7 +118,8 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
Map<String, String> bucketsPath, Map<String, String> bucketsPath,
Supplier<LocalModel> model, Supplier<LocalModel> model,
String modelId, String modelId,
InferenceConfigUpdate inferenceConfig InferenceConfigUpdate inferenceConfig,
XPackLicenseState licenseState
) { ) {
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {})); super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
modelLoadingService = null; modelLoadingService = null;
@ -113,13 +134,14 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
*/ */
this.modelId = modelId; this.modelId = modelId;
this.inferenceConfig = inferenceConfig; this.inferenceConfig = inferenceConfig;
this.licenseState = licenseState;
} }
void setModelId(String modelId) { public void setModelId(String modelId) {
this.modelId = modelId; this.modelId = modelId;
} }
void setInferenceConfig(InferenceConfigUpdate inferenceConfig) { public void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
this.inferenceConfig = inferenceConfig; this.inferenceConfig = inferenceConfig;
} }
@ -160,7 +182,7 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
} }
@Override @Override
public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) throws IOException { public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) {
if (model != null) { if (model != null) {
return this; return this;
} }
@ -168,10 +190,17 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
context.registerAsyncAction((client, listener) -> { context.registerAsyncAction((client, listener) -> {
modelLoadingService.get().getModelForSearch(modelId, ActionListener.delegateFailure(listener, (delegate, model) -> { modelLoadingService.get().getModelForSearch(modelId, ActionListener.delegateFailure(listener, (delegate, model) -> {
loadedModel.set(model); loadedModel.set(model);
delegate.onResponse(null);
boolean isLicensed = licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING) ||
licenseState.isAllowedByLicense(model.getLicenseLevel());
if (isLicensed) {
delegate.onResponse(null);
} else {
delegate.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
}
})); }));
}); });
return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig); return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig, licenseState);
} }
@Override @Override

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ml.inference.loadingservice; package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
@ -38,6 +39,7 @@ public class LocalModel {
private volatile long persistenceQuotient = 100; private volatile long persistenceQuotient = 100;
private final LongAdder currentInferenceCount; private final LongAdder currentInferenceCount;
private final InferenceConfig inferenceConfig; private final InferenceConfig inferenceConfig;
private final License.OperationMode licenseLevel;
public LocalModel(String modelId, public LocalModel(String modelId,
String nodeId, String nodeId,
@ -45,6 +47,7 @@ public class LocalModel {
TrainedModelInput input, TrainedModelInput input,
Map<String, String> defaultFieldMap, Map<String, String> defaultFieldMap,
InferenceConfig modelInferenceConfig, InferenceConfig modelInferenceConfig,
License.OperationMode licenseLevel,
TrainedModelStatsService trainedModelStatsService) { TrainedModelStatsService trainedModelStatsService) {
this.trainedModelDefinition = trainedModelDefinition; this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId; this.modelId = modelId;
@ -56,6 +59,7 @@ public class LocalModel {
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap); this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
this.currentInferenceCount = new LongAdder(); this.currentInferenceCount = new LongAdder();
this.inferenceConfig = modelInferenceConfig; this.inferenceConfig = modelInferenceConfig;
this.licenseLevel = licenseLevel;
} }
long ramBytesUsed() { long ramBytesUsed() {
@ -66,6 +70,10 @@ public class LocalModel {
return modelId; return modelId;
} }
public License.OperationMode getLicenseLevel() {
return licenseLevel;
}
public InferenceStats getLatestStatsAndReset() { public InferenceStats getLatestStatsAndReset() {
return statsAccumulator.currentStatsAndReset(); return statsAccumulator.currentStatsAndReset();
} }

View File

@ -309,6 +309,7 @@ public class ModelLoadingService implements ClusterStateListener {
trainedModelConfig.getInput(), trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap(), trainedModelConfig.getDefaultFieldMap(),
inferenceConfig, inferenceConfig,
trainedModelConfig.getLicenseLevel(),
modelStatsService)); modelStatsService));
}, },
// Failure getting the definition, remove the initial estimation value // Failure getting the definition, remove the initial estimation value
@ -337,6 +338,7 @@ public class ModelLoadingService implements ClusterStateListener {
trainedModelConfig.getInput(), trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap(), trainedModelConfig.getDefaultFieldMap(),
inferenceConfig, inferenceConfig,
trainedModelConfig.getLicenseLevel(),
modelStatsService); modelStatsService);
synchronized (loadingListeners) { synchronized (loadingListeners) {
listeners = loadingListeners.remove(modelId); listeners = loadingListeners.remove(modelId);

View File

@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase; import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
@ -61,7 +62,8 @@ public class InferencePipelineAggregationBuilderTests extends BasePipelineAggreg
.collect(Collectors.toMap(Function.identity(), (t) -> randomAlphaOfLength(5))); .collect(Collectors.toMap(Function.identity(), (t) -> randomAlphaOfLength(5)));
InferencePipelineAggregationBuilder builder = InferencePipelineAggregationBuilder builder =
new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)), bucketPaths); new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)),
mock(XPackLicenseState.class), bucketPaths);
builder.setModelId(randomAlphaOfLength(6)); builder.setModelId(randomAlphaOfLength(6));
if (randomBoolean()) { if (randomBoolean()) {

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.license.License;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
@ -73,6 +74,7 @@ public class LocalModelTests extends ESTestCase {
new TrainedModelInput(inputFields), new TrainedModelInput(inputFields),
Collections.singletonMap("field.foo", "field.foo.keyword"), Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS, ClassificationConfig.EMPTY_PARAMS,
randomFrom(License.OperationMode.values()),
modelStatsService); modelStatsService);
Map<String, Object> fields = new HashMap<String, Object>() {{ Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0); put("field.foo", 1.0);
@ -102,6 +104,7 @@ public class LocalModelTests extends ESTestCase {
new TrainedModelInput(inputFields), new TrainedModelInput(inputFields),
Collections.singletonMap("field.foo", "field.foo.keyword"), Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS, ClassificationConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService); modelStatsService);
result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS); result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
assertThat(result.value(), equalTo(0.0)); assertThat(result.value(), equalTo(0.0));
@ -144,6 +147,7 @@ public class LocalModelTests extends ESTestCase {
new TrainedModelInput(inputFields), new TrainedModelInput(inputFields),
Collections.singletonMap("field.foo", "field.foo.keyword"), Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS, ClassificationConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService); modelStatsService);
Map<String, Object> fields = new HashMap<String, Object>() {{ Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0); put("field.foo", 1.0);
@ -199,6 +203,7 @@ public class LocalModelTests extends ESTestCase {
new TrainedModelInput(inputFields), new TrainedModelInput(inputFields),
Collections.singletonMap("bar", "bar.keyword"), Collections.singletonMap("bar", "bar.keyword"),
RegressionConfig.EMPTY_PARAMS, RegressionConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService); modelStatsService);
Map<String, Object> fields = new HashMap<String, Object>() {{ Map<String, Object> fields = new HashMap<String, Object>() {{
@ -226,6 +231,7 @@ public class LocalModelTests extends ESTestCase {
new TrainedModelInput(inputFields), new TrainedModelInput(inputFields),
null, null,
RegressionConfig.EMPTY_PARAMS, RegressionConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService); modelStatsService);
Map<String, Object> fields = new HashMap<String, Object>() {{ Map<String, Object> fields = new HashMap<String, Object>() {{
@ -256,6 +262,7 @@ public class LocalModelTests extends ESTestCase {
new TrainedModelInput(inputFields), new TrainedModelInput(inputFields),
null, null,
ClassificationConfig.EMPTY_PARAMS, ClassificationConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService modelStatsService
); );
Map<String, Object> fields = new HashMap<String, Object>() {{ Map<String, Object> fields = new HashMap<String, Object>() {{