[ML][Inference] allowing per-model licensing (#49398) (#49435)

* [ML][Inference] allowing per-model licensing

* changing to internal action + removing pre-mature opt
This commit is contained in:
Benjamin Trent 2019-11-21 09:46:34 -05:00 committed by GitHub
parent f264808a6a
commit d41b2e3f38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 516 additions and 212 deletions

View File

@ -50,6 +50,7 @@ public class TrainedModelConfig implements ToXContentObject {
public static final ParseField INPUT = new ParseField("input");
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
@ -71,6 +72,7 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
}
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
@ -88,6 +90,7 @@ public class TrainedModelConfig implements ToXContentObject {
private final TrainedModelInput input;
private final Long estimatedHeapMemory;
private final Long estimatedOperations;
private final String licenseLevel;
TrainedModelConfig(String modelId,
String createdBy,
@ -99,7 +102,8 @@ public class TrainedModelConfig implements ToXContentObject {
Map<String, Object> metadata,
TrainedModelInput input,
Long estimatedHeapMemory,
Long estimatedOperations) {
Long estimatedOperations,
String licenseLevel) {
this.modelId = modelId;
this.createdBy = createdBy;
this.version = version;
@ -111,6 +115,7 @@ public class TrainedModelConfig implements ToXContentObject {
this.input = input;
this.estimatedHeapMemory = estimatedHeapMemory;
this.estimatedOperations = estimatedOperations;
this.licenseLevel = licenseLevel;
}
public String getModelId() {
@ -161,6 +166,10 @@ public class TrainedModelConfig implements ToXContentObject {
return estimatedOperations;
}
public String getLicenseLevel() {
return licenseLevel;
}
public static Builder builder() {
return new Builder();
}
@ -201,6 +210,9 @@ public class TrainedModelConfig implements ToXContentObject {
if (estimatedOperations != null) {
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
}
if (licenseLevel != null) {
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel);
}
builder.endObject();
return builder;
}
@ -225,6 +237,7 @@ public class TrainedModelConfig implements ToXContentObject {
Objects.equals(input, that.input) &&
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
Objects.equals(estimatedOperations, that.estimatedOperations) &&
Objects.equals(licenseLevel, that.licenseLevel) &&
Objects.equals(metadata, that.metadata);
}
@ -240,6 +253,7 @@ public class TrainedModelConfig implements ToXContentObject {
estimatedHeapMemory,
estimatedOperations,
metadata,
licenseLevel,
input);
}
@ -257,6 +271,7 @@ public class TrainedModelConfig implements ToXContentObject {
private TrainedModelInput input;
private Long estimatedHeapMemory;
private Long estimatedOperations;
private String licenseLevel;
public Builder setModelId(String modelId) {
this.modelId = modelId;
@ -312,16 +327,21 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
private Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
this.estimatedHeapMemory = estimatedHeapMemory;
return this;
}
public Builder setEstimatedOperations(Long estimatedOperations) {
private Builder setEstimatedOperations(Long estimatedOperations) {
this.estimatedOperations = estimatedOperations;
return this;
}
private Builder setLicenseLevel(String licenseLevel) {
this.licenseLevel = licenseLevel;
return this;
}
public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
@ -334,7 +354,8 @@ public class TrainedModelConfig implements ToXContentObject {
metadata,
input,
estimatedHeapMemory,
estimatedOperations);
estimatedOperations,
licenseLevel);
}
}

View File

@ -66,7 +66,8 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
randomBoolean() ? null : randomNonNegativeLong(),
randomBoolean() ? null : randomNonNegativeLong());
randomBoolean() ? null : randomNonNegativeLong(),
randomBoolean() ? null : randomFrom("platinum", "basic"));
}

View File

@ -112,7 +112,7 @@ import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction;
import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
@ -175,13 +175,6 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
import org.elasticsearch.xpack.core.monitoring.MonitoringFeatureSetUsage;
import org.elasticsearch.xpack.core.rollup.RollupFeatureSetUsage;
@ -389,7 +382,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
StopDataFrameAnalyticsAction.INSTANCE,
EvaluateDataFrameAction.INSTANCE,
EstimateMemoryUsageAction.INSTANCE,
InferModelAction.INSTANCE,
InternalInferModelAction.INSTANCE,
GetTrainedModelsAction.INSTANCE,
DeleteTrainedModelAction.INSTANCE,
GetTrainedModelsStatsAction.INSTANCE,

View File

