mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-23 21:38:15 +00:00
Ensures the licence is sufficient for the model used in inference
This commit is contained in:
parent
f651487d74
commit
d86435938b
@ -5,25 +5,31 @@
|
|||||||
*/
|
*/
|
||||||
package org.elasticsearch.license;
|
package org.elasticsearch.license;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.SetOnce;
|
||||||
import org.elasticsearch.ElasticsearchSecurityException;
|
import org.elasticsearch.ElasticsearchSecurityException;
|
||||||
|
import org.elasticsearch.action.index.IndexRequest;
|
||||||
import org.elasticsearch.action.ingest.PutPipelineAction;
|
import org.elasticsearch.action.ingest.PutPipelineAction;
|
||||||
import org.elasticsearch.action.ingest.PutPipelineRequest;
|
import org.elasticsearch.action.ingest.PutPipelineRequest;
|
||||||
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
|
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
|
||||||
import org.elasticsearch.action.ingest.SimulatePipelineAction;
|
import org.elasticsearch.action.ingest.SimulatePipelineAction;
|
||||||
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
|
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
|
||||||
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
|
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
|
||||||
|
import org.elasticsearch.action.search.SearchRequest;
|
||||||
import org.elasticsearch.action.support.PlainActionFuture;
|
import org.elasticsearch.action.support.PlainActionFuture;
|
||||||
|
import org.elasticsearch.action.support.WriteRequest;
|
||||||
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
||||||
import org.elasticsearch.client.transport.TransportClient;
|
import org.elasticsearch.client.transport.TransportClient;
|
||||||
import org.elasticsearch.cluster.ClusterState;
|
import org.elasticsearch.cluster.ClusterState;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
|
||||||
import org.elasticsearch.common.bytes.BytesArray;
|
import org.elasticsearch.common.bytes.BytesArray;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.unit.TimeValue;
|
import org.elasticsearch.common.unit.TimeValue;
|
||||||
import org.elasticsearch.common.xcontent.XContentType;
|
import org.elasticsearch.common.xcontent.XContentType;
|
||||||
import org.elasticsearch.index.mapper.MapperService;
|
import org.elasticsearch.index.mapper.MapperService;
|
||||||
import org.elasticsearch.license.License.OperationMode;
|
import org.elasticsearch.license.License.OperationMode;
|
||||||
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
|
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
|
||||||
import org.elasticsearch.rest.RestStatus;
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
|
||||||
|
import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
|
||||||
import org.elasticsearch.transport.Transport;
|
import org.elasticsearch.transport.Transport;
|
||||||
import org.elasticsearch.xpack.core.TestXPackTransportClient;
|
import org.elasticsearch.xpack.core.TestXPackTransportClient;
|
||||||
import org.elasticsearch.xpack.core.XPackField;
|
import org.elasticsearch.xpack.core.XPackField;
|
||||||
@ -53,12 +59,15 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
|||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
||||||
import org.elasticsearch.xpack.core.ml.job.config.JobState;
|
import org.elasticsearch.xpack.core.ml.job.config.JobState;
|
||||||
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
|
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
|
||||||
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
|
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.empty;
|
import static org.hamcrest.Matchers.empty;
|
||||||
@ -163,7 +172,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testMachineLearningPutDatafeedActionRestricted() throws Exception {
|
public void testMachineLearningPutDatafeedActionRestricted() {
|
||||||
String jobId = "testmachinelearningputdatafeedactionrestricted";
|
String jobId = "testmachinelearningputdatafeedactionrestricted";
|
||||||
String datafeedId = jobId + "-datafeed";
|
String datafeedId = jobId + "-datafeed";
|
||||||
assertMLAllowed(true);
|
assertMLAllowed(true);
|
||||||
@ -497,7 +506,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testMachineLearningDeleteJobActionNotRestricted() throws Exception {
|
public void testMachineLearningDeleteJobActionNotRestricted() {
|
||||||
String jobId = "testmachinelearningclosejobactionnotrestricted";
|
String jobId = "testmachinelearningclosejobactionnotrestricted";
|
||||||
assertMLAllowed(true);
|
assertMLAllowed(true);
|
||||||
// test that license restricted apis do now work
|
// test that license restricted apis do now work
|
||||||
@ -522,7 +531,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testMachineLearningDeleteDatafeedActionNotRestricted() throws Exception {
|
public void testMachineLearningDeleteDatafeedActionNotRestricted() {
|
||||||
String jobId = "testmachinelearningdeletedatafeedactionnotrestricted";
|
String jobId = "testmachinelearningdeletedatafeedactionnotrestricted";
|
||||||
String datafeedId = jobId + "-datafeed";
|
String datafeedId = jobId + "-datafeed";
|
||||||
assertMLAllowed(true);
|
assertMLAllowed(true);
|
||||||
@ -554,7 +563,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testMachineLearningCreateInferenceProcessorRestricted() throws Exception {
|
public void testMachineLearningCreateInferenceProcessorRestricted() {
|
||||||
String modelId = "modelprocessorlicensetest";
|
String modelId = "modelprocessorlicensetest";
|
||||||
assertMLAllowed(true);
|
assertMLAllowed(true);
|
||||||
putInferenceModel(modelId);
|
putInferenceModel(modelId);
|
||||||
@ -686,7 +695,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|||||||
.actionGet();
|
.actionGet();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testMachineLearningInferModelRestricted() throws Exception {
|
public void testMachineLearningInferModelRestricted() {
|
||||||
String modelId = "modelinfermodellicensetest";
|
String modelId = "modelinfermodellicensetest";
|
||||||
assertMLAllowed(true);
|
assertMLAllowed(true);
|
||||||
putInferenceModel(modelId);
|
putInferenceModel(modelId);
|
||||||
@ -748,6 +757,58 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|||||||
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
|
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testInferenceAggRestricted() {
|
||||||
|
String modelId = "inference-agg-restricted";
|
||||||
|
assertMLAllowed(true);
|
||||||
|
putInferenceModel(modelId);
|
||||||
|
|
||||||
|
// index some data
|
||||||
|
String index = "inference-agg-licence-test";
|
||||||
|
client().admin().indices().prepareCreate(index)
|
||||||
|
.addMapping("type", "feature1", "type=double", "feature2", "type=keyword").get();
|
||||||
|
client().prepareBulk()
|
||||||
|
.add(new IndexRequest(index).source("feature1", "10.0", "feature2", "foo"))
|
||||||
|
.add(new IndexRequest(index).source("feature1", "20.0", "feature2", "foo"))
|
||||||
|
.add(new IndexRequest(index).source("feature1", "20.0", "feature2", "bar"))
|
||||||
|
.add(new IndexRequest(index).source("feature1", "20.0", "feature2", "bar"))
|
||||||
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||||
|
.get();
|
||||||
|
|
||||||
|
TermsAggregationBuilder termsAgg = new TermsAggregationBuilder("foobar").field("feature2");
|
||||||
|
AvgAggregationBuilder avgAgg = new AvgAggregationBuilder("avg_feature1").field("feature1");
|
||||||
|
termsAgg.subAggregation(avgAgg);
|
||||||
|
|
||||||
|
XPackLicenseState licenseState = internalCluster().getInstance(XPackLicenseState.class);
|
||||||
|
ModelLoadingService modelLoading = internalCluster().getInstance(ModelLoadingService.class);
|
||||||
|
|
||||||
|
Map<String, String> bucketPaths = new HashMap<>();
|
||||||
|
bucketPaths.put("feature1", "avg_feature1");
|
||||||
|
InferencePipelineAggregationBuilder inferenceAgg =
|
||||||
|
new InferencePipelineAggregationBuilder("infer_agg", new SetOnce<>(modelLoading), licenseState, bucketPaths);
|
||||||
|
inferenceAgg.setModelId(modelId);
|
||||||
|
|
||||||
|
termsAgg.subAggregation(inferenceAgg);
|
||||||
|
|
||||||
|
SearchRequest search = new SearchRequest(index);
|
||||||
|
search.source().aggregation(termsAgg);
|
||||||
|
client().search(search).actionGet();
|
||||||
|
|
||||||
|
// Pick a license that does not allow machine learning
|
||||||
|
License.OperationMode mode = randomInvalidLicenseType();
|
||||||
|
enableLicensing(mode);
|
||||||
|
assertMLAllowed(false);
|
||||||
|
|
||||||
|
// inferring against a model should now fail
|
||||||
|
SearchRequest invalidSearch = new SearchRequest(index);
|
||||||
|
invalidSearch.source().aggregation(termsAgg);
|
||||||
|
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class,
|
||||||
|
() -> client().search(invalidSearch).actionGet());
|
||||||
|
|
||||||
|
assertThat(e.status(), is(RestStatus.FORBIDDEN));
|
||||||
|
assertThat(e.getMessage(), containsString("current license is non-compliant for [ml]"));
|
||||||
|
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
|
||||||
|
}
|
||||||
|
|
||||||
private void putInferenceModel(String modelId) {
|
private void putInferenceModel(String modelId) {
|
||||||
TrainedModelConfig config = TrainedModelConfig.builder()
|
TrainedModelConfig config = TrainedModelConfig.builder()
|
||||||
.setParsedDefinition(
|
.setParsedDefinition(
|
||||||
@ -755,13 +816,13 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|||||||
.setTrainedModel(
|
.setTrainedModel(
|
||||||
Tree.builder()
|
Tree.builder()
|
||||||
.setTargetType(TargetType.REGRESSION)
|
.setTargetType(TargetType.REGRESSION)
|
||||||
.setFeatureNames(Arrays.asList("feature1"))
|
.setFeatureNames(Collections.singletonList("feature1"))
|
||||||
.setNodes(TreeNode.builder(0).setLeafValue(1.0))
|
.setNodes(TreeNode.builder(0).setLeafValue(1.0))
|
||||||
.build())
|
.build())
|
||||||
.setPreProcessors(Collections.emptyList()))
|
.setPreProcessors(Collections.emptyList()))
|
||||||
.setModelId(modelId)
|
.setModelId(modelId)
|
||||||
.setDescription("test model for classification")
|
.setDescription("test model for classification")
|
||||||
.setInput(new TrainedModelInput(Arrays.asList("feature1")))
|
.setInput(new TrainedModelInput(Collections.singletonList("feature1")))
|
||||||
.setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
|
.setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
|
||||||
.build();
|
.build();
|
||||||
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
|
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
|
||||||
|
@ -980,10 +980,9 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
|
|||||||
@Override
|
@Override
|
||||||
public List<PipelineAggregationSpec> getPipelineAggregations() {
|
public List<PipelineAggregationSpec> getPipelineAggregations() {
|
||||||
PipelineAggregationSpec spec = new PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME,
|
PipelineAggregationSpec spec = new PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME,
|
||||||
in -> new InferencePipelineAggregationBuilder(in, modelLoadingService),
|
in -> new InferencePipelineAggregationBuilder(in, getLicenseState(), modelLoadingService),
|
||||||
(ContextParser<String, ? extends PipelineAggregationBuilder>)
|
(ContextParser<String, ? extends PipelineAggregationBuilder>)
|
||||||
(parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, name, parser
|
(parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, getLicenseState(), name, parser));
|
||||||
));
|
|
||||||
spec.addResultReader(InternalInferenceAggregation::new);
|
spec.addResultReader(InternalInferenceAggregation::new);
|
||||||
|
|
||||||
return Collections.singletonList(spec);
|
return Collections.singletonList(spec);
|
||||||
|
@ -10,15 +10,17 @@ import org.apache.lucene.util.SetOnce;
|
|||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.collect.Tuple;
|
|
||||||
import org.elasticsearch.common.io.stream.StreamInput;
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.index.query.QueryRewriteContext;
|
import org.elasticsearch.index.query.QueryRewriteContext;
|
||||||
|
import org.elasticsearch.license.LicenseUtils;
|
||||||
|
import org.elasticsearch.license.XPackLicenseState;
|
||||||
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
|
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
|
||||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||||
|
import org.elasticsearch.xpack.core.XPackField;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
|
||||||
@ -44,10 +46,10 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
|
|||||||
static final String AGGREGATIONS_RESULTS_FIELD = "value";
|
static final String AGGREGATIONS_RESULTS_FIELD = "value";
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder,
|
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder, ParserSupplement> PARSER =
|
||||||
Tuple<SetOnce<ModelLoadingService>, String>> PARSER = new ConstructingObjectParser<>(
|
new ConstructingObjectParser<>(NAME, false,
|
||||||
NAME, false,
|
(args, context) -> new InferencePipelineAggregationBuilder(context.name, context.modelLoadingService,
|
||||||
(args, context) -> new InferencePipelineAggregationBuilder(context.v2(), context.v1(), (Map<String, String>) args[0])
|
context.licenseState, (Map<String, String>) args[0])
|
||||||
);
|
);
|
||||||
|
|
||||||
static {
|
static {
|
||||||
@ -60,34 +62,52 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
|
|||||||
private final Map<String, String> bucketPathMap;
|
private final Map<String, String> bucketPathMap;
|
||||||
private String modelId;
|
private String modelId;
|
||||||
private InferenceConfigUpdate inferenceConfig;
|
private InferenceConfigUpdate inferenceConfig;
|
||||||
|
private final XPackLicenseState licenseState;
|
||||||
private final SetOnce<ModelLoadingService> modelLoadingService;
|
private final SetOnce<ModelLoadingService> modelLoadingService;
|
||||||
/**
|
/**
|
||||||
* The model. Set to a non-null value during the rewrite phase.
|
* The model. Set to a non-null value during the rewrite phase.
|
||||||
*/
|
*/
|
||||||
private final Supplier<LocalModel> model;
|
private final Supplier<LocalModel> model;
|
||||||
|
|
||||||
|
private static class ParserSupplement {
|
||||||
|
final XPackLicenseState licenseState;
|
||||||
|
final SetOnce<ModelLoadingService> modelLoadingService;
|
||||||
|
final String name;
|
||||||
|
|
||||||
|
ParserSupplement(String name, XPackLicenseState licenseState, SetOnce<ModelLoadingService> modelLoadingService) {
|
||||||
|
this.name = name;
|
||||||
|
this.licenseState = licenseState;
|
||||||
|
this.modelLoadingService = modelLoadingService;
|
||||||
|
}
|
||||||
|
}
|
||||||
public static InferencePipelineAggregationBuilder parse(SetOnce<ModelLoadingService> modelLoadingService,
|
public static InferencePipelineAggregationBuilder parse(SetOnce<ModelLoadingService> modelLoadingService,
|
||||||
|
XPackLicenseState licenseState,
|
||||||
String pipelineAggregatorName,
|
String pipelineAggregatorName,
|
||||||
XContentParser parser) {
|
XContentParser parser) {
|
||||||
Tuple<SetOnce<ModelLoadingService>, String> context = new Tuple<>(modelLoadingService, pipelineAggregatorName);
|
return PARSER.apply(parser, new ParserSupplement(pipelineAggregatorName, licenseState, modelLoadingService));
|
||||||
return PARSER.apply(parser, context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public InferencePipelineAggregationBuilder(String name, SetOnce<ModelLoadingService> modelLoadingService,
|
public InferencePipelineAggregationBuilder(String name,
|
||||||
|
SetOnce<ModelLoadingService> modelLoadingService,
|
||||||
|
XPackLicenseState licenseState,
|
||||||
Map<String, String> bucketsPath) {
|
Map<String, String> bucketsPath) {
|
||||||
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
|
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
|
||||||
this.modelLoadingService = modelLoadingService;
|
this.modelLoadingService = modelLoadingService;
|
||||||
this.bucketPathMap = bucketsPath;
|
this.bucketPathMap = bucketsPath;
|
||||||
this.model = null;
|
this.model = null;
|
||||||
|
this.licenseState = licenseState;
|
||||||
}
|
}
|
||||||
|
|
||||||
public InferencePipelineAggregationBuilder(StreamInput in, SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
|
public InferencePipelineAggregationBuilder(StreamInput in,
|
||||||
|
XPackLicenseState licenseState,
|
||||||
|
SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
|
||||||
super(in, NAME);
|
super(in, NAME);
|
||||||
modelId = in.readString();
|
modelId = in.readString();
|
||||||
bucketPathMap = in.readMap(StreamInput::readString, StreamInput::readString);
|
bucketPathMap = in.readMap(StreamInput::readString, StreamInput::readString);
|
||||||
inferenceConfig = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
|
inferenceConfig = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
|
||||||
this.modelLoadingService = modelLoadingService;
|
this.modelLoadingService = modelLoadingService;
|
||||||
this.model = null;
|
this.model = null;
|
||||||
|
this.licenseState = licenseState;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -98,7 +118,8 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
|
|||||||
Map<String, String> bucketsPath,
|
Map<String, String> bucketsPath,
|
||||||
Supplier<LocalModel> model,
|
Supplier<LocalModel> model,
|
||||||
String modelId,
|
String modelId,
|
||||||
InferenceConfigUpdate inferenceConfig
|
InferenceConfigUpdate inferenceConfig,
|
||||||
|
XPackLicenseState licenseState
|
||||||
) {
|
) {
|
||||||
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
|
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
|
||||||
modelLoadingService = null;
|
modelLoadingService = null;
|
||||||
@ -113,13 +134,14 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
|
|||||||
*/
|
*/
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
this.inferenceConfig = inferenceConfig;
|
this.inferenceConfig = inferenceConfig;
|
||||||
|
this.licenseState = licenseState;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setModelId(String modelId) {
|
public void setModelId(String modelId) {
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
|
public void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
|
||||||
this.inferenceConfig = inferenceConfig;
|
this.inferenceConfig = inferenceConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -160,7 +182,7 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) throws IOException {
|
public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) {
|
||||||
if (model != null) {
|
if (model != null) {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
@ -168,10 +190,17 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
|
|||||||
context.registerAsyncAction((client, listener) -> {
|
context.registerAsyncAction((client, listener) -> {
|
||||||
modelLoadingService.get().getModelForSearch(modelId, ActionListener.delegateFailure(listener, (delegate, model) -> {
|
modelLoadingService.get().getModelForSearch(modelId, ActionListener.delegateFailure(listener, (delegate, model) -> {
|
||||||
loadedModel.set(model);
|
loadedModel.set(model);
|
||||||
delegate.onResponse(null);
|
|
||||||
|
boolean isLicensed = licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING) ||
|
||||||
|
licenseState.isAllowedByLicense(model.getLicenseLevel());
|
||||||
|
if (isLicensed) {
|
||||||
|
delegate.onResponse(null);
|
||||||
|
} else {
|
||||||
|
delegate.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
|
||||||
|
}
|
||||||
}));
|
}));
|
||||||
});
|
});
|
||||||
return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig);
|
return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig, licenseState);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
package org.elasticsearch.xpack.ml.inference.loadingservice;
|
package org.elasticsearch.xpack.ml.inference.loadingservice;
|
||||||
|
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.license.License;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||||
@ -38,6 +39,7 @@ public class LocalModel {
|
|||||||
private volatile long persistenceQuotient = 100;
|
private volatile long persistenceQuotient = 100;
|
||||||
private final LongAdder currentInferenceCount;
|
private final LongAdder currentInferenceCount;
|
||||||
private final InferenceConfig inferenceConfig;
|
private final InferenceConfig inferenceConfig;
|
||||||
|
private final License.OperationMode licenseLevel;
|
||||||
|
|
||||||
public LocalModel(String modelId,
|
public LocalModel(String modelId,
|
||||||
String nodeId,
|
String nodeId,
|
||||||
@ -45,6 +47,7 @@ public class LocalModel {
|
|||||||
TrainedModelInput input,
|
TrainedModelInput input,
|
||||||
Map<String, String> defaultFieldMap,
|
Map<String, String> defaultFieldMap,
|
||||||
InferenceConfig modelInferenceConfig,
|
InferenceConfig modelInferenceConfig,
|
||||||
|
License.OperationMode licenseLevel,
|
||||||
TrainedModelStatsService trainedModelStatsService) {
|
TrainedModelStatsService trainedModelStatsService) {
|
||||||
this.trainedModelDefinition = trainedModelDefinition;
|
this.trainedModelDefinition = trainedModelDefinition;
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
@ -56,6 +59,7 @@ public class LocalModel {
|
|||||||
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
|
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
|
||||||
this.currentInferenceCount = new LongAdder();
|
this.currentInferenceCount = new LongAdder();
|
||||||
this.inferenceConfig = modelInferenceConfig;
|
this.inferenceConfig = modelInferenceConfig;
|
||||||
|
this.licenseLevel = licenseLevel;
|
||||||
}
|
}
|
||||||
|
|
||||||
long ramBytesUsed() {
|
long ramBytesUsed() {
|
||||||
@ -66,6 +70,10 @@ public class LocalModel {
|
|||||||
return modelId;
|
return modelId;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public License.OperationMode getLicenseLevel() {
|
||||||
|
return licenseLevel;
|
||||||
|
}
|
||||||
|
|
||||||
public InferenceStats getLatestStatsAndReset() {
|
public InferenceStats getLatestStatsAndReset() {
|
||||||
return statsAccumulator.currentStatsAndReset();
|
return statsAccumulator.currentStatsAndReset();
|
||||||
}
|
}
|
||||||
|
@ -309,6 +309,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||||||
trainedModelConfig.getInput(),
|
trainedModelConfig.getInput(),
|
||||||
trainedModelConfig.getDefaultFieldMap(),
|
trainedModelConfig.getDefaultFieldMap(),
|
||||||
inferenceConfig,
|
inferenceConfig,
|
||||||
|
trainedModelConfig.getLicenseLevel(),
|
||||||
modelStatsService));
|
modelStatsService));
|
||||||
},
|
},
|
||||||
// Failure getting the definition, remove the initial estimation value
|
// Failure getting the definition, remove the initial estimation value
|
||||||
@ -337,6 +338,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||||||
trainedModelConfig.getInput(),
|
trainedModelConfig.getInput(),
|
||||||
trainedModelConfig.getDefaultFieldMap(),
|
trainedModelConfig.getDefaultFieldMap(),
|
||||||
inferenceConfig,
|
inferenceConfig,
|
||||||
|
trainedModelConfig.getLicenseLevel(),
|
||||||
modelStatsService);
|
modelStatsService);
|
||||||
synchronized (loadingListeners) {
|
synchronized (loadingListeners) {
|
||||||
listeners = loadingListeners.remove(modelId);
|
listeners = loadingListeners.remove(modelId);
|
||||||
|
@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce;
|
|||||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
|
import org.elasticsearch.license.XPackLicenseState;
|
||||||
import org.elasticsearch.plugins.SearchPlugin;
|
import org.elasticsearch.plugins.SearchPlugin;
|
||||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||||
import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
|
import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
|
||||||
@ -61,7 +62,8 @@ public class InferencePipelineAggregationBuilderTests extends BasePipelineAggreg
|
|||||||
.collect(Collectors.toMap(Function.identity(), (t) -> randomAlphaOfLength(5)));
|
.collect(Collectors.toMap(Function.identity(), (t) -> randomAlphaOfLength(5)));
|
||||||
|
|
||||||
InferencePipelineAggregationBuilder builder =
|
InferencePipelineAggregationBuilder builder =
|
||||||
new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)), bucketPaths);
|
new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)),
|
||||||
|
mock(XPackLicenseState.class), bucketPaths);
|
||||||
builder.setModelId(randomAlphaOfLength(6));
|
builder.setModelId(randomAlphaOfLength(6));
|
||||||
|
|
||||||
if (randomBoolean()) {
|
if (randomBoolean()) {
|
||||||
|
@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
|
|||||||
|
|
||||||
import org.elasticsearch.action.support.PlainActionFuture;
|
import org.elasticsearch.action.support.PlainActionFuture;
|
||||||
import org.elasticsearch.ingest.IngestDocument;
|
import org.elasticsearch.ingest.IngestDocument;
|
||||||
|
import org.elasticsearch.license.License;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||||
@ -73,6 +74,7 @@ public class LocalModelTests extends ESTestCase {
|
|||||||
new TrainedModelInput(inputFields),
|
new TrainedModelInput(inputFields),
|
||||||
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
||||||
ClassificationConfig.EMPTY_PARAMS,
|
ClassificationConfig.EMPTY_PARAMS,
|
||||||
|
randomFrom(License.OperationMode.values()),
|
||||||
modelStatsService);
|
modelStatsService);
|
||||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||||
put("field.foo", 1.0);
|
put("field.foo", 1.0);
|
||||||
@ -102,6 +104,7 @@ public class LocalModelTests extends ESTestCase {
|
|||||||
new TrainedModelInput(inputFields),
|
new TrainedModelInput(inputFields),
|
||||||
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
||||||
ClassificationConfig.EMPTY_PARAMS,
|
ClassificationConfig.EMPTY_PARAMS,
|
||||||
|
License.OperationMode.PLATINUM,
|
||||||
modelStatsService);
|
modelStatsService);
|
||||||
result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
|
result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
|
||||||
assertThat(result.value(), equalTo(0.0));
|
assertThat(result.value(), equalTo(0.0));
|
||||||
@ -144,6 +147,7 @@ public class LocalModelTests extends ESTestCase {
|
|||||||
new TrainedModelInput(inputFields),
|
new TrainedModelInput(inputFields),
|
||||||
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
||||||
ClassificationConfig.EMPTY_PARAMS,
|
ClassificationConfig.EMPTY_PARAMS,
|
||||||
|
License.OperationMode.PLATINUM,
|
||||||
modelStatsService);
|
modelStatsService);
|
||||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||||
put("field.foo", 1.0);
|
put("field.foo", 1.0);
|
||||||
@ -199,6 +203,7 @@ public class LocalModelTests extends ESTestCase {
|
|||||||
new TrainedModelInput(inputFields),
|
new TrainedModelInput(inputFields),
|
||||||
Collections.singletonMap("bar", "bar.keyword"),
|
Collections.singletonMap("bar", "bar.keyword"),
|
||||||
RegressionConfig.EMPTY_PARAMS,
|
RegressionConfig.EMPTY_PARAMS,
|
||||||
|
License.OperationMode.PLATINUM,
|
||||||
modelStatsService);
|
modelStatsService);
|
||||||
|
|
||||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||||
@ -226,6 +231,7 @@ public class LocalModelTests extends ESTestCase {
|
|||||||
new TrainedModelInput(inputFields),
|
new TrainedModelInput(inputFields),
|
||||||
null,
|
null,
|
||||||
RegressionConfig.EMPTY_PARAMS,
|
RegressionConfig.EMPTY_PARAMS,
|
||||||
|
License.OperationMode.PLATINUM,
|
||||||
modelStatsService);
|
modelStatsService);
|
||||||
|
|
||||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||||
@ -256,6 +262,7 @@ public class LocalModelTests extends ESTestCase {
|
|||||||
new TrainedModelInput(inputFields),
|
new TrainedModelInput(inputFields),
|
||||||
null,
|
null,
|
||||||
ClassificationConfig.EMPTY_PARAMS,
|
ClassificationConfig.EMPTY_PARAMS,
|
||||||
|
License.OperationMode.PLATINUM,
|
||||||
modelStatsService
|
modelStatsService
|
||||||
);
|
);
|
||||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||||
|
Loading…
x
Reference in New Issue
Block a user