mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-23 21:38:15 +00:00
Ensures the licence is sufficient for the model used in inference
This commit is contained in:
parent
f651487d74
commit
d86435938b
@ -5,25 +5,31 @@
|
||||
*/
|
||||
package org.elasticsearch.license;
|
||||
|
||||
import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.ElasticsearchSecurityException;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.ingest.PutPipelineAction;
|
||||
import org.elasticsearch.action.ingest.PutPipelineRequest;
|
||||
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
|
||||
import org.elasticsearch.action.ingest.SimulatePipelineAction;
|
||||
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
|
||||
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
||||
import org.elasticsearch.client.transport.TransportClient;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.index.mapper.MapperService;
|
||||
import org.elasticsearch.license.License.OperationMode;
|
||||
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
|
||||
import org.elasticsearch.transport.Transport;
|
||||
import org.elasticsearch.xpack.core.TestXPackTransportClient;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
@ -53,12 +59,15 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.JobState;
|
||||
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
|
||||
import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
|
||||
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
@ -163,7 +172,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
public void testMachineLearningPutDatafeedActionRestricted() throws Exception {
|
||||
public void testMachineLearningPutDatafeedActionRestricted() {
|
||||
String jobId = "testmachinelearningputdatafeedactionrestricted";
|
||||
String datafeedId = jobId + "-datafeed";
|
||||
assertMLAllowed(true);
|
||||
@ -497,7 +506,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
public void testMachineLearningDeleteJobActionNotRestricted() throws Exception {
|
||||
public void testMachineLearningDeleteJobActionNotRestricted() {
|
||||
String jobId = "testmachinelearningclosejobactionnotrestricted";
|
||||
assertMLAllowed(true);
|
||||
// test that license restricted apis do now work
|
||||
@ -522,7 +531,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
public void testMachineLearningDeleteDatafeedActionNotRestricted() throws Exception {
|
||||
public void testMachineLearningDeleteDatafeedActionNotRestricted() {
|
||||
String jobId = "testmachinelearningdeletedatafeedactionnotrestricted";
|
||||
String datafeedId = jobId + "-datafeed";
|
||||
assertMLAllowed(true);
|
||||
@ -554,7 +563,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
public void testMachineLearningCreateInferenceProcessorRestricted() throws Exception {
|
||||
public void testMachineLearningCreateInferenceProcessorRestricted() {
|
||||
String modelId = "modelprocessorlicensetest";
|
||||
assertMLAllowed(true);
|
||||
putInferenceModel(modelId);
|
||||
@ -686,7 +695,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
||||
.actionGet();
|
||||
}
|
||||
|
||||
public void testMachineLearningInferModelRestricted() throws Exception {
|
||||
public void testMachineLearningInferModelRestricted() {
|
||||
String modelId = "modelinfermodellicensetest";
|
||||
assertMLAllowed(true);
|
||||
putInferenceModel(modelId);
|
||||
@ -748,6 +757,58 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
||||
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
|
||||
}
|
||||
|
||||
public void testInferenceAggRestricted() {
|
||||
String modelId = "inference-agg-restricted";
|
||||
assertMLAllowed(true);
|
||||
putInferenceModel(modelId);
|
||||
|
||||
// index some data
|
||||
String index = "inference-agg-licence-test";
|
||||
client().admin().indices().prepareCreate(index)
|
||||
.addMapping("type", "feature1", "type=double", "feature2", "type=keyword").get();
|
||||
client().prepareBulk()
|
||||
.add(new IndexRequest(index).source("feature1", "10.0", "feature2", "foo"))
|
||||
.add(new IndexRequest(index).source("feature1", "20.0", "feature2", "foo"))
|
||||
.add(new IndexRequest(index).source("feature1", "20.0", "feature2", "bar"))
|
||||
.add(new IndexRequest(index).source("feature1", "20.0", "feature2", "bar"))
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get();
|
||||
|
||||
TermsAggregationBuilder termsAgg = new TermsAggregationBuilder("foobar").field("feature2");
|
||||
AvgAggregationBuilder avgAgg = new AvgAggregationBuilder("avg_feature1").field("feature1");
|
||||
termsAgg.subAggregation(avgAgg);
|
||||
|
||||
XPackLicenseState licenseState = internalCluster().getInstance(XPackLicenseState.class);
|
||||
ModelLoadingService modelLoading = internalCluster().getInstance(ModelLoadingService.class);
|
||||
|
||||
Map<String, String> bucketPaths = new HashMap<>();
|
||||
bucketPaths.put("feature1", "avg_feature1");
|
||||
InferencePipelineAggregationBuilder inferenceAgg =
|
||||
new InferencePipelineAggregationBuilder("infer_agg", new SetOnce<>(modelLoading), licenseState, bucketPaths);
|
||||
inferenceAgg.setModelId(modelId);
|
||||
|
||||
termsAgg.subAggregation(inferenceAgg);
|
||||
|
||||
SearchRequest search = new SearchRequest(index);
|
||||
search.source().aggregation(termsAgg);
|
||||
client().search(search).actionGet();
|
||||
|
||||
// Pick a license that does not allow machine learning
|
||||
License.OperationMode mode = randomInvalidLicenseType();
|
||||
enableLicensing(mode);
|
||||
assertMLAllowed(false);
|
||||
|
||||
// inferring against a model should now fail
|
||||
SearchRequest invalidSearch = new SearchRequest(index);
|
||||
invalidSearch.source().aggregation(termsAgg);
|
||||
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class,
|
||||
() -> client().search(invalidSearch).actionGet());
|
||||
|
||||
assertThat(e.status(), is(RestStatus.FORBIDDEN));
|
||||
assertThat(e.getMessage(), containsString("current license is non-compliant for [ml]"));
|
||||
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
|
||||
}
|
||||
|
||||
private void putInferenceModel(String modelId) {
|
||||
TrainedModelConfig config = TrainedModelConfig.builder()
|
||||
.setParsedDefinition(
|
||||
@ -755,13 +816,13 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
||||
.setTrainedModel(
|
||||
Tree.builder()
|
||||
.setTargetType(TargetType.REGRESSION)
|
||||
.setFeatureNames(Arrays.asList("feature1"))
|
||||
.setFeatureNames(Collections.singletonList("feature1"))
|
||||
.setNodes(TreeNode.builder(0).setLeafValue(1.0))
|
||||
.build())
|
||||
.setPreProcessors(Collections.emptyList()))
|
||||
.setModelId(modelId)
|
||||
.setDescription("test model for classification")
|
||||
.setInput(new TrainedModelInput(Arrays.asList("feature1")))
|
||||
.setInput(new TrainedModelInput(Collections.singletonList("feature1")))
|
||||
.setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
|
||||
.build();
|
||||
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
|
||||
|
@ -980,10 +980,9 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
|
||||
@Override
|
||||
public List<PipelineAggregationSpec> getPipelineAggregations() {
|
||||
PipelineAggregationSpec spec = new PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME,
|
||||
in -> new InferencePipelineAggregationBuilder(in, modelLoadingService),
|
||||
in -> new InferencePipelineAggregationBuilder(in, getLicenseState(), modelLoadingService),
|
||||
(ContextParser<String, ? extends PipelineAggregationBuilder>)
|
||||
(parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, name, parser
|
||||
));
|
||||
(parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, getLicenseState(), name, parser));
|
||||
spec.addResultReader(InternalInferenceAggregation::new);
|
||||
|
||||
return Collections.singletonList(spec);
|
||||
|
@ -10,15 +10,17 @@ import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryRewriteContext;
|
||||
import org.elasticsearch.license.LicenseUtils;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
|
||||
@ -44,10 +46,10 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
|
||||
static final String AGGREGATIONS_RESULTS_FIELD = "value";
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder,
|
||||
Tuple<SetOnce<ModelLoadingService>, String>> PARSER = new ConstructingObjectParser<>(
|
||||
NAME, false,
|
||||
(args, context) -> new InferencePipelineAggregationBuilder(context.v2(), context.v1(), (Map<String, String>) args[0])
|
||||
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder, ParserSupplement> PARSER =
|
||||
new ConstructingObjectParser<>(NAME, false,
|
||||
(args, context) -> new InferencePipelineAggregationBuilder(context.name, context.modelLoadingService,
|
||||
context.licenseState, (Map<String, String>) args[0])
|
||||
);
|
||||
|
||||
static {
|
||||
@ -60,34 +62,52 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
|
||||
private final Map<String, String> bucketPathMap;
|
||||
private String modelId;
|
||||
private InferenceConfigUpdate inferenceConfig;
|
||||
private final XPackLicenseState licenseState;
|
||||
private final SetOnce<ModelLoadingService> modelLoadingService;
|
||||
/**
|
||||
* The model. Set to a non-null value during the rewrite phase.
|
||||
*/
|
||||
private final Supplier<LocalModel> model;
|
||||
|
||||
private static class ParserSupplement {
|
||||
final XPackLicenseState licenseState;
|
||||
final SetOnce<ModelLoadingService> modelLoadingService;
|
||||
final String name;
|
||||
|
||||
ParserSupplement(String name, XPackLicenseState licenseState, SetOnce<ModelLoadingService> modelLoadingService) {
|
||||
this.name = name;
|
||||
this.licenseState = licenseState;
|
||||
this.modelLoadingService = modelLoadingService;
|
||||
}
|
||||
}
|
||||
public static InferencePipelineAggregationBuilder parse(SetOnce<ModelLoadingService> modelLoadingService,
|
||||
XPackLicenseState licenseState,
|
||||
String pipelineAggregatorName,
|
||||
XContentParser parser) {
|
||||
Tuple<SetOnce<ModelLoadingService>, String> context = new Tuple<>(modelLoadingService, pipelineAggregatorName);
|
||||
return PARSER.apply(parser, context);
|
||||
return PARSER.apply(parser, new ParserSupplement(pipelineAggregatorName, licenseState, modelLoadingService));
|
||||
}
|
||||
|
||||
public InferencePipelineAggregationBuilder(String name, SetOnce<ModelLoadingService> modelLoadingService,
|
||||
public InferencePipelineAggregationBuilder(String name,
|
||||
SetOnce<ModelLoadingService> modelLoadingService,
|
||||
XPackLicenseState licenseState,
|
||||
Map<String, String> bucketsPath) {
|
||||
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
|
||||
this.modelLoadingService = modelLoadingService;
|
||||
this.bucketPathMap = bucketsPath;
|
||||
this.model = null;
|
||||
this.licenseState = licenseState;
|
||||
}
|
||||
|
||||
public InferencePipelineAggregationBuilder(StreamInput in, SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
|
||||
public InferencePipelineAggregationBuilder(StreamInput in,
|
||||
XPackLicenseState licenseState,
|
||||
SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
|
||||
super(in, NAME);
|
||||
modelId = in.readString();
|
||||
bucketPathMap = in.readMap(StreamInput::readString, StreamInput::readString);
|
||||
inferenceConfig = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
|
||||
this.modelLoadingService = modelLoadingService;
|
||||
this.model = null;
|
||||
this.licenseState = licenseState;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -98,7 +118,8 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
|
||||
Map<String, String> bucketsPath,
|
||||
Supplier<LocalModel> model,
|
||||
String modelId,
|
||||
InferenceConfigUpdate inferenceConfig
|
||||
InferenceConfigUpdate inferenceConfig,
|
||||
XPackLicenseState licenseState
|
||||
) {
|
||||
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
|
||||
modelLoadingService = null;
|
||||
@ -113,13 +134,14 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
|
||||
*/
|
||||
this.modelId = modelId;
|
||||
this.inferenceConfig = inferenceConfig;
|
||||
this.licenseState = licenseState;
|
||||
}
|
||||
|
||||
void setModelId(String modelId) {
|
||||
public void setModelId(String modelId) {
|
||||
this.modelId = modelId;
|
||||
}
|
||||
|
||||
void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
|
||||
public void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
|
||||
this.inferenceConfig = inferenceConfig;
|
||||
}
|
||||
|
||||
@ -160,7 +182,7 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) throws IOException {
|
||||
public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) {
|
||||
if (model != null) {
|
||||
return this;
|
||||
}
|
||||
@ -168,10 +190,17 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
|
||||
context.registerAsyncAction((client, listener) -> {
|
||||
modelLoadingService.get().getModelForSearch(modelId, ActionListener.delegateFailure(listener, (delegate, model) -> {
|
||||
loadedModel.set(model);
|
||||
delegate.onResponse(null);
|
||||
|
||||
boolean isLicensed = licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING) ||
|
||||
licenseState.isAllowedByLicense(model.getLicenseLevel());
|
||||
if (isLicensed) {
|
||||
delegate.onResponse(null);
|
||||
} else {
|
||||
delegate.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
|
||||
}
|
||||
}));
|
||||
});
|
||||
return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig);
|
||||
return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig, licenseState);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -6,6 +6,7 @@
|
||||
package org.elasticsearch.xpack.ml.inference.loadingservice;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
@ -38,6 +39,7 @@ public class LocalModel {
|
||||
private volatile long persistenceQuotient = 100;
|
||||
private final LongAdder currentInferenceCount;
|
||||
private final InferenceConfig inferenceConfig;
|
||||
private final License.OperationMode licenseLevel;
|
||||
|
||||
public LocalModel(String modelId,
|
||||
String nodeId,
|
||||
@ -45,6 +47,7 @@ public class LocalModel {
|
||||
TrainedModelInput input,
|
||||
Map<String, String> defaultFieldMap,
|
||||
InferenceConfig modelInferenceConfig,
|
||||
License.OperationMode licenseLevel,
|
||||
TrainedModelStatsService trainedModelStatsService) {
|
||||
this.trainedModelDefinition = trainedModelDefinition;
|
||||
this.modelId = modelId;
|
||||
@ -56,6 +59,7 @@ public class LocalModel {
|
||||
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
|
||||
this.currentInferenceCount = new LongAdder();
|
||||
this.inferenceConfig = modelInferenceConfig;
|
||||
this.licenseLevel = licenseLevel;
|
||||
}
|
||||
|
||||
long ramBytesUsed() {
|
||||
@ -66,6 +70,10 @@ public class LocalModel {
|
||||
return modelId;
|
||||
}
|
||||
|
||||
public License.OperationMode getLicenseLevel() {
|
||||
return licenseLevel;
|
||||
}
|
||||
|
||||
public InferenceStats getLatestStatsAndReset() {
|
||||
return statsAccumulator.currentStatsAndReset();
|
||||
}
|
||||
|
@ -309,6 +309,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
||||
trainedModelConfig.getInput(),
|
||||
trainedModelConfig.getDefaultFieldMap(),
|
||||
inferenceConfig,
|
||||
trainedModelConfig.getLicenseLevel(),
|
||||
modelStatsService));
|
||||
},
|
||||
// Failure getting the definition, remove the initial estimation value
|
||||
@ -337,6 +338,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
||||
trainedModelConfig.getInput(),
|
||||
trainedModelConfig.getDefaultFieldMap(),
|
||||
inferenceConfig,
|
||||
trainedModelConfig.getLicenseLevel(),
|
||||
modelStatsService);
|
||||
synchronized (loadingListeners) {
|
||||
listeners = loadingListeners.remove(modelId);
|
||||
|
@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.plugins.SearchPlugin;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
|
||||
@ -61,7 +62,8 @@ public class InferencePipelineAggregationBuilderTests extends BasePipelineAggreg
|
||||
.collect(Collectors.toMap(Function.identity(), (t) -> randomAlphaOfLength(5)));
|
||||
|
||||
InferencePipelineAggregationBuilder builder =
|
||||
new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)), bucketPaths);
|
||||
new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)),
|
||||
mock(XPackLicenseState.class), bucketPaths);
|
||||
builder.setModelId(randomAlphaOfLength(6));
|
||||
|
||||
if (randomBoolean()) {
|
||||
|
@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
|
||||
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
@ -73,6 +74,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
new TrainedModelInput(inputFields),
|
||||
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
randomFrom(License.OperationMode.values()),
|
||||
modelStatsService);
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("field.foo", 1.0);
|
||||
@ -102,6 +104,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
new TrainedModelInput(inputFields),
|
||||
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
License.OperationMode.PLATINUM,
|
||||
modelStatsService);
|
||||
result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
|
||||
assertThat(result.value(), equalTo(0.0));
|
||||
@ -144,6 +147,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
new TrainedModelInput(inputFields),
|
||||
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
License.OperationMode.PLATINUM,
|
||||
modelStatsService);
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("field.foo", 1.0);
|
||||
@ -199,6 +203,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
new TrainedModelInput(inputFields),
|
||||
Collections.singletonMap("bar", "bar.keyword"),
|
||||
RegressionConfig.EMPTY_PARAMS,
|
||||
License.OperationMode.PLATINUM,
|
||||
modelStatsService);
|
||||
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
@ -226,6 +231,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
new TrainedModelInput(inputFields),
|
||||
null,
|
||||
RegressionConfig.EMPTY_PARAMS,
|
||||
License.OperationMode.PLATINUM,
|
||||
modelStatsService);
|
||||
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
@ -256,6 +262,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
new TrainedModelInput(inputFields),
|
||||
null,
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
License.OperationMode.PLATINUM,
|
||||
modelStatsService
|
||||
);
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
|
Loading…
x
Reference in New Issue
Block a user