@ -24,12 +24,12 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
public class InferModelAction extends ActionType<InferModelAction.Response> {
public class InternalInferModelAction extends ActionType<InternalInferModelAction.Response> {
public static final InferModelAction INSTANCE = new InferModelAction();
public static final String NAME = "cluster:admin/xpack/ml/inference/infer";
public static final InternalInferModelAction INSTANCE = new InternalInferModelAction();
public static final String NAME = "cluster:internal/xpack/ml/inference/infer";
private InferModelAction() {
private InternalInferModelAction() {
super(NAME, Response::new);
}
@ -38,21 +38,27 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
private final String modelId;
private final List<Map<String, Object>> objectsToInfer;
private final InferenceConfig config;
private final boolean previouslyLicensed;
public Request(String modelId) {
this(modelId, Collections.emptyList(), new RegressionConfig());
public Request(String modelId, boolean previouslyLicensed) {
this(modelId, Collections.emptyList(), new RegressionConfig(), previouslyLicensed);
}
public Request(String modelId, List<Map<String, Object>> objectsToInfer, InferenceConfig inferenceConfig) {
public Request(String modelId,
List<Map<String, Object>> objectsToInfer,
InferenceConfig inferenceConfig,
boolean previouslyLicensed) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer"));
this.config = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config");
this.previouslyLicensed = previouslyLicensed;
}
public Request(String modelId, Map<String, Object> objectToInfer, InferenceConfig config) {
public Request(String modelId, Map<String, Object> objectToInfer, InferenceConfig config, boolean previouslyLicensed) {
this(modelId,
Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")),
config);
config,
previouslyLicensed);
}
public Request(StreamInput in) throws IOException {
@ -60,6 +66,7 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
this.modelId = in.readString();
this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap));
this.config = in.readNamedWriteable(InferenceConfig.class);
this.previouslyLicensed = in.readBoolean();
}
public String getModelId() {
@ -74,6 +81,10 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
return config;
}
public boolean isPreviouslyLicensed() {
return previouslyLicensed;
}
@Override
public ActionRequestValidationException validate() {
return null;
@ -85,21 +96,23 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
out.writeString(modelId);
out.writeCollection(objectsToInfer, StreamOutput::writeMap);
out.writeNamedWriteable(config);
out.writeBoolean(previouslyLicensed);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferModelAction.Request that = (InferModelAction.Request) o;
InternalInferModelAction.Request that = (InternalInferModelAction.Request) o;
return Objects.equals(modelId, that.modelId)
&& Objects.equals(config, that.config)
&& Objects.equals(previouslyLicensed, that.previouslyLicensed)
&& Objects.equals(objectsToInfer, that.objectsToInfer);
}
@Override
public int hashCode() {
return Objects.hash(modelId, objectsToInfer, config);
return Objects.hash(modelId, objectsToInfer, config, previouslyLicensed);
}
}
@ -107,37 +120,68 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
public static class Response extends ActionResponse {
private final List<InferenceResults> inferenceResults;
private final boolean isLicensed;
public Response(List<InferenceResults> inferenceResults) {
public Response(List<InferenceResults> inferenceResults, boolean isLicensed) {
super();
this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults"));
this.isLicensed = isLicensed;
}
public Response(StreamInput in) throws IOException {
super(in);
this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class));
this.isLicensed = in.readBoolean();
}
public List<InferenceResults> getInferenceResults() {
return inferenceResults;
}
public boolean isLicensed() {
return isLicensed;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteableList(inferenceResults);
out.writeBoolean(isLicensed);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferModelAction.Response that = (InferModelAction.Response) o;
return Objects.equals(inferenceResults, that.inferenceResults);
InternalInferModelAction.Response that = (InternalInferModelAction.Response) o;
return isLicensed == that.isLicensed && Objects.equals(inferenceResults, that.inferenceResults);
}
@Override
public int hashCode() {
return Objects.hash(inferenceResults);
return Objects.hash(inferenceResults, isLicensed);
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private List<InferenceResults> inferenceResults;
private boolean isLicensed;
public Builder setInferenceResults(List<InferenceResults> inferenceResults) {
this.inferenceResults = inferenceResults;
return this;
}
public Builder setLicensed(boolean licensed) {
isLicensed = licensed;
return this;
}
public Response build() {
return new Response(inferenceResults, isLicensed);
}
}
}

View File

@ -17,6 +17,8 @@ import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.license.License;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.xpack.core.common.time.TimeUtils;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
@ -48,6 +50,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final ParseField INPUT = new ParseField("input");
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
@ -73,6 +76,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
INPUT);
parser.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
parser.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
parser.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
return parser;
}
@ -90,6 +94,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private final TrainedModelInput input;
private final long estimatedHeapMemory;
private final long estimatedOperations;
private final License.OperationMode licenseLevel;
private final TrainedModelDefinition definition;
@ -103,7 +108,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
Map<String, Object> metadata,
TrainedModelInput input,
Long estimatedHeapMemory,
Long estimatedOperations) {
Long estimatedOperations,
String licenseLevel) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
@ -122,6 +128,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
throw new IllegalArgumentException("[" + ESTIMATED_OPERATIONS.getPreferredName() + "] must be greater than or equal to 0");
}
this.estimatedOperations = estimatedOperations;
this.licenseLevel = License.OperationMode.resolve(ExceptionsHelper.requireNonNull(licenseLevel, LICENSE_LEVEL));
}
public TrainedModelConfig(StreamInput in) throws IOException {
@ -136,6 +143,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
input = new TrainedModelInput(in);
estimatedHeapMemory = in.readVLong();
estimatedOperations = in.readVLong();
licenseLevel = License.OperationMode.resolve(in.readString());
}
public String getModelId() {
@ -187,6 +195,25 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return estimatedOperations;
}
public License.OperationMode getLicenseLevel() {
return licenseLevel;
}
public boolean isAvailableWithLicense(XPackLicenseState licenseState) {
// Basic is always true
if (licenseLevel.equals(License.OperationMode.BASIC)) {
return true;
}
// The model license does not matter, this is the highest licensed level
if (licenseState.isActive() && XPackLicenseState.isPlatinumOrTrialOperationMode(licenseState.getOperationMode())) {
return true;
}
// catch the rest, if the license is active and is at least the required model license
return licenseState.isActive() && License.OperationMode.compare(licenseState.getOperationMode(), licenseLevel) >= 0;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
@ -200,6 +227,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
input.writeTo(out);
out.writeVLong(estimatedHeapMemory);
out.writeVLong(estimatedOperations);
out.writeString(licenseLevel.description());
}
@Override
@ -229,6 +257,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
new ByteSizeValue(estimatedHeapMemory));
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
builder.endObject();
return builder;
}
@ -253,6 +282,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
Objects.equals(input, that.input) &&
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
Objects.equals(estimatedOperations, that.estimatedOperations) &&
Objects.equals(licenseLevel, that.licenseLevel) &&
Objects.equals(metadata, that.metadata);
}
@ -268,7 +298,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
metadata,
estimatedHeapMemory,
estimatedOperations,
input);
input,
licenseLevel);
}
public static class Builder {
@ -284,6 +315,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private TrainedModelDefinition definition;
private Long estimatedHeapMemory;
private Long estimatedOperations;
private String licenseLevel = License.OperationMode.PLATINUM.description();
public Builder setModelId(String modelId) {
this.modelId = modelId;
@ -349,6 +381,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this;
}
public Builder setLicenseLevel(String licenseLevel) {
this.licenseLevel = licenseLevel;
return this;
}
// TODO move to REST level instead of here in the builder
public void validate() {
// We require a definition to be available here even though it will be stored in a different doc
@ -366,28 +403,17 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
MlStrings.ID_LENGTH_LIMIT));
}
if (version != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", VERSION.getPreferredName());
}
checkIllegalSetting(version, VERSION.getPreferredName());
checkIllegalSetting(createdBy, CREATED_BY.getPreferredName());
checkIllegalSetting(createTime, CREATE_TIME.getPreferredName());
checkIllegalSetting(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName());
checkIllegalSetting(estimatedOperations, ESTIMATED_OPERATIONS.getPreferredName());
checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName());
}
if (createdBy != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
CREATED_BY.getPreferredName());
}
if (createTime != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
CREATE_TIME.getPreferredName());
}
if (estimatedHeapMemory != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName());
}
if (estimatedOperations != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
ESTIMATED_OPERATIONS.getPreferredName());
private static void checkIllegalSetting(Object value, String setting) {
if (value != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", setting);
}
}
@ -403,7 +429,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
metadata,
input,
estimatedHeapMemory,
estimatedOperations);
estimatedOperations,
licenseLevel);
}
}

