[7.x] Add ml licence check to the pipeline inference agg. (#59213) (#59412)

Ensures the licence is sufficient for the model used in inference
This commit is contained in:
David Kyle 2020-07-14 14:03:10 +01:00 committed by GitHub
parent f651487d74
commit d86435938b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 136 additions and 28 deletions

View File

@ -5,25 +5,31 @@
*/
package org.elasticsearch.license;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.ingest.PutPipelineAction;
import org.elasticsearch.action.ingest.PutPipelineRequest;
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
import org.elasticsearch.action.ingest.SimulatePipelineAction;
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.license.License.OperationMode;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.xpack.core.TestXPackTransportClient;
import org.elasticsearch.xpack.core.XPackField;
@ -53,12 +59,15 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
import org.junit.Before;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
@ -163,7 +172,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
}
}
public void testMachineLearningPutDatafeedActionRestricted() throws Exception {
public void testMachineLearningPutDatafeedActionRestricted() {
String jobId = "testmachinelearningputdatafeedactionrestricted";
String datafeedId = jobId + "-datafeed";
assertMLAllowed(true);
@ -497,7 +506,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
}
}
public void testMachineLearningDeleteJobActionNotRestricted() throws Exception {
public void testMachineLearningDeleteJobActionNotRestricted() {
String jobId = "testmachinelearningclosejobactionnotrestricted";
assertMLAllowed(true);
// test that license restricted apis do now work
@ -522,7 +531,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
}
}
public void testMachineLearningDeleteDatafeedActionNotRestricted() throws Exception {
public void testMachineLearningDeleteDatafeedActionNotRestricted() {
String jobId = "testmachinelearningdeletedatafeedactionnotrestricted";
String datafeedId = jobId + "-datafeed";
assertMLAllowed(true);
@ -554,7 +563,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
}
}
public void testMachineLearningCreateInferenceProcessorRestricted() throws Exception {
public void testMachineLearningCreateInferenceProcessorRestricted() {
String modelId = "modelprocessorlicensetest";
assertMLAllowed(true);
putInferenceModel(modelId);
@ -686,7 +695,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
.actionGet();
}
public void testMachineLearningInferModelRestricted() throws Exception {
public void testMachineLearningInferModelRestricted() {
String modelId = "modelinfermodellicensetest";
assertMLAllowed(true);
putInferenceModel(modelId);
@ -748,6 +757,58 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
}
public void testInferenceAggRestricted() {
String modelId = "inference-agg-restricted";
assertMLAllowed(true);
putInferenceModel(modelId);
// index some data
String index = "inference-agg-licence-test";
client().admin().indices().prepareCreate(index)
.addMapping("type", "feature1", "type=double", "feature2", "type=keyword").get();
client().prepareBulk()
.add(new IndexRequest(index).source("feature1", "10.0", "feature2", "foo"))
.add(new IndexRequest(index).source("feature1", "20.0", "feature2", "foo"))
.add(new IndexRequest(index).source("feature1", "20.0", "feature2", "bar"))
.add(new IndexRequest(index).source("feature1", "20.0", "feature2", "bar"))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get();
TermsAggregationBuilder termsAgg = new TermsAggregationBuilder("foobar").field("feature2");
AvgAggregationBuilder avgAgg = new AvgAggregationBuilder("avg_feature1").field("feature1");
termsAgg.subAggregation(avgAgg);
XPackLicenseState licenseState = internalCluster().getInstance(XPackLicenseState.class);
ModelLoadingService modelLoading = internalCluster().getInstance(ModelLoadingService.class);
Map<String, String> bucketPaths = new HashMap<>();
bucketPaths.put("feature1", "avg_feature1");
InferencePipelineAggregationBuilder inferenceAgg =
new InferencePipelineAggregationBuilder("infer_agg", new SetOnce<>(modelLoading), licenseState, bucketPaths);
inferenceAgg.setModelId(modelId);
termsAgg.subAggregation(inferenceAgg);
SearchRequest search = new SearchRequest(index);
search.source().aggregation(termsAgg);
client().search(search).actionGet();
// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
enableLicensing(mode);
assertMLAllowed(false);
// inferring against a model should now fail
SearchRequest invalidSearch = new SearchRequest(index);
invalidSearch.source().aggregation(termsAgg);
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class,
() -> client().search(invalidSearch).actionGet());
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("current license is non-compliant for [ml]"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
}
private void putInferenceModel(String modelId) {
TrainedModelConfig config = TrainedModelConfig.builder()
.setParsedDefinition(
@ -755,13 +816,13 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
.setTrainedModel(
Tree.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(Arrays.asList("feature1"))
.setFeatureNames(Collections.singletonList("feature1"))
.setNodes(TreeNode.builder(0).setLeafValue(1.0))
.build())
.setPreProcessors(Collections.emptyList()))
.setModelId(modelId)
.setDescription("test model for classification")
.setInput(new TrainedModelInput(Arrays.asList("feature1")))
.setInput(new TrainedModelInput(Collections.singletonList("feature1")))
.setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
.build();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();

View File

@ -980,10 +980,9 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
@Override
public List<PipelineAggregationSpec> getPipelineAggregations() {
PipelineAggregationSpec spec = new PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME,
in -> new InferencePipelineAggregationBuilder(in, modelLoadingService),
in -> new InferencePipelineAggregationBuilder(in, getLicenseState(), modelLoadingService),
(ContextParser<String, ? extends PipelineAggregationBuilder>)
(parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, name, parser
));
(parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, getLicenseState(), name, parser));
spec.addResultReader(InternalInferenceAggregation::new);
return Collections.singletonList(spec);

View File

@ -10,15 +10,17 @@ import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
@ -44,10 +46,10 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
static final String AGGREGATIONS_RESULTS_FIELD = "value";
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder,
Tuple<SetOnce<ModelLoadingService>, String>> PARSER = new ConstructingObjectParser<>(
NAME, false,
(args, context) -> new InferencePipelineAggregationBuilder(context.v2(), context.v1(), (Map<String, String>) args[0])
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder, ParserSupplement> PARSER =
new ConstructingObjectParser<>(NAME, false,
(args, context) -> new InferencePipelineAggregationBuilder(context.name, context.modelLoadingService,
context.licenseState, (Map<String, String>) args[0])
);
static {
@ -60,34 +62,52 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
private final Map<String, String> bucketPathMap;
private String modelId;
private InferenceConfigUpdate inferenceConfig;
private final XPackLicenseState licenseState;
private final SetOnce<ModelLoadingService> modelLoadingService;
/**
* The model. Set to a non-null value during the rewrite phase.
*/
private final Supplier<LocalModel> model;
private static class ParserSupplement {
final XPackLicenseState licenseState;
final SetOnce<ModelLoadingService> modelLoadingService;
final String name;
ParserSupplement(String name, XPackLicenseState licenseState, SetOnce<ModelLoadingService> modelLoadingService) {
this.name = name;
this.licenseState = licenseState;
this.modelLoadingService = modelLoadingService;
}
}
public static InferencePipelineAggregationBuilder parse(SetOnce<ModelLoadingService> modelLoadingService,
XPackLicenseState licenseState,
String pipelineAggregatorName,
XContentParser parser) {
Tuple<SetOnce<ModelLoadingService>, String> context = new Tuple<>(modelLoadingService, pipelineAggregatorName);
return PARSER.apply(parser, context);
return PARSER.apply(parser, new ParserSupplement(pipelineAggregatorName, licenseState, modelLoadingService));
}
public InferencePipelineAggregationBuilder(String name, SetOnce<ModelLoadingService> modelLoadingService,
public InferencePipelineAggregationBuilder(String name,
SetOnce<ModelLoadingService> modelLoadingService,
XPackLicenseState licenseState,
Map<String, String> bucketsPath) {
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
this.modelLoadingService = modelLoadingService;
this.bucketPathMap = bucketsPath;
this.model = null;
this.licenseState = licenseState;
}
public InferencePipelineAggregationBuilder(StreamInput in, SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
public InferencePipelineAggregationBuilder(StreamInput in,
XPackLicenseState licenseState,
SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
super(in, NAME);
modelId = in.readString();
bucketPathMap = in.readMap(StreamInput::readString, StreamInput::readString);
inferenceConfig = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
this.modelLoadingService = modelLoadingService;
this.model = null;
this.licenseState = licenseState;
}
/**
@ -98,7 +118,8 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
Map<String, String> bucketsPath,
Supplier<LocalModel> model,
String modelId,
InferenceConfigUpdate inferenceConfig
InferenceConfigUpdate inferenceConfig,
XPackLicenseState licenseState
) {
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
modelLoadingService = null;
@ -113,13 +134,14 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
*/
this.modelId = modelId;
this.inferenceConfig = inferenceConfig;
this.licenseState = licenseState;
}
void setModelId(String modelId) {
public void setModelId(String modelId) {
this.modelId = modelId;
}
void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
public void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
this.inferenceConfig = inferenceConfig;
}
@ -160,7 +182,7 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
}
@Override
public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) throws IOException {
public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) {
if (model != null) {
return this;
}
@ -168,10 +190,17 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
context.registerAsyncAction((client, listener) -> {
modelLoadingService.get().getModelForSearch(modelId, ActionListener.delegateFailure(listener, (delegate, model) -> {
loadedModel.set(model);
delegate.onResponse(null);
boolean isLicensed = licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING) ||
licenseState.isAllowedByLicense(model.getLicenseLevel());
if (isLicensed) {
delegate.onResponse(null);
} else {
delegate.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
}
}));
});
return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig);
return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig, licenseState);
}
@Override

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
@ -38,6 +39,7 @@ public class LocalModel {
private volatile long persistenceQuotient = 100;
private final LongAdder currentInferenceCount;
private final InferenceConfig inferenceConfig;
private final License.OperationMode licenseLevel;
public LocalModel(String modelId,
String nodeId,
@ -45,6 +47,7 @@ public class LocalModel {
TrainedModelInput input,
Map<String, String> defaultFieldMap,
InferenceConfig modelInferenceConfig,
License.OperationMode licenseLevel,
TrainedModelStatsService trainedModelStatsService) {
this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId;
@ -56,6 +59,7 @@ public class LocalModel {
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
this.currentInferenceCount = new LongAdder();
this.inferenceConfig = modelInferenceConfig;
this.licenseLevel = licenseLevel;
}
long ramBytesUsed() {
@ -66,6 +70,10 @@ public class LocalModel {
return modelId;
}
public License.OperationMode getLicenseLevel() {
return licenseLevel;
}
public InferenceStats getLatestStatsAndReset() {
return statsAccumulator.currentStatsAndReset();
}

View File

@ -309,6 +309,7 @@ public class ModelLoadingService implements ClusterStateListener {
trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap(),
inferenceConfig,
trainedModelConfig.getLicenseLevel(),
modelStatsService));
},
// Failure getting the definition, remove the initial estimation value
@ -337,6 +338,7 @@ public class ModelLoadingService implements ClusterStateListener {
trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap(),
inferenceConfig,
trainedModelConfig.getLicenseLevel(),
modelStatsService);
synchronized (loadingListeners) {
listeners = loadingListeners.remove(modelId);

View File

@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
@ -61,7 +62,8 @@ public class InferencePipelineAggregationBuilderTests extends BasePipelineAggreg
.collect(Collectors.toMap(Function.identity(), (t) -> randomAlphaOfLength(5)));
InferencePipelineAggregationBuilder builder =
new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)), bucketPaths);
new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)),
mock(XPackLicenseState.class), bucketPaths);
builder.setModelId(randomAlphaOfLength(6));
if (randomBoolean()) {

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.license.License;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
@ -73,6 +74,7 @@ public class LocalModelTests extends ESTestCase {
new TrainedModelInput(inputFields),
Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS,
randomFrom(License.OperationMode.values()),
modelStatsService);
Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0);
@ -102,6 +104,7 @@ public class LocalModelTests extends ESTestCase {
new TrainedModelInput(inputFields),
Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService);
result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
assertThat(result.value(), equalTo(0.0));
@ -144,6 +147,7 @@ public class LocalModelTests extends ESTestCase {
new TrainedModelInput(inputFields),
Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService);
Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0);
@ -199,6 +203,7 @@ public class LocalModelTests extends ESTestCase {
new TrainedModelInput(inputFields),
Collections.singletonMap("bar", "bar.keyword"),
RegressionConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService);
Map<String, Object> fields = new HashMap<String, Object>() {{
@ -226,6 +231,7 @@ public class LocalModelTests extends ESTestCase {
new TrainedModelInput(inputFields),
null,
RegressionConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService);
Map<String, Object> fields = new HashMap<String, Object>() {{
@ -256,6 +262,7 @@ public class LocalModelTests extends ESTestCase {
new TrainedModelInput(inputFields),
null,
ClassificationConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService
);
Map<String, Object> fields = new HashMap<String, Object>() {{