View File

@ -8,7 +8,7 @@ package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
@ -22,19 +22,21 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
public class InferModelActionRequestTests extends AbstractWireSerializingTestCase<Request> {
public class InternalInferModelActionRequestTests extends AbstractWireSerializingTestCase<Request> {
@Override
protected Request createTestInstance() {
return randomBoolean() ?
new Request(
randomAlphaOfLength(10),
Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()),
randomInferenceConfig()) :
Stream.generate(InternalInferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()),
randomInferenceConfig(),
randomBoolean()) :
new Request(
randomAlphaOfLength(10),
randomMap(),
randomInferenceConfig());
randomInferenceConfig(),
randomBoolean());
}
private static InferenceConfig randomInferenceConfig() {

View File

@ -8,7 +8,7 @@ package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests;
@ -16,12 +16,10 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class InferModelActionResponseTests extends AbstractWireSerializingTestCase<Response> {
public class InternalInferModelActionResponseTests extends AbstractWireSerializingTestCase<Response> {
@Override
protected Response createTestInstance() {
@ -29,7 +27,8 @@ public class InferModelActionResponseTests extends AbstractWireSerializingTestCa
return new Response(
Stream.generate(() -> randomInferenceResult(resultType))
.limit(randomIntBetween(0, 10))
.collect(Collectors.toList()));
.collect(Collectors.toList()),
randomBoolean());
}
private static InferenceResults randomInferenceResult(String resultType) {
@ -50,9 +49,7 @@ public class InferModelActionResponseTests extends AbstractWireSerializingTestCa
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(entries);
return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
}
}

View File

@ -16,6 +16,8 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.license.License;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
@ -37,11 +39,29 @@ import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNA
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class TrainedModelConfigTests extends AbstractSerializingTestCase<TrainedModelConfig> {
private boolean lenient;
public static TrainedModelConfig.Builder createTestInstance(String modelId) {
List<String> tags = Arrays.asList(generateRandomStringArray(randomIntBetween(0, 5), 15, false));
return TrainedModelConfig.builder()
.setInput(TrainedModelInputTests.createRandomInput())
.setMetadata(randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)))
.setCreateTime(Instant.ofEpochMilli(randomNonNegativeLong()))
.setVersion(Version.CURRENT)
.setModelId(modelId)
.setCreatedBy(randomAlphaOfLength(10))
.setDescription(randomBoolean() ? null : randomAlphaOfLength(100))
.setEstimatedHeapMemory(randomNonNegativeLong())
.setEstimatedOperations(randomNonNegativeLong())
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setTags(tags);
}
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
@ -64,19 +84,7 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
@Override
protected TrainedModelConfig createTestInstance() {
List<String> tags = Arrays.asList(generateRandomStringArray(randomIntBetween(0, 5), 15, false));
return new TrainedModelConfig(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
Version.CURRENT,
randomBoolean() ? null : randomAlphaOfLength(100),
Instant.ofEpochMilli(randomNonNegativeLong()),
null, // is not parsed so should not be provided
tags,
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
TrainedModelInputTests.createRandomInput(),
randomNonNegativeLong(),
randomNonNegativeLong());
return createTestInstance(randomAlphaOfLength(10)).build();
}
@Override
@ -121,7 +129,8 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
TrainedModelInputTests.createRandomInput(),
randomNonNegativeLong(),
randomNonNegativeLong());
randomNonNegativeLong(),
"platinum");
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
assertThat(reference.utf8ToString(), containsString("definition"));
@ -179,4 +188,39 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
.setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));
}
public void testIsAvailableWithLicense() {
TrainedModelConfig.Builder builder = createTestInstance(randomAlphaOfLength(10));
XPackLicenseState licenseState = mock(XPackLicenseState.class);
when(licenseState.isActive()).thenReturn(false);
when(licenseState.getOperationMode()).thenReturn(License.OperationMode.BASIC);
assertFalse(builder.setLicenseLevel(License.OperationMode.PLATINUM.description()).build().isAvailableWithLicense(licenseState));
assertTrue(builder.setLicenseLevel(License.OperationMode.BASIC.description()).build().isAvailableWithLicense(licenseState));
when(licenseState.isActive()).thenReturn(true);
when(licenseState.getOperationMode()).thenReturn(License.OperationMode.PLATINUM);
assertTrue(builder.setLicenseLevel(License.OperationMode.PLATINUM.description()).build().isAvailableWithLicense(licenseState));
assertTrue(builder.setLicenseLevel(License.OperationMode.BASIC.description()).build().isAvailableWithLicense(licenseState));
assertTrue(builder.setLicenseLevel(License.OperationMode.GOLD.description()).build().isAvailableWithLicense(licenseState));
when(licenseState.isActive()).thenReturn(false);
assertFalse(builder.setLicenseLevel(License.OperationMode.PLATINUM.description()).build().isAvailableWithLicense(licenseState));
assertTrue(builder.setLicenseLevel(License.OperationMode.BASIC.description()).build().isAvailableWithLicense(licenseState));
assertFalse(builder.setLicenseLevel(License.OperationMode.GOLD.description()).build().isAvailableWithLicense(licenseState));
when(licenseState.isActive()).thenReturn(true);
when(licenseState.getOperationMode()).thenReturn(License.OperationMode.GOLD);
assertFalse(builder.setLicenseLevel(License.OperationMode.PLATINUM.description()).build().isAvailableWithLicense(licenseState));
assertTrue(builder.setLicenseLevel(License.OperationMode.BASIC.description()).build().isAvailableWithLicense(licenseState));
assertTrue(builder.setLicenseLevel(License.OperationMode.GOLD.description()).build().isAvailableWithLicense(licenseState));
when(licenseState.isActive()).thenReturn(false);
assertFalse(builder.setLicenseLevel(License.OperationMode.PLATINUM.description()).build().isAvailableWithLicense(licenseState));
assertTrue(builder.setLicenseLevel(License.OperationMode.BASIC.description()).build().isAvailableWithLicense(licenseState));
assertFalse(builder.setLicenseLevel(License.OperationMode.GOLD.description()).build().isAvailableWithLicense(licenseState));
}
}

View File

@ -370,6 +370,7 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for regression\",\n" +
" \"version\": \"7.6.0\",\n" +
" \"license_level\": \"platinum\",\n" +
" \"created_by\": \"ml_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
" \"estimated_operations\": 0," +
@ -503,6 +504,7 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for classification\",\n" +
" \"version\": \"7.6.0\",\n" +
" \"license_level\": \"platinum\",\n" +
" \"created_by\": \"benwtrent\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
" \"estimated_operations\": 0," +

View File

@ -18,6 +18,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.license.License;
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
@ -193,6 +194,7 @@ public class TrainedModelIT extends ESRestTestCase {
.setVersion(Version.CURRENT)
.setCreateTime(Instant.now())
.setEstimatedOperations(0)
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setEstimatedHeapMemory(0)
.build()
.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));

View File

@ -99,7 +99,7 @@ import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
@ -168,7 +168,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetModelSnapshotsAction;
import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction;
import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction;
import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsStatsAction;
import org.elasticsearch.xpack.ml.action.TransportInferModelAction;
import org.elasticsearch.xpack.ml.action.TransportInternalInferModelAction;
import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsAction;
import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction;
import org.elasticsearch.xpack.ml.action.TransportKillProcessAction;
@ -346,9 +346,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
InferenceProcessor.Factory inferenceFactory = new InferenceProcessor.Factory(parameters.client,
parameters.ingestService.getClusterService(),
this.settings,
parameters.ingestService,
getLicenseState());
getLicenseState().addListener(inferenceFactory);
parameters.ingestService);
parameters.ingestService.addIngestClusterStateListener(inferenceFactory);
return Collections.singletonMap(InferenceProcessor.TYPE, inferenceFactory);
}
@ -831,7 +829,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
new ActionHandler<>(StopDataFrameAnalyticsAction.INSTANCE, TransportStopDataFrameAnalyticsAction.class),
new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class),
new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class),
new ActionHandler<>(InferModelAction.INSTANCE, TransportInferModelAction.class),
new ActionHandler<>(InternalInferModelAction.INSTANCE, TransportInternalInferModelAction.class),
new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class),
new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class),
new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class)

View File

@ -16,38 +16,41 @@ import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.ml.inference.loadingservice.Model;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
public class TransportInferModelAction extends HandledTransportAction<InferModelAction.Request, InferModelAction.Response> {
public class TransportInternalInferModelAction extends HandledTransportAction<Request, Response> {
private final ModelLoadingService modelLoadingService;
private final Client client;
private final XPackLicenseState licenseState;
private final TrainedModelProvider trainedModelProvider;
@Inject
public TransportInferModelAction(TransportService transportService,
ActionFilters actionFilters,
ModelLoadingService modelLoadingService,
Client client,
XPackLicenseState licenseState) {
super(InferModelAction.NAME, transportService, actionFilters, InferModelAction.Request::new);
public TransportInternalInferModelAction(TransportService transportService,
ActionFilters actionFilters,
ModelLoadingService modelLoadingService,
Client client,
XPackLicenseState licenseState,
TrainedModelProvider trainedModelProvider) {
super(InternalInferModelAction.NAME, transportService, actionFilters, InternalInferModelAction.Request::new);
this.modelLoadingService = modelLoadingService;
this.client = client;
this.licenseState = licenseState;
this.trainedModelProvider = trainedModelProvider;
}
@Override
protected void doExecute(Task task, InferModelAction.Request request, ActionListener<InferModelAction.Response> listener) {
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
if (licenseState.isMachineLearningAllowed() == false) {
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
return;
}
Response.Builder responseBuilder = Response.builder();
ActionListener<Model> getModelListener = ActionListener.wrap(
model -> {
@ -63,13 +66,28 @@ public class TransportInferModelAction extends HandledTransportAction<InferModel
typedChainTaskExecutor.execute(ActionListener.wrap(
inferenceResultsInterfaces ->
listener.onResponse(new InferModelAction.Response(inferenceResultsInterfaces)),
listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).build()),
listener::onFailure
));
},
listener::onFailure
);
this.modelLoadingService.getModel(request.getModelId(), getModelListener);
if (licenseState.isMachineLearningAllowed()) {
responseBuilder.setLicensed(true);
this.modelLoadingService.getModel(request.getModelId(), getModelListener);
} else {
trainedModelProvider.getTrainedModel(request.getModelId(), false, ActionListener.wrap(
trainedModelConfig -> {
responseBuilder.setLicensed(trainedModelConfig.isAvailableWithLicense(licenseState));
if (trainedModelConfig.isAvailableWithLicense(licenseState) || request.isPreviouslyLicensed()) {
this.modelLoadingService.getModel(request.getModelId(), getModelListener);
} else {
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
}
},
listener::onFailure
));
}
}
}

View File

@ -14,6 +14,7 @@ import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
@ -157,6 +158,7 @@ public class AnalyticsResultProcessor {
.setEstimatedHeapMemory(definition.ramBytesUsed())
.setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations())
.setInput(new TrainedModelInput(fieldNames))
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.build();
}

View File

@ -26,22 +26,20 @@ import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.ingest.PipelineConfiguration;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.license.LicenseStateListener;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
@ -73,8 +71,12 @@ public class InferenceProcessor extends AbstractProcessor {
private final InferenceConfig inferenceConfig;
private final Map<String, String> fieldMapping;
private final boolean includeModelMetadata;
private final InferenceAuditor auditor;
private volatile boolean previouslyLicensed;
private final AtomicBoolean shouldAudit = new AtomicBoolean(true);
public InferenceProcessor(Client client,
InferenceAuditor auditor,
String tag,
String targetField,
String modelId,
@ -85,6 +87,7 @@ public class InferenceProcessor extends AbstractProcessor {
super(tag);
this.client = ExceptionsHelper.requireNonNull(client, "client");
this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD);
this.auditor = ExceptionsHelper.requireNonNull(auditor, "auditor");
this.modelInfoField = ExceptionsHelper.requireNonNull(modelInfoField, MODEL_INFO_FIELD);
this.includeModelMetadata = includeModelMetadata;
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
@ -100,22 +103,32 @@ public class InferenceProcessor extends AbstractProcessor {
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
executeAsyncWithOrigin(client,
ML_ORIGIN,
InferModelAction.INSTANCE,
InternalInferModelAction.INSTANCE,
this.buildRequest(ingestDocument),
ActionListener.wrap(
r -> {
try {
mutateDocument(r, ingestDocument);
handler.accept(ingestDocument, null);
} catch(ElasticsearchException ex) {
handler.accept(ingestDocument, ex);
}
},
r -> handleResponse(r, ingestDocument, handler),
e -> handler.accept(ingestDocument, e)
));
}
InferModelAction.Request buildRequest(IngestDocument ingestDocument) {
void handleResponse(InternalInferModelAction.Response response,
IngestDocument ingestDocument,
BiConsumer<IngestDocument, Exception> handler) {
if (previouslyLicensed == false) {
previouslyLicensed = true;
}
if (response.isLicensed() == false) {
auditWarningAboutLicenseIfNecessary();
}
try {
mutateDocument(response, ingestDocument);
handler.accept(ingestDocument, null);
} catch(ElasticsearchException ex) {
handler.accept(ingestDocument, ex);
}
}
InternalInferModelAction.Request buildRequest(IngestDocument ingestDocument) {
Map<String, Object> fields = new HashMap<>(ingestDocument.getSourceAndMetadata());
if (fieldMapping != null) {
fieldMapping.forEach((src, dest) -> {
@ -125,10 +138,19 @@ public class InferenceProcessor extends AbstractProcessor {
}
});
}
return new InferModelAction.Request(modelId, fields, inferenceConfig);
return new InternalInferModelAction.Request(modelId, fields, inferenceConfig, previouslyLicensed);
}
void mutateDocument(InferModelAction.Response response, IngestDocument ingestDocument) {
void auditWarningAboutLicenseIfNecessary() {
if (shouldAudit.compareAndSet(true, false)) {
auditor.warning(
modelId,
"This cluster is no longer licensed to use this model in the inference ingest processor. " +
"Please update your license information.");
}
}
void mutateDocument(InternalInferModelAction.Response response, IngestDocument ingestDocument) {
if (response.getInferenceResults().isEmpty()) {
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
}
@ -148,28 +170,25 @@ public class InferenceProcessor extends AbstractProcessor {
return TYPE;
}
public static final class Factory implements Processor.Factory, Consumer<ClusterState>, LicenseStateListener {
public static final class Factory implements Processor.Factory, Consumer<ClusterState> {
private static final Logger logger = LogManager.getLogger(Factory.class);
private final Client client;
private final IngestService ingestService;
private final XPackLicenseState licenseState;
private final InferenceAuditor auditor;
private volatile int currentInferenceProcessors;
private volatile int maxIngestProcessors;
private volatile Version minNodeVersion = Version.CURRENT;
private volatile boolean inferenceAllowed;
public Factory(Client client,
ClusterService clusterService,
Settings settings,
IngestService ingestService,
XPackLicenseState licenseState) {
IngestService ingestService) {
this.client = client;
this.maxIngestProcessors = MAX_INFERENCE_PROCESSORS.get(settings);
this.ingestService = ingestService;
this.licenseState = licenseState;
this.inferenceAllowed = licenseState.isMachineLearningAllowed();
this.auditor = new InferenceAuditor(client, clusterService.getNodeName());
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_INFERENCE_PROCESSORS, this::setMaxIngestProcessors);
}
@ -211,10 +230,6 @@ public class InferenceProcessor extends AbstractProcessor {
public InferenceProcessor create(Map<String, Processor.Factory> processorFactories, String tag, Map<String, Object> config)
throws Exception {
if (inferenceAllowed == false) {
throw LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING);
}
if (this.maxIngestProcessors <= currentInferenceProcessors) {
throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. " +
"Adjust the setting [{}]: [{}] if a greater number is desired.",
@ -236,6 +251,7 @@ public class InferenceProcessor extends AbstractProcessor {
modelInfoField += "." + tag;
}
return new InferenceProcessor(client,
auditor,
tag,
targetField,
modelId,
@ -289,9 +305,5 @@ public class InferenceProcessor extends AbstractProcessor {
}
}
@Override
public void licenseStateChanged() {
this.inferenceAllowed = licenseState.isMachineLearningAllowed();
}
}
}

View File

@ -8,6 +8,7 @@ package org.elasticsearch.license;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ingest.PutPipelineAction;
import org.elasticsearch.action.ingest.PutPipelineRequest;
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
import org.elasticsearch.action.ingest.SimulatePipelineAction;
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
@ -32,7 +33,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
@ -56,8 +57,10 @@ import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
@ -565,7 +568,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
" }\n" +
" }\n" +
" }]}\n";
// test that license restricted apis do now work
// Creating a pipeline should work
PlainActionFuture<AcknowledgedResponse> putPipelineListener = PlainActionFuture.newFuture();
client().execute(PutPipelineAction.INSTANCE,
new PutPipelineRequest("test_infer_license_pipeline",
@ -575,6 +578,12 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
AcknowledgedResponse putPipelineResponse = putPipelineListener.actionGet();
assertTrue(putPipelineResponse.isAcknowledged());
client().prepareIndex("infer_license_test", MapperService.SINGLE_MAPPING_NAME)
.setPipeline("test_infer_license_pipeline")
.setSource("{}", XContentType.JSON)
.execute()
.actionGet();
String simulateSource = "{\n" +
" \"pipeline\": \n" +
pipeline +
@ -594,37 +603,52 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
assertThat(simulatePipelineListener.actionGet().getResults(), is(not(empty())));
// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
enableLicensing(mode);
assertMLAllowed(false);
// creating a new pipeline should fail
// Inference against the previous pipeline should still work
try {
client().prepareIndex("infer_license_test", MapperService.SINGLE_MAPPING_NAME)
.setPipeline("test_infer_license_pipeline")
.setSource("{}", XContentType.JSON)
.execute()
.actionGet();
} catch (ElasticsearchSecurityException ex) {
fail(ex.getMessage());
}
// Creating a new pipeline with an inference processor should work
putPipelineListener = PlainActionFuture.newFuture();
client().execute(PutPipelineAction.INSTANCE,
new PutPipelineRequest("test_infer_license_pipeline_again",
new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON),
putPipelineListener);
putPipelineResponse = putPipelineListener.actionGet();
assertTrue(putPipelineResponse.isAcknowledged());
// Inference against the new pipeline should fail since it has never previously succeeded
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> {
PlainActionFuture<AcknowledgedResponse> listener = PlainActionFuture.newFuture();
client().execute(PutPipelineAction.INSTANCE,
new PutPipelineRequest("test_infer_license_pipeline_failure",
new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON),
listener);
listener.actionGet();
client().prepareIndex("infer_license_test", MapperService.SINGLE_MAPPING_NAME)
.setPipeline("test_infer_license_pipeline_again")
.setSource("{}", XContentType.JSON)
.execute()
.actionGet();
});
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("non-compliant"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
// Simulating the pipeline should fail
e = expectThrows(ElasticsearchSecurityException.class, () -> {
PlainActionFuture<SimulatePipelineResponse> listener = PlainActionFuture.newFuture();
client().execute(SimulatePipelineAction.INSTANCE,
new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON),
listener);
listener.actionGet();
});
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("non-compliant"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
SimulateDocumentBaseResult simulateResponse = (SimulateDocumentBaseResult)client().execute(SimulatePipelineAction.INSTANCE,
new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON))
.actionGet()
.getResults()
.get(0);
assertThat(simulateResponse.getFailure(), is(not(nullValue())));
assertThat((simulateResponse.getFailure()).getCause(), is(instanceOf(ElasticsearchSecurityException.class)));
// Pick a license that does allow machine learning
mode = randomValidLicenseType();
@ -646,21 +670,37 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
simulatePipelineListenerNewLicense);
assertThat(simulatePipelineListenerNewLicense.actionGet().getResults(), is(not(empty())));
//both ingest pipelines should work
client().prepareIndex("infer_license_test", MapperService.SINGLE_MAPPING_NAME)
.setPipeline("test_infer_license_pipeline")
.setSource("{}", XContentType.JSON)
.execute()
.actionGet();
client().prepareIndex("infer_license_test", MapperService.SINGLE_MAPPING_NAME)
.setPipeline("test_infer_license_pipeline_again")
.setSource("{}", XContentType.JSON)
.execute()
.actionGet();
}
public void testMachineLearningInferModelRestricted() throws Exception {
public void testMachineLearningInferModelRestricted() {
String modelId = "modelinfermodellicensetest";
assertMLAllowed(true);
putInferenceModel(modelId);
PlainActionFuture<InferModelAction.Response> inferModelSuccess = PlainActionFuture.newFuture();
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
PlainActionFuture<InternalInferModelAction.Response> inferModelSuccess = PlainActionFuture.newFuture();
client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
new RegressionConfig()
new RegressionConfig(),
false
), inferModelSuccess);
assertThat(inferModelSuccess.actionGet().getInferenceResults(), is(not(empty())));
InternalInferModelAction.Response response = inferModelSuccess.actionGet();
assertThat(response.getInferenceResults(), is(not(empty())));
assertThat(response.isLicensed(), is(true));
// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
@ -669,28 +709,40 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
// inferring against a model should now fail
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> {
PlainActionFuture<InferModelAction.Response> listener = PlainActionFuture.newFuture();
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
new RegressionConfig()
), listener);
listener.actionGet();
new RegressionConfig(),
false
)).actionGet();
});
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("non-compliant"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
// Inferring with previously Licensed == true should pass, but indicate license issues
inferModelSuccess = PlainActionFuture.newFuture();
client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
new RegressionConfig(),
true
), inferModelSuccess);
response = inferModelSuccess.actionGet();
assertThat(response.getInferenceResults(), is(not(empty())));
assertThat(response.isLicensed(), is(false));
// Pick a license that does allow machine learning
mode = randomValidLicenseType();
enableLicensing(mode);
assertMLAllowed(true);
PlainActionFuture<InferModelAction.Response> listener = PlainActionFuture.newFuture();
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
PlainActionFuture<InternalInferModelAction.Response> listener = PlainActionFuture.newFuture();
client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
new RegressionConfig()
new RegressionConfig(),
false
), listener);
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
}
@ -703,6 +755,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
" \"description\": \"test model for classification\",\n" +
" \"version\": \"7.6.0\",\n" +
" \"created_by\": \"benwtrent\",\n" +
" \"license_level\": \"platinum\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0,\n" +
" \"estimated_operations\": 0,\n" +
" \"created_time\": 0\n" +

View File

@ -12,7 +12,11 @@ import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.routing.allocation.decider.AwarenessAllocationDecider;
import org.elasticsearch.cluster.routing.OperationRouting;
import org.elasticsearch.cluster.service.ClusterApplierService;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.cluster.service.MasterService;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
@ -87,8 +91,7 @@ public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
new InferenceProcessor.Factory(parameters.client,
parameters.ingestService.getClusterService(),
Settings.EMPTY,
parameters.ingestService,
licenseState));
parameters.ingestService));
factoryMap.put("not_inference", new NotInferenceProcessor.Factory());
@ -105,9 +108,15 @@ public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
ThreadPool tp = mock(ThreadPool.class);
client = mock(Client.class);
clusterService = mock(ClusterService.class);
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY,
Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS));
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
Settings settings = Settings.builder().put("node.name", "InferenceProcessorFactoryTests_node").build();
ClusterSettings clusterSettings = new ClusterSettings(settings,
new HashSet<>(Arrays.asList(InferenceProcessor.MAX_INFERENCE_PROCESSORS,
MasterService.MASTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING,
OperationRouting.USE_ADAPTIVE_REPLICA_SELECTION_SETTING,
ClusterService.USER_DEFINED_META_DATA,
AwarenessAllocationDecider.CLUSTER_ROUTING_ALLOCATION_AWARENESS_ATTRIBUTE_SETTING,
ClusterApplierService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING)));
clusterService = new ClusterService(settings, clusterSettings, tp);
ingestService = new IngestService(clusterService, tp, null, null,
null, Collections.singletonList(SKINNY_INGEST_PLUGIN), client);
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.license.License;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
@ -138,6 +139,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
verify(trainedModelProvider).storeTrainedModel(storedModelCaptor.capture(), any(ActionListener.class));
TrainedModelConfig storedModel = storedModelCaptor.getValue();
assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM));
assertThat(storedModel.getModelId(), containsString(JOB_ID));
assertThat(storedModel.getVersion(), equalTo(Version.CURRENT));
assertThat(storedModel.getCreatedBy(), equalTo("data-frame-analytics"));

View File

@ -14,7 +14,11 @@ import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.routing.allocation.decider.AwarenessAllocationDecider;
import org.elasticsearch.cluster.routing.OperationRouting;
import org.elasticsearch.cluster.service.ClusterApplierService;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.cluster.service.MasterService;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
@ -36,8 +40,10 @@ import org.junit.Before;
import java.io.IOException;
import java.net.InetAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
@ -55,8 +61,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
new InferenceProcessor.Factory(parameters.client,
parameters.ingestService.getClusterService(),
Settings.EMPTY,
parameters.ingestService,
licenseState));
parameters.ingestService));
}
};
private Client client;
@ -68,10 +73,15 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
public void setUpVariables() {
ThreadPool tp = mock(ThreadPool.class);
client = mock(Client.class);
clusterService = mock(ClusterService.class);
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY,
Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS));
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
Settings settings = Settings.builder().put("node.name", "InferenceProcessorFactoryTests_node").build();
ClusterSettings clusterSettings = new ClusterSettings(settings,
new HashSet<>(Arrays.asList(InferenceProcessor.MAX_INFERENCE_PROCESSORS,
MasterService.MASTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING,
OperationRouting.USE_ADAPTIVE_REPLICA_SELECTION_SETTING,
ClusterService.USER_DEFINED_META_DATA,
AwarenessAllocationDecider.CLUSTER_ROUTING_ALLOCATION_AWARENESS_ATTRIBUTE_SETTING,
ClusterApplierService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING)));
clusterService = new ClusterService(settings, clusterSettings, tp);
ingestService = new IngestService(clusterService, tp, null, null,
null, Collections.singletonList(SKINNY_PLUGIN), client);
licenseState = mock(XPackLicenseState.class);
@ -84,8 +94,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
clusterService,
Settings.EMPTY,
ingestService,
licenseState);
ingestService);
processorFactory.accept(buildClusterState(metaData));
assertThat(processorFactory.numInferenceProcessors(), equalTo(0));
@ -102,8 +111,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
clusterService,
Settings.builder().put(InferenceProcessor.MAX_INFERENCE_PROCESSORS.getKey(), 1).build(),
ingestService,
licenseState);
ingestService);
processorFactory.accept(buildClusterStateWithModelReferences("model1"));
@ -118,8 +126,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
clusterService,
Settings.EMPTY,
ingestService,
licenseState);
ingestService);
Map<String, Object> config = new HashMap<String, Object>() {{
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
@ -160,8 +167,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
clusterService,
Settings.EMPTY,
ingestService,
licenseState);
ingestService);
processorFactory.accept(builderClusterStateWithModelReferences(Version.V_7_5_0, "model1"));
Map<String, Object> regression = new HashMap<String, Object>() {{
@ -204,8 +210,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
clusterService,
Settings.EMPTY,
ingestService,
licenseState);
ingestService);
Map<String, Object> regression = new HashMap<String, Object>() {{
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());

View File

@ -8,11 +8,12 @@ package org.elasticsearch.xpack.ml.inference.ingest;
import org.elasticsearch.client.Client;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import org.junit.Before;
import java.util.ArrayList;
@ -23,21 +24,30 @@ import java.util.Map;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.core.Is.is;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
public class InferenceProcessorTests extends ESTestCase {
private Client client;
private InferenceAuditor auditor;
@Before
public void setUpVariables() {
client = mock(Client.class);
auditor = mock(InferenceAuditor.class);
}
public void testMutateDocumentWithClassification() {
String targetField = "classification_value";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
targetField,
"classification_model",
@ -50,8 +60,9 @@ public class InferenceProcessorTests extends ESTestCase {
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
InferModelAction.Response response = new InferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", null)));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", null)),
true);
inferenceProcessor.mutateDocument(response, document);
assertThat(document.getFieldValue(targetField, String.class), equalTo("foo"));
@ -63,6 +74,7 @@ public class InferenceProcessorTests extends ESTestCase {
public void testMutateDocumentClassificationTopNClasses() {
String targetField = "classification_value_probabilities";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
targetField,
"classification_model",
@ -79,8 +91,9 @@ public class InferenceProcessorTests extends ESTestCase {
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
InferModelAction.Response response = new InferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes)));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes)),
true);
inferenceProcessor.mutateDocument(response, document);
assertThat((List<Map<?,?>>)document.getFieldValue(targetField, List.class),
@ -92,6 +105,7 @@ public class InferenceProcessorTests extends ESTestCase {
public void testMutateDocumentRegression() {
String targetField = "regression_value";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
targetField,
"regression_model",
@ -104,8 +118,8 @@ public class InferenceProcessorTests extends ESTestCase {
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
InferModelAction.Response response = new InferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7)));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7)), true);
inferenceProcessor.mutateDocument(response, document);
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
@ -116,6 +130,7 @@ public class InferenceProcessorTests extends ESTestCase {
public void testMutateDocumentNoModelMetaData() {
String targetField = "regression_value";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
targetField,
"regression_model",
@ -128,8 +143,8 @@ public class InferenceProcessorTests extends ESTestCase {
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
InferModelAction.Response response = new InferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7)));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7)), true);
inferenceProcessor.mutateDocument(response, document);
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
@ -139,6 +154,7 @@ public class InferenceProcessorTests extends ESTestCase {
public void testMutateDocumentModelMetaDataExistingField() {
String targetField = "regression_value";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
targetField,
"regression_model",
@ -157,8 +173,8 @@ public class InferenceProcessorTests extends ESTestCase {
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
InferModelAction.Response response = new InferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7)));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7)), true);
inferenceProcessor.mutateDocument(response, document);
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
@ -174,6 +190,7 @@ public class InferenceProcessorTests extends ESTestCase {
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
InferenceProcessor processor = new InferenceProcessor(client,
auditor,
"my_processor",
"my_field",
modelId,
@ -204,6 +221,7 @@ public class InferenceProcessorTests extends ESTestCase {
}};
InferenceProcessor processor = new InferenceProcessor(client,
auditor,
"my_processor",
"my_field",
modelId,
@ -227,4 +245,50 @@ public class InferenceProcessorTests extends ESTestCase {
}};
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap));
}
public void testHandleResponseLicenseChanged() {
String targetField = "regression_value";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
targetField,
"regression_model",
new RegressionConfig(),
Collections.emptyMap(),
"ml.my_processor",
true);
Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(false));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7)), true);
inferenceProcessor.handleResponse(response, document, (doc, ex) -> {
assertThat(doc, is(not(nullValue())));
assertThat(ex, is(nullValue()));
});
assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(true));
response = new InternalInferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7)), false);
inferenceProcessor.handleResponse(response, document, (doc, ex) -> {
assertThat(doc, is(not(nullValue())));
assertThat(ex, is(nullValue()));
});
assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(true));
inferenceProcessor.handleResponse(response, document, (doc, ex) -> {
assertThat(doc, is(not(nullValue())));
assertThat(ex, is(nullValue()));
});
verify(auditor, times(1)).warning(eq("regression_model"), any(String.class));
}
}

View File

@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.license.License;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
@ -22,7 +23,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.junit.Before;
@ -68,6 +69,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
.setTrainedModel(buildClassification(true))
.setModelId(modelId1))
.setVersion(Version.CURRENT)
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setCreateTime(Instant.now())
.setEstimatedOperations(0)
.setEstimatedHeapMemory(0)
@ -119,20 +121,20 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
}});
// Test regression
InferModelAction.Request request = new InferModelAction.Request(modelId1, toInfer, new RegressionConfig());
InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet();
InternalInferModelAction.Request request = new InternalInferModelAction.Request(modelId1, toInfer, new RegressionConfig(), true);
InternalInferModelAction.Response response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()),
contains(1.3, 1.25));
request = new InferModelAction.Request(modelId1, toInfer2, new RegressionConfig());
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
request = new InternalInferModelAction.Request(modelId1, toInfer2, new RegressionConfig(), true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()),
contains(1.65, 1.55));
// Test classification
request = new InferModelAction.Request(modelId2, toInfer, new ClassificationConfig(0));
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfig(0), true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults()
.stream()
.map(i -> ((SingleValueInferenceResults)i).valueAsString())
@ -140,8 +142,8 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
contains("not_to_be", "to_be"));
// Get top classes
request = new InferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2));
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2), true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
ClassificationInferenceResults classificationInferenceResults =
(ClassificationInferenceResults)response.getInferenceResults().get(0);
@ -159,8 +161,8 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
// Test that top classes restrict the number returned
request = new InferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1));
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1), true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0);
assertThat(classificationInferenceResults.getTopClasses(), hasSize(1));
@ -169,9 +171,13 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
public void testInferMissingModel() {
String model = "test-infer-missing-model";
InferModelAction.Request request = new InferModelAction.Request(model, Collections.emptyList(), new RegressionConfig());
InternalInferModelAction.Request request = new InternalInferModelAction.Request(
model,
Collections.emptyList(),
new RegressionConfig(),
true);
try {
client().execute(InferModelAction.INSTANCE, request).actionGet();
client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
} catch (ElasticsearchException ex) {
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model)));
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.action.delete.DeleteRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.license.License;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
@ -144,6 +145,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
.setDescription("trained model config for test")
.setModelId(modelId)
.setVersion(Version.CURRENT)
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setEstimatedHeapMemory(0)
.setEstimatedOperations(0)
.setInput(TrainedModelInputTests.createRandomInput());