* [ML][Inference] allowing per-model licensing * changing to internal action + removing pre-mature opt
This commit is contained in:
parent
f264808a6a
commit
d41b2e3f38
|
@ -50,6 +50,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
public static final ParseField INPUT = new ParseField("input");
|
||||
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
|
||||
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
|
||||
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
|
||||
|
||||
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
|
||||
true,
|
||||
|
@ -71,6 +72,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
|
||||
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
|
||||
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
|
||||
PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
|
||||
}
|
||||
|
||||
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
|
||||
|
@ -88,6 +90,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
private final TrainedModelInput input;
|
||||
private final Long estimatedHeapMemory;
|
||||
private final Long estimatedOperations;
|
||||
private final String licenseLevel;
|
||||
|
||||
TrainedModelConfig(String modelId,
|
||||
String createdBy,
|
||||
|
@ -99,7 +102,8 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
Map<String, Object> metadata,
|
||||
TrainedModelInput input,
|
||||
Long estimatedHeapMemory,
|
||||
Long estimatedOperations) {
|
||||
Long estimatedOperations,
|
||||
String licenseLevel) {
|
||||
this.modelId = modelId;
|
||||
this.createdBy = createdBy;
|
||||
this.version = version;
|
||||
|
@ -111,6 +115,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
this.input = input;
|
||||
this.estimatedHeapMemory = estimatedHeapMemory;
|
||||
this.estimatedOperations = estimatedOperations;
|
||||
this.licenseLevel = licenseLevel;
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
|
@ -161,6 +166,10 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
return estimatedOperations;
|
||||
}
|
||||
|
||||
public String getLicenseLevel() {
|
||||
return licenseLevel;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
@ -201,6 +210,9 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
if (estimatedOperations != null) {
|
||||
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
|
||||
}
|
||||
if (licenseLevel != null) {
|
||||
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -225,6 +237,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
Objects.equals(input, that.input) &&
|
||||
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
|
||||
Objects.equals(estimatedOperations, that.estimatedOperations) &&
|
||||
Objects.equals(licenseLevel, that.licenseLevel) &&
|
||||
Objects.equals(metadata, that.metadata);
|
||||
}
|
||||
|
||||
|
@ -240,6 +253,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
estimatedHeapMemory,
|
||||
estimatedOperations,
|
||||
metadata,
|
||||
licenseLevel,
|
||||
input);
|
||||
}
|
||||
|
||||
|
@ -257,6 +271,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
private TrainedModelInput input;
|
||||
private Long estimatedHeapMemory;
|
||||
private Long estimatedOperations;
|
||||
private String licenseLevel;
|
||||
|
||||
public Builder setModelId(String modelId) {
|
||||
this.modelId = modelId;
|
||||
|
@ -312,16 +327,21 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
|
||||
private Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
|
||||
this.estimatedHeapMemory = estimatedHeapMemory;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setEstimatedOperations(Long estimatedOperations) {
|
||||
private Builder setEstimatedOperations(Long estimatedOperations) {
|
||||
this.estimatedOperations = estimatedOperations;
|
||||
return this;
|
||||
}
|
||||
|
||||
private Builder setLicenseLevel(String licenseLevel) {
|
||||
this.licenseLevel = licenseLevel;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainedModelConfig build() {
|
||||
return new TrainedModelConfig(
|
||||
modelId,
|
||||
|
@ -334,7 +354,8 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
metadata,
|
||||
input,
|
||||
estimatedHeapMemory,
|
||||
estimatedOperations);
|
||||
estimatedOperations,
|
||||
licenseLevel);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -66,7 +66,8 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
|
|||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
|
||||
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
|
||||
randomBoolean() ? null : randomNonNegativeLong(),
|
||||
randomBoolean() ? null : randomNonNegativeLong());
|
||||
randomBoolean() ? null : randomNonNegativeLong(),
|
||||
randomBoolean() ? null : randomFrom("platinum", "basic"));
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -112,7 +112,7 @@ import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
|
||||
|
@ -175,13 +175,6 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding
|
|||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
|
||||
import org.elasticsearch.xpack.core.monitoring.MonitoringFeatureSetUsage;
|
||||
import org.elasticsearch.xpack.core.rollup.RollupFeatureSetUsage;
|
||||
|
@ -389,7 +382,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
|||
StopDataFrameAnalyticsAction.INSTANCE,
|
||||
EvaluateDataFrameAction.INSTANCE,
|
||||
EstimateMemoryUsageAction.INSTANCE,
|
||||
InferModelAction.INSTANCE,
|
||||
InternalInferModelAction.INSTANCE,
|
||||
GetTrainedModelsAction.INSTANCE,
|
||||
DeleteTrainedModelAction.INSTANCE,
|
||||
GetTrainedModelsStatsAction.INSTANCE,
|
||||
|
|
|
@ -24,12 +24,12 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class InferModelAction extends ActionType<InferModelAction.Response> {
|
||||
public class InternalInferModelAction extends ActionType<InternalInferModelAction.Response> {
|
||||
|
||||
public static final InferModelAction INSTANCE = new InferModelAction();
|
||||
public static final String NAME = "cluster:admin/xpack/ml/inference/infer";
|
||||
public static final InternalInferModelAction INSTANCE = new InternalInferModelAction();
|
||||
public static final String NAME = "cluster:internal/xpack/ml/inference/infer";
|
||||
|
||||
private InferModelAction() {
|
||||
private InternalInferModelAction() {
|
||||
super(NAME, Response::new);
|
||||
}
|
||||
|
||||
|
@ -38,21 +38,27 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
|
|||
private final String modelId;
|
||||
private final List<Map<String, Object>> objectsToInfer;
|
||||
private final InferenceConfig config;
|
||||
private final boolean previouslyLicensed;
|
||||
|
||||
public Request(String modelId) {
|
||||
this(modelId, Collections.emptyList(), new RegressionConfig());
|
||||
public Request(String modelId, boolean previouslyLicensed) {
|
||||
this(modelId, Collections.emptyList(), new RegressionConfig(), previouslyLicensed);
|
||||
}
|
||||
|
||||
public Request(String modelId, List<Map<String, Object>> objectsToInfer, InferenceConfig inferenceConfig) {
|
||||
public Request(String modelId,
|
||||
List<Map<String, Object>> objectsToInfer,
|
||||
InferenceConfig inferenceConfig,
|
||||
boolean previouslyLicensed) {
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
|
||||
this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer"));
|
||||
this.config = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config");
|
||||
this.previouslyLicensed = previouslyLicensed;
|
||||
}
|
||||
|
||||
public Request(String modelId, Map<String, Object> objectToInfer, InferenceConfig config) {
|
||||
public Request(String modelId, Map<String, Object> objectToInfer, InferenceConfig config, boolean previouslyLicensed) {
|
||||
this(modelId,
|
||||
Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")),
|
||||
config);
|
||||
config,
|
||||
previouslyLicensed);
|
||||
}
|
||||
|
||||
public Request(StreamInput in) throws IOException {
|
||||
|
@ -60,6 +66,7 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
|
|||
this.modelId = in.readString();
|
||||
this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap));
|
||||
this.config = in.readNamedWriteable(InferenceConfig.class);
|
||||
this.previouslyLicensed = in.readBoolean();
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
|
@ -74,6 +81,10 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
|
|||
return config;
|
||||
}
|
||||
|
||||
public boolean isPreviouslyLicensed() {
|
||||
return previouslyLicensed;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActionRequestValidationException validate() {
|
||||
return null;
|
||||
|
@ -85,21 +96,23 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
|
|||
out.writeString(modelId);
|
||||
out.writeCollection(objectsToInfer, StreamOutput::writeMap);
|
||||
out.writeNamedWriteable(config);
|
||||
out.writeBoolean(previouslyLicensed);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
InferModelAction.Request that = (InferModelAction.Request) o;
|
||||
InternalInferModelAction.Request that = (InternalInferModelAction.Request) o;
|
||||
return Objects.equals(modelId, that.modelId)
|
||||
&& Objects.equals(config, that.config)
|
||||
&& Objects.equals(previouslyLicensed, that.previouslyLicensed)
|
||||
&& Objects.equals(objectsToInfer, that.objectsToInfer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(modelId, objectsToInfer, config);
|
||||
return Objects.hash(modelId, objectsToInfer, config, previouslyLicensed);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -107,37 +120,68 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
|
|||
public static class Response extends ActionResponse {
|
||||
|
||||
private final List<InferenceResults> inferenceResults;
|
||||
private final boolean isLicensed;
|
||||
|
||||
public Response(List<InferenceResults> inferenceResults) {
|
||||
public Response(List<InferenceResults> inferenceResults, boolean isLicensed) {
|
||||
super();
|
||||
this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults"));
|
||||
this.isLicensed = isLicensed;
|
||||
}
|
||||
|
||||
public Response(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class));
|
||||
this.isLicensed = in.readBoolean();
|
||||
}
|
||||
|
||||
public List<InferenceResults> getInferenceResults() {
|
||||
return inferenceResults;
|
||||
}
|
||||
|
||||
public boolean isLicensed() {
|
||||
return isLicensed;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeNamedWriteableList(inferenceResults);
|
||||
out.writeBoolean(isLicensed);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
InferModelAction.Response that = (InferModelAction.Response) o;
|
||||
return Objects.equals(inferenceResults, that.inferenceResults);
|
||||
InternalInferModelAction.Response that = (InternalInferModelAction.Response) o;
|
||||
return isLicensed == that.isLicensed && Objects.equals(inferenceResults, that.inferenceResults);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(inferenceResults);
|
||||
return Objects.hash(inferenceResults, isLicensed);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private List<InferenceResults> inferenceResults;
|
||||
private boolean isLicensed;
|
||||
|
||||
public Builder setInferenceResults(List<InferenceResults> inferenceResults) {
|
||||
this.inferenceResults = inferenceResults;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setLicensed(boolean licensed) {
|
||||
isLicensed = licensed;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Response build() {
|
||||
return new Response(inferenceResults, isLicensed);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -17,6 +17,8 @@ import org.elasticsearch.common.xcontent.ObjectParser;
|
|||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.xpack.core.common.time.TimeUtils;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
|
@ -48,6 +50,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
public static final ParseField INPUT = new ParseField("input");
|
||||
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
|
||||
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
|
||||
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
|
||||
|
||||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
||||
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
|
||||
|
@ -73,6 +76,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
INPUT);
|
||||
parser.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
|
||||
parser.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
|
||||
parser.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -90,6 +94,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
private final TrainedModelInput input;
|
||||
private final long estimatedHeapMemory;
|
||||
private final long estimatedOperations;
|
||||
private final License.OperationMode licenseLevel;
|
||||
|
||||
private final TrainedModelDefinition definition;
|
||||
|
||||
|
@ -103,7 +108,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
Map<String, Object> metadata,
|
||||
TrainedModelInput input,
|
||||
Long estimatedHeapMemory,
|
||||
Long estimatedOperations) {
|
||||
Long estimatedOperations,
|
||||
String licenseLevel) {
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
||||
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
|
||||
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
|
||||
|
@ -122,6 +128,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
throw new IllegalArgumentException("[" + ESTIMATED_OPERATIONS.getPreferredName() + "] must be greater than or equal to 0");
|
||||
}
|
||||
this.estimatedOperations = estimatedOperations;
|
||||
this.licenseLevel = License.OperationMode.resolve(ExceptionsHelper.requireNonNull(licenseLevel, LICENSE_LEVEL));
|
||||
}
|
||||
|
||||
public TrainedModelConfig(StreamInput in) throws IOException {
|
||||
|
@ -136,6 +143,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
input = new TrainedModelInput(in);
|
||||
estimatedHeapMemory = in.readVLong();
|
||||
estimatedOperations = in.readVLong();
|
||||
licenseLevel = License.OperationMode.resolve(in.readString());
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
|
@ -187,6 +195,25 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
return estimatedOperations;
|
||||
}
|
||||
|
||||
public License.OperationMode getLicenseLevel() {
|
||||
return licenseLevel;
|
||||
}
|
||||
|
||||
public boolean isAvailableWithLicense(XPackLicenseState licenseState) {
|
||||
// Basic is always true
|
||||
if (licenseLevel.equals(License.OperationMode.BASIC)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// The model license does not matter, this is the highest licensed level
|
||||
if (licenseState.isActive() && XPackLicenseState.isPlatinumOrTrialOperationMode(licenseState.getOperationMode())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// catch the rest, if the license is active and is at least the required model license
|
||||
return licenseState.isActive() && License.OperationMode.compare(licenseState.getOperationMode(), licenseLevel) >= 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(modelId);
|
||||
|
@ -200,6 +227,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
input.writeTo(out);
|
||||
out.writeVLong(estimatedHeapMemory);
|
||||
out.writeVLong(estimatedOperations);
|
||||
out.writeString(licenseLevel.description());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -229,6 +257,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
|
||||
new ByteSizeValue(estimatedHeapMemory));
|
||||
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
|
||||
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -253,6 +282,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
Objects.equals(input, that.input) &&
|
||||
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
|
||||
Objects.equals(estimatedOperations, that.estimatedOperations) &&
|
||||
Objects.equals(licenseLevel, that.licenseLevel) &&
|
||||
Objects.equals(metadata, that.metadata);
|
||||
}
|
||||
|
||||
|
@ -268,7 +298,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
metadata,
|
||||
estimatedHeapMemory,
|
||||
estimatedOperations,
|
||||
input);
|
||||
input,
|
||||
licenseLevel);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
@ -284,6 +315,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
private TrainedModelDefinition definition;
|
||||
private Long estimatedHeapMemory;
|
||||
private Long estimatedOperations;
|
||||
private String licenseLevel = License.OperationMode.PLATINUM.description();
|
||||
|
||||
public Builder setModelId(String modelId) {
|
||||
this.modelId = modelId;
|
||||
|
@ -349,6 +381,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setLicenseLevel(String licenseLevel) {
|
||||
this.licenseLevel = licenseLevel;
|
||||
return this;
|
||||
}
|
||||
|
||||
// TODO move to REST level instead of here in the builder
|
||||
public void validate() {
|
||||
// We require a definition to be available here even though it will be stored in a different doc
|
||||
|
@ -366,28 +403,17 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
MlStrings.ID_LENGTH_LIMIT));
|
||||
}
|
||||
|
||||
if (version != null) {
|
||||
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", VERSION.getPreferredName());
|
||||
}
|
||||
checkIllegalSetting(version, VERSION.getPreferredName());
|
||||
checkIllegalSetting(createdBy, CREATED_BY.getPreferredName());
|
||||
checkIllegalSetting(createTime, CREATE_TIME.getPreferredName());
|
||||
checkIllegalSetting(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName());
|
||||
checkIllegalSetting(estimatedOperations, ESTIMATED_OPERATIONS.getPreferredName());
|
||||
checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName());
|
||||
}
|
||||
|
||||
if (createdBy != null) {
|
||||
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
|
||||
CREATED_BY.getPreferredName());
|
||||
}
|
||||
|
||||
if (createTime != null) {
|
||||
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
|
||||
CREATE_TIME.getPreferredName());
|
||||
}
|
||||
|
||||
if (estimatedHeapMemory != null) {
|
||||
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
|
||||
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName());
|
||||
}
|
||||
|
||||
if (estimatedOperations != null) {
|
||||
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
|
||||
ESTIMATED_OPERATIONS.getPreferredName());
|
||||
private static void checkIllegalSetting(Object value, String setting) {
|
||||
if (value != null) {
|
||||
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", setting);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -403,7 +429,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
metadata,
|
||||
input,
|
||||
estimatedHeapMemory,
|
||||
estimatedOperations);
|
||||
estimatedOperations,
|
||||
licenseLevel);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ package org.elasticsearch.xpack.core.ml.action;
|
|||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
|
@ -22,19 +22,21 @@ import java.util.stream.Collectors;
|
|||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class InferModelActionRequestTests extends AbstractWireSerializingTestCase<Request> {
|
||||
public class InternalInferModelActionRequestTests extends AbstractWireSerializingTestCase<Request> {
|
||||
|
||||
@Override
|
||||
protected Request createTestInstance() {
|
||||
return randomBoolean() ?
|
||||
new Request(
|
||||
randomAlphaOfLength(10),
|
||||
Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()),
|
||||
randomInferenceConfig()) :
|
||||
Stream.generate(InternalInferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()),
|
||||
randomInferenceConfig(),
|
||||
randomBoolean()) :
|
||||
new Request(
|
||||
randomAlphaOfLength(10),
|
||||
randomMap(),
|
||||
randomInferenceConfig());
|
||||
randomInferenceConfig(),
|
||||
randomBoolean());
|
||||
}
|
||||
|
||||
private static InferenceConfig randomInferenceConfig() {
|
|
@ -8,7 +8,7 @@ package org.elasticsearch.xpack.core.ml.action;
|
|||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests;
|
||||
|
@ -16,12 +16,10 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
|||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class InferModelActionResponseTests extends AbstractWireSerializingTestCase<Response> {
|
||||
public class InternalInferModelActionResponseTests extends AbstractWireSerializingTestCase<Response> {
|
||||
|
||||
@Override
|
||||
protected Response createTestInstance() {
|
||||
|
@ -29,7 +27,8 @@ public class InferModelActionResponseTests extends AbstractWireSerializingTestCa
|
|||
return new Response(
|
||||
Stream.generate(() -> randomInferenceResult(resultType))
|
||||
.limit(randomIntBetween(0, 10))
|
||||
.collect(Collectors.toList()));
|
||||
.collect(Collectors.toList()),
|
||||
randomBoolean());
|
||||
}
|
||||
|
||||
private static InferenceResults randomInferenceResult(String resultType) {
|
||||
|
@ -50,9 +49,7 @@ public class InferModelActionResponseTests extends AbstractWireSerializingTestCa
|
|||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(entries);
|
||||
return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
}
|
||||
|
||||
}
|
|
@ -16,6 +16,8 @@ import org.elasticsearch.common.xcontent.ToXContent;
|
|||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
|
@ -37,11 +39,29 @@ import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNA
|
|||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class TrainedModelConfigTests extends AbstractSerializingTestCase<TrainedModelConfig> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
public static TrainedModelConfig.Builder createTestInstance(String modelId) {
|
||||
List<String> tags = Arrays.asList(generateRandomStringArray(randomIntBetween(0, 5), 15, false));
|
||||
return TrainedModelConfig.builder()
|
||||
.setInput(TrainedModelInputTests.createRandomInput())
|
||||
.setMetadata(randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)))
|
||||
.setCreateTime(Instant.ofEpochMilli(randomNonNegativeLong()))
|
||||
.setVersion(Version.CURRENT)
|
||||
.setModelId(modelId)
|
||||
.setCreatedBy(randomAlphaOfLength(10))
|
||||
.setDescription(randomBoolean() ? null : randomAlphaOfLength(100))
|
||||
.setEstimatedHeapMemory(randomNonNegativeLong())
|
||||
.setEstimatedOperations(randomNonNegativeLong())
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
.setTags(tags);
|
||||
}
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
|
@ -64,19 +84,7 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
|||
|
||||
@Override
|
||||
protected TrainedModelConfig createTestInstance() {
|
||||
List<String> tags = Arrays.asList(generateRandomStringArray(randomIntBetween(0, 5), 15, false));
|
||||
return new TrainedModelConfig(
|
||||
randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
Version.CURRENT,
|
||||
randomBoolean() ? null : randomAlphaOfLength(100),
|
||||
Instant.ofEpochMilli(randomNonNegativeLong()),
|
||||
null, // is not parsed so should not be provided
|
||||
tags,
|
||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
|
||||
TrainedModelInputTests.createRandomInput(),
|
||||
randomNonNegativeLong(),
|
||||
randomNonNegativeLong());
|
||||
return createTestInstance(randomAlphaOfLength(10)).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -121,7 +129,8 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
|||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
|
||||
TrainedModelInputTests.createRandomInput(),
|
||||
randomNonNegativeLong(),
|
||||
randomNonNegativeLong());
|
||||
randomNonNegativeLong(),
|
||||
"platinum");
|
||||
|
||||
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
|
||||
assertThat(reference.utf8ToString(), containsString("definition"));
|
||||
|
@ -179,4 +188,39 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
|||
.setModelId(modelId).validate());
|
||||
assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));
|
||||
}
|
||||
|
||||
public void testIsAvailableWithLicense() {
|
||||
TrainedModelConfig.Builder builder = createTestInstance(randomAlphaOfLength(10));
|
||||
|
||||
XPackLicenseState licenseState = mock(XPackLicenseState.class);
|
||||
when(licenseState.isActive()).thenReturn(false);
|
||||
when(licenseState.getOperationMode()).thenReturn(License.OperationMode.BASIC);
|
||||
|
||||
assertFalse(builder.setLicenseLevel(License.OperationMode.PLATINUM.description()).build().isAvailableWithLicense(licenseState));
|
||||
assertTrue(builder.setLicenseLevel(License.OperationMode.BASIC.description()).build().isAvailableWithLicense(licenseState));
|
||||
|
||||
|
||||
when(licenseState.isActive()).thenReturn(true);
|
||||
when(licenseState.getOperationMode()).thenReturn(License.OperationMode.PLATINUM);
|
||||
assertTrue(builder.setLicenseLevel(License.OperationMode.PLATINUM.description()).build().isAvailableWithLicense(licenseState));
|
||||
assertTrue(builder.setLicenseLevel(License.OperationMode.BASIC.description()).build().isAvailableWithLicense(licenseState));
|
||||
assertTrue(builder.setLicenseLevel(License.OperationMode.GOLD.description()).build().isAvailableWithLicense(licenseState));
|
||||
|
||||
when(licenseState.isActive()).thenReturn(false);
|
||||
assertFalse(builder.setLicenseLevel(License.OperationMode.PLATINUM.description()).build().isAvailableWithLicense(licenseState));
|
||||
assertTrue(builder.setLicenseLevel(License.OperationMode.BASIC.description()).build().isAvailableWithLicense(licenseState));
|
||||
assertFalse(builder.setLicenseLevel(License.OperationMode.GOLD.description()).build().isAvailableWithLicense(licenseState));
|
||||
|
||||
when(licenseState.isActive()).thenReturn(true);
|
||||
when(licenseState.getOperationMode()).thenReturn(License.OperationMode.GOLD);
|
||||
assertFalse(builder.setLicenseLevel(License.OperationMode.PLATINUM.description()).build().isAvailableWithLicense(licenseState));
|
||||
assertTrue(builder.setLicenseLevel(License.OperationMode.BASIC.description()).build().isAvailableWithLicense(licenseState));
|
||||
assertTrue(builder.setLicenseLevel(License.OperationMode.GOLD.description()).build().isAvailableWithLicense(licenseState));
|
||||
|
||||
when(licenseState.isActive()).thenReturn(false);
|
||||
assertFalse(builder.setLicenseLevel(License.OperationMode.PLATINUM.description()).build().isAvailableWithLicense(licenseState));
|
||||
assertTrue(builder.setLicenseLevel(License.OperationMode.BASIC.description()).build().isAvailableWithLicense(licenseState));
|
||||
assertFalse(builder.setLicenseLevel(License.OperationMode.GOLD.description()).build().isAvailableWithLicense(licenseState));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -370,6 +370,7 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
|
||||
" \"description\": \"test model for regression\",\n" +
|
||||
" \"version\": \"7.6.0\",\n" +
|
||||
" \"license_level\": \"platinum\",\n" +
|
||||
" \"created_by\": \"ml_test\",\n" +
|
||||
" \"estimated_heap_memory_usage_bytes\": 0," +
|
||||
" \"estimated_operations\": 0," +
|
||||
|
@ -503,6 +504,7 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
|
||||
" \"description\": \"test model for classification\",\n" +
|
||||
" \"version\": \"7.6.0\",\n" +
|
||||
" \"license_level\": \"platinum\",\n" +
|
||||
" \"created_by\": \"benwtrent\",\n" +
|
||||
" \"estimated_heap_memory_usage_bytes\": 0," +
|
||||
" \"estimated_operations\": 0," +
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
|
|||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.test.SecuritySettingsSourceField;
|
||||
import org.elasticsearch.test.rest.ESRestTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
|
@ -193,6 +194,7 @@ public class TrainedModelIT extends ESRestTestCase {
|
|||
.setVersion(Version.CURRENT)
|
||||
.setCreateTime(Instant.now())
|
||||
.setEstimatedOperations(0)
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
.setEstimatedHeapMemory(0)
|
||||
.build()
|
||||
.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
|
||||
|
|
|
@ -99,7 +99,7 @@ import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
|
||||
|
@ -168,7 +168,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetModelSnapshotsAction;
|
|||
import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsStatsAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportInferModelAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportInternalInferModelAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportKillProcessAction;
|
||||
|
@ -346,9 +346,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
InferenceProcessor.Factory inferenceFactory = new InferenceProcessor.Factory(parameters.client,
|
||||
parameters.ingestService.getClusterService(),
|
||||
this.settings,
|
||||
parameters.ingestService,
|
||||
getLicenseState());
|
||||
getLicenseState().addListener(inferenceFactory);
|
||||
parameters.ingestService);
|
||||
parameters.ingestService.addIngestClusterStateListener(inferenceFactory);
|
||||
return Collections.singletonMap(InferenceProcessor.TYPE, inferenceFactory);
|
||||
}
|
||||
|
@ -831,7 +829,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
new ActionHandler<>(StopDataFrameAnalyticsAction.INSTANCE, TransportStopDataFrameAnalyticsAction.class),
|
||||
new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class),
|
||||
new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class),
|
||||
new ActionHandler<>(InferModelAction.INSTANCE, TransportInferModelAction.class),
|
||||
new ActionHandler<>(InternalInferModelAction.INSTANCE, TransportInternalInferModelAction.class),
|
||||
new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class),
|
||||
new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class),
|
||||
new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class)
|
||||
|
|
|
@ -16,38 +16,41 @@ import org.elasticsearch.tasks.Task;
|
|||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.Model;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
|
||||
|
||||
|
||||
public class TransportInferModelAction extends HandledTransportAction<InferModelAction.Request, InferModelAction.Response> {
|
||||
public class TransportInternalInferModelAction extends HandledTransportAction<Request, Response> {
|
||||
|
||||
private final ModelLoadingService modelLoadingService;
|
||||
private final Client client;
|
||||
private final XPackLicenseState licenseState;
|
||||
private final TrainedModelProvider trainedModelProvider;
|
||||
|
||||
@Inject
|
||||
public TransportInferModelAction(TransportService transportService,
|
||||
ActionFilters actionFilters,
|
||||
ModelLoadingService modelLoadingService,
|
||||
Client client,
|
||||
XPackLicenseState licenseState) {
|
||||
super(InferModelAction.NAME, transportService, actionFilters, InferModelAction.Request::new);
|
||||
public TransportInternalInferModelAction(TransportService transportService,
|
||||
ActionFilters actionFilters,
|
||||
ModelLoadingService modelLoadingService,
|
||||
Client client,
|
||||
XPackLicenseState licenseState,
|
||||
TrainedModelProvider trainedModelProvider) {
|
||||
super(InternalInferModelAction.NAME, transportService, actionFilters, InternalInferModelAction.Request::new);
|
||||
this.modelLoadingService = modelLoadingService;
|
||||
this.client = client;
|
||||
this.licenseState = licenseState;
|
||||
this.trainedModelProvider = trainedModelProvider;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doExecute(Task task, InferModelAction.Request request, ActionListener<InferModelAction.Response> listener) {
|
||||
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
|
||||
|
||||
if (licenseState.isMachineLearningAllowed() == false) {
|
||||
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
|
||||
return;
|
||||
}
|
||||
Response.Builder responseBuilder = Response.builder();
|
||||
|
||||
ActionListener<Model> getModelListener = ActionListener.wrap(
|
||||
model -> {
|
||||
|
@ -63,13 +66,28 @@ public class TransportInferModelAction extends HandledTransportAction<InferModel
|
|||
|
||||
typedChainTaskExecutor.execute(ActionListener.wrap(
|
||||
inferenceResultsInterfaces ->
|
||||
listener.onResponse(new InferModelAction.Response(inferenceResultsInterfaces)),
|
||||
listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).build()),
|
||||
listener::onFailure
|
||||
));
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
this.modelLoadingService.getModel(request.getModelId(), getModelListener);
|
||||
if (licenseState.isMachineLearningAllowed()) {
|
||||
responseBuilder.setLicensed(true);
|
||||
this.modelLoadingService.getModel(request.getModelId(), getModelListener);
|
||||
} else {
|
||||
trainedModelProvider.getTrainedModel(request.getModelId(), false, ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
responseBuilder.setLicensed(trainedModelConfig.isAvailableWithLicense(licenseState));
|
||||
if (trainedModelConfig.isAvailableWithLicense(licenseState) || request.isPreviouslyLicensed()) {
|
||||
this.modelLoadingService.getModel(request.getModelId(), getModelListener);
|
||||
} else {
|
||||
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
|
||||
}
|
||||
},
|
||||
listener::onFailure
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.action.LatchedActionListener;
|
|||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
|
@ -157,6 +158,7 @@ public class AnalyticsResultProcessor {
|
|||
.setEstimatedHeapMemory(definition.ramBytesUsed())
|
||||
.setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations())
|
||||
.setInput(new TrainedModelInput(fieldNames))
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
|
|
@ -26,22 +26,20 @@ import org.elasticsearch.ingest.IngestService;
|
|||
import org.elasticsearch.ingest.Pipeline;
|
||||
import org.elasticsearch.ingest.PipelineConfiguration;
|
||||
import org.elasticsearch.ingest.Processor;
|
||||
import org.elasticsearch.license.LicenseStateListener;
|
||||
import org.elasticsearch.license.LicenseUtils;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
|
@ -73,8 +71,12 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
private final InferenceConfig inferenceConfig;
|
||||
private final Map<String, String> fieldMapping;
|
||||
private final boolean includeModelMetadata;
|
||||
private final InferenceAuditor auditor;
|
||||
private volatile boolean previouslyLicensed;
|
||||
private final AtomicBoolean shouldAudit = new AtomicBoolean(true);
|
||||
|
||||
public InferenceProcessor(Client client,
|
||||
InferenceAuditor auditor,
|
||||
String tag,
|
||||
String targetField,
|
||||
String modelId,
|
||||
|
@ -85,6 +87,7 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
super(tag);
|
||||
this.client = ExceptionsHelper.requireNonNull(client, "client");
|
||||
this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD);
|
||||
this.auditor = ExceptionsHelper.requireNonNull(auditor, "auditor");
|
||||
this.modelInfoField = ExceptionsHelper.requireNonNull(modelInfoField, MODEL_INFO_FIELD);
|
||||
this.includeModelMetadata = includeModelMetadata;
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
||||
|
@ -100,22 +103,32 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
|
||||
executeAsyncWithOrigin(client,
|
||||
ML_ORIGIN,
|
||||
InferModelAction.INSTANCE,
|
||||
InternalInferModelAction.INSTANCE,
|
||||
this.buildRequest(ingestDocument),
|
||||
ActionListener.wrap(
|
||||
r -> {
|
||||
try {
|
||||
mutateDocument(r, ingestDocument);
|
||||
handler.accept(ingestDocument, null);
|
||||
} catch(ElasticsearchException ex) {
|
||||
handler.accept(ingestDocument, ex);
|
||||
}
|
||||
},
|
||||
r -> handleResponse(r, ingestDocument, handler),
|
||||
e -> handler.accept(ingestDocument, e)
|
||||
));
|
||||
}
|
||||
|
||||
InferModelAction.Request buildRequest(IngestDocument ingestDocument) {
|
||||
void handleResponse(InternalInferModelAction.Response response,
|
||||
IngestDocument ingestDocument,
|
||||
BiConsumer<IngestDocument, Exception> handler) {
|
||||
if (previouslyLicensed == false) {
|
||||
previouslyLicensed = true;
|
||||
}
|
||||
if (response.isLicensed() == false) {
|
||||
auditWarningAboutLicenseIfNecessary();
|
||||
}
|
||||
try {
|
||||
mutateDocument(response, ingestDocument);
|
||||
handler.accept(ingestDocument, null);
|
||||
} catch(ElasticsearchException ex) {
|
||||
handler.accept(ingestDocument, ex);
|
||||
}
|
||||
}
|
||||
|
||||
InternalInferModelAction.Request buildRequest(IngestDocument ingestDocument) {
|
||||
Map<String, Object> fields = new HashMap<>(ingestDocument.getSourceAndMetadata());
|
||||
if (fieldMapping != null) {
|
||||
fieldMapping.forEach((src, dest) -> {
|
||||
|
@ -125,10 +138,19 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
}
|
||||
});
|
||||
}
|
||||
return new InferModelAction.Request(modelId, fields, inferenceConfig);
|
||||
return new InternalInferModelAction.Request(modelId, fields, inferenceConfig, previouslyLicensed);
|
||||
}
|
||||
|
||||
void mutateDocument(InferModelAction.Response response, IngestDocument ingestDocument) {
|
||||
void auditWarningAboutLicenseIfNecessary() {
|
||||
if (shouldAudit.compareAndSet(true, false)) {
|
||||
auditor.warning(
|
||||
modelId,
|
||||
"This cluster is no longer licensed to use this model in the inference ingest processor. " +
|
||||
"Please update your license information.");
|
||||
}
|
||||
}
|
||||
|
||||
void mutateDocument(InternalInferModelAction.Response response, IngestDocument ingestDocument) {
|
||||
if (response.getInferenceResults().isEmpty()) {
|
||||
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
@ -148,28 +170,25 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
return TYPE;
|
||||
}
|
||||
|
||||
public static final class Factory implements Processor.Factory, Consumer<ClusterState>, LicenseStateListener {
|
||||
public static final class Factory implements Processor.Factory, Consumer<ClusterState> {
|
||||
|
||||
private static final Logger logger = LogManager.getLogger(Factory.class);
|
||||
|
||||
private final Client client;
|
||||
private final IngestService ingestService;
|
||||
private final XPackLicenseState licenseState;
|
||||
private final InferenceAuditor auditor;
|
||||
private volatile int currentInferenceProcessors;
|
||||
private volatile int maxIngestProcessors;
|
||||
private volatile Version minNodeVersion = Version.CURRENT;
|
||||
private volatile boolean inferenceAllowed;
|
||||
|
||||
public Factory(Client client,
|
||||
ClusterService clusterService,
|
||||
Settings settings,
|
||||
IngestService ingestService,
|
||||
XPackLicenseState licenseState) {
|
||||
IngestService ingestService) {
|
||||
this.client = client;
|
||||
this.maxIngestProcessors = MAX_INFERENCE_PROCESSORS.get(settings);
|
||||
this.ingestService = ingestService;
|
||||
this.licenseState = licenseState;
|
||||
this.inferenceAllowed = licenseState.isMachineLearningAllowed();
|
||||
this.auditor = new InferenceAuditor(client, clusterService.getNodeName());
|
||||
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_INFERENCE_PROCESSORS, this::setMaxIngestProcessors);
|
||||
}
|
||||
|
||||
|
@ -211,10 +230,6 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
public InferenceProcessor create(Map<String, Processor.Factory> processorFactories, String tag, Map<String, Object> config)
|
||||
throws Exception {
|
||||
|
||||
if (inferenceAllowed == false) {
|
||||
throw LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING);
|
||||
}
|
||||
|
||||
if (this.maxIngestProcessors <= currentInferenceProcessors) {
|
||||
throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. " +
|
||||
"Adjust the setting [{}]: [{}] if a greater number is desired.",
|
||||
|
@ -236,6 +251,7 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
modelInfoField += "." + tag;
|
||||
}
|
||||
return new InferenceProcessor(client,
|
||||
auditor,
|
||||
tag,
|
||||
targetField,
|
||||
modelId,
|
||||
|
@ -289,9 +305,5 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void licenseStateChanged() {
|
||||
this.inferenceAllowed = licenseState.isMachineLearningAllowed();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.license;
|
|||
import org.elasticsearch.ElasticsearchSecurityException;
|
||||
import org.elasticsearch.action.ingest.PutPipelineAction;
|
||||
import org.elasticsearch.action.ingest.PutPipelineRequest;
|
||||
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
|
||||
import org.elasticsearch.action.ingest.SimulatePipelineAction;
|
||||
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
|
||||
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
|
||||
|
@ -32,7 +33,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteDatafeedAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
|
||||
|
@ -56,8 +57,10 @@ import static org.hamcrest.Matchers.containsString;
|
|||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasItem;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
|
||||
|
||||
|
@ -565,7 +568,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
|
|||
" }\n" +
|
||||
" }\n" +
|
||||
" }]}\n";
|
||||
// test that license restricted apis do now work
|
||||
// Creating a pipeline should work
|
||||
PlainActionFuture<AcknowledgedResponse> putPipelineListener = PlainActionFuture.newFuture();
|
||||
client().execute(PutPipelineAction.INSTANCE,
|
||||
new PutPipelineRequest("test_infer_license_pipeline",
|
||||
|
@ -575,6 +578,12 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
|
|||
AcknowledgedResponse putPipelineResponse = putPipelineListener.actionGet();
|
||||
assertTrue(putPipelineResponse.isAcknowledged());
|
||||
|
||||
client().prepareIndex("infer_license_test", MapperService.SINGLE_MAPPING_NAME)
|
||||
.setPipeline("test_infer_license_pipeline")
|
||||
.setSource("{}", XContentType.JSON)
|
||||
.execute()
|
||||
.actionGet();
|
||||
|
||||
String simulateSource = "{\n" +
|
||||
" \"pipeline\": \n" +
|
||||
pipeline +
|
||||
|
@ -594,37 +603,52 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
|
|||
|
||||
assertThat(simulatePipelineListener.actionGet().getResults(), is(not(empty())));
|
||||
|
||||
|
||||
// Pick a license that does not allow machine learning
|
||||
License.OperationMode mode = randomInvalidLicenseType();
|
||||
enableLicensing(mode);
|
||||
assertMLAllowed(false);
|
||||
|
||||
// creating a new pipeline should fail
|
||||
// Inference against the previous pipeline should still work
|
||||
try {
|
||||
client().prepareIndex("infer_license_test", MapperService.SINGLE_MAPPING_NAME)
|
||||
.setPipeline("test_infer_license_pipeline")
|
||||
.setSource("{}", XContentType.JSON)
|
||||
.execute()
|
||||
.actionGet();
|
||||
} catch (ElasticsearchSecurityException ex) {
|
||||
fail(ex.getMessage());
|
||||
}
|
||||
|
||||
// Creating a new pipeline with an inference processor should work
|
||||
putPipelineListener = PlainActionFuture.newFuture();
|
||||
client().execute(PutPipelineAction.INSTANCE,
|
||||
new PutPipelineRequest("test_infer_license_pipeline_again",
|
||||
new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON),
|
||||
putPipelineListener);
|
||||
putPipelineResponse = putPipelineListener.actionGet();
|
||||
assertTrue(putPipelineResponse.isAcknowledged());
|
||||
|
||||
// Inference against the new pipeline should fail since it has never previously succeeded
|
||||
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> {
|
||||
PlainActionFuture<AcknowledgedResponse> listener = PlainActionFuture.newFuture();
|
||||
client().execute(PutPipelineAction.INSTANCE,
|
||||
new PutPipelineRequest("test_infer_license_pipeline_failure",
|
||||
new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON),
|
||||
listener);
|
||||
listener.actionGet();
|
||||
client().prepareIndex("infer_license_test", MapperService.SINGLE_MAPPING_NAME)
|
||||
.setPipeline("test_infer_license_pipeline_again")
|
||||
.setSource("{}", XContentType.JSON)
|
||||
.execute()
|
||||
.actionGet();
|
||||
});
|
||||
assertThat(e.status(), is(RestStatus.FORBIDDEN));
|
||||
assertThat(e.getMessage(), containsString("non-compliant"));
|
||||
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
|
||||
|
||||
// Simulating the pipeline should fail
|
||||
e = expectThrows(ElasticsearchSecurityException.class, () -> {
|
||||
PlainActionFuture<SimulatePipelineResponse> listener = PlainActionFuture.newFuture();
|
||||
client().execute(SimulatePipelineAction.INSTANCE,
|
||||
new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON),
|
||||
listener);
|
||||
listener.actionGet();
|
||||
});
|
||||
assertThat(e.status(), is(RestStatus.FORBIDDEN));
|
||||
assertThat(e.getMessage(), containsString("non-compliant"));
|
||||
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
|
||||
SimulateDocumentBaseResult simulateResponse = (SimulateDocumentBaseResult)client().execute(SimulatePipelineAction.INSTANCE,
|
||||
new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON))
|
||||
.actionGet()
|
||||
.getResults()
|
||||
.get(0);
|
||||
assertThat(simulateResponse.getFailure(), is(not(nullValue())));
|
||||
assertThat((simulateResponse.getFailure()).getCause(), is(instanceOf(ElasticsearchSecurityException.class)));
|
||||
|
||||
// Pick a license that does allow machine learning
|
||||
mode = randomValidLicenseType();
|
||||
|
@ -646,21 +670,37 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
|
|||
simulatePipelineListenerNewLicense);
|
||||
|
||||
assertThat(simulatePipelineListenerNewLicense.actionGet().getResults(), is(not(empty())));
|
||||
|
||||
//both ingest pipelines should work
|
||||
|
||||
client().prepareIndex("infer_license_test", MapperService.SINGLE_MAPPING_NAME)
|
||||
.setPipeline("test_infer_license_pipeline")
|
||||
.setSource("{}", XContentType.JSON)
|
||||
.execute()
|
||||
.actionGet();
|
||||
client().prepareIndex("infer_license_test", MapperService.SINGLE_MAPPING_NAME)
|
||||
.setPipeline("test_infer_license_pipeline_again")
|
||||
.setSource("{}", XContentType.JSON)
|
||||
.execute()
|
||||
.actionGet();
|
||||
}
|
||||
|
||||
public void testMachineLearningInferModelRestricted() throws Exception {
|
||||
public void testMachineLearningInferModelRestricted() {
|
||||
String modelId = "modelinfermodellicensetest";
|
||||
assertMLAllowed(true);
|
||||
putInferenceModel(modelId);
|
||||
|
||||
|
||||
PlainActionFuture<InferModelAction.Response> inferModelSuccess = PlainActionFuture.newFuture();
|
||||
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
|
||||
PlainActionFuture<InternalInferModelAction.Response> inferModelSuccess = PlainActionFuture.newFuture();
|
||||
client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request(
|
||||
modelId,
|
||||
Collections.singletonList(Collections.emptyMap()),
|
||||
new RegressionConfig()
|
||||
new RegressionConfig(),
|
||||
false
|
||||
), inferModelSuccess);
|
||||
assertThat(inferModelSuccess.actionGet().getInferenceResults(), is(not(empty())));
|
||||
InternalInferModelAction.Response response = inferModelSuccess.actionGet();
|
||||
assertThat(response.getInferenceResults(), is(not(empty())));
|
||||
assertThat(response.isLicensed(), is(true));
|
||||
|
||||
// Pick a license that does not allow machine learning
|
||||
License.OperationMode mode = randomInvalidLicenseType();
|
||||
|
@ -669,28 +709,40 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
|
|||
|
||||
// inferring against a model should now fail
|
||||
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> {
|
||||
PlainActionFuture<InferModelAction.Response> listener = PlainActionFuture.newFuture();
|
||||
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
|
||||
client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request(
|
||||
modelId,
|
||||
Collections.singletonList(Collections.emptyMap()),
|
||||
new RegressionConfig()
|
||||
), listener);
|
||||
listener.actionGet();
|
||||
new RegressionConfig(),
|
||||
false
|
||||
)).actionGet();
|
||||
});
|
||||
assertThat(e.status(), is(RestStatus.FORBIDDEN));
|
||||
assertThat(e.getMessage(), containsString("non-compliant"));
|
||||
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
|
||||
|
||||
// Inferring with previously Licensed == true should pass, but indicate license issues
|
||||
inferModelSuccess = PlainActionFuture.newFuture();
|
||||
client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request(
|
||||
modelId,
|
||||
Collections.singletonList(Collections.emptyMap()),
|
||||
new RegressionConfig(),
|
||||
true
|
||||
), inferModelSuccess);
|
||||
response = inferModelSuccess.actionGet();
|
||||
assertThat(response.getInferenceResults(), is(not(empty())));
|
||||
assertThat(response.isLicensed(), is(false));
|
||||
|
||||
// Pick a license that does allow machine learning
|
||||
mode = randomValidLicenseType();
|
||||
enableLicensing(mode);
|
||||
assertMLAllowed(true);
|
||||
|
||||
PlainActionFuture<InferModelAction.Response> listener = PlainActionFuture.newFuture();
|
||||
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
|
||||
PlainActionFuture<InternalInferModelAction.Response> listener = PlainActionFuture.newFuture();
|
||||
client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request(
|
||||
modelId,
|
||||
Collections.singletonList(Collections.emptyMap()),
|
||||
new RegressionConfig()
|
||||
new RegressionConfig(),
|
||||
false
|
||||
), listener);
|
||||
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
|
||||
}
|
||||
|
@ -703,6 +755,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
|
|||
" \"description\": \"test model for classification\",\n" +
|
||||
" \"version\": \"7.6.0\",\n" +
|
||||
" \"created_by\": \"benwtrent\",\n" +
|
||||
" \"license_level\": \"platinum\",\n" +
|
||||
" \"estimated_heap_memory_usage_bytes\": 0,\n" +
|
||||
" \"estimated_operations\": 0,\n" +
|
||||
" \"created_time\": 0\n" +
|
||||
|
|
|
@ -12,7 +12,11 @@ import org.elasticsearch.cluster.ClusterName;
|
|||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.metadata.MetaData;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNode;
|
||||
import org.elasticsearch.cluster.routing.allocation.decider.AwarenessAllocationDecider;
|
||||
import org.elasticsearch.cluster.routing.OperationRouting;
|
||||
import org.elasticsearch.cluster.service.ClusterApplierService;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.cluster.service.MasterService;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.ClusterSettings;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
|
@ -87,8 +91,7 @@ public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
|
|||
new InferenceProcessor.Factory(parameters.client,
|
||||
parameters.ingestService.getClusterService(),
|
||||
Settings.EMPTY,
|
||||
parameters.ingestService,
|
||||
licenseState));
|
||||
parameters.ingestService));
|
||||
|
||||
factoryMap.put("not_inference", new NotInferenceProcessor.Factory());
|
||||
|
||||
|
@ -105,9 +108,15 @@ public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
|
|||
ThreadPool tp = mock(ThreadPool.class);
|
||||
client = mock(Client.class);
|
||||
clusterService = mock(ClusterService.class);
|
||||
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY,
|
||||
Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS));
|
||||
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
|
||||
Settings settings = Settings.builder().put("node.name", "InferenceProcessorFactoryTests_node").build();
|
||||
ClusterSettings clusterSettings = new ClusterSettings(settings,
|
||||
new HashSet<>(Arrays.asList(InferenceProcessor.MAX_INFERENCE_PROCESSORS,
|
||||
MasterService.MASTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING,
|
||||
OperationRouting.USE_ADAPTIVE_REPLICA_SELECTION_SETTING,
|
||||
ClusterService.USER_DEFINED_META_DATA,
|
||||
AwarenessAllocationDecider.CLUSTER_ROUTING_ALLOCATION_AWARENESS_ATTRIBUTE_SETTING,
|
||||
ClusterApplierService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING)));
|
||||
clusterService = new ClusterService(settings, clusterSettings, tp);
|
||||
ingestService = new IngestService(clusterService, tp, null, null,
|
||||
null, Collections.singletonList(SKINNY_INGEST_PLUGIN), client);
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionListener;
|
|||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
|
@ -138,6 +139,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
verify(trainedModelProvider).storeTrainedModel(storedModelCaptor.capture(), any(ActionListener.class));
|
||||
|
||||
TrainedModelConfig storedModel = storedModelCaptor.getValue();
|
||||
assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM));
|
||||
assertThat(storedModel.getModelId(), containsString(JOB_ID));
|
||||
assertThat(storedModel.getVersion(), equalTo(Version.CURRENT));
|
||||
assertThat(storedModel.getCreatedBy(), equalTo("data-frame-analytics"));
|
||||
|
|
|
@ -14,7 +14,11 @@ import org.elasticsearch.cluster.ClusterState;
|
|||
import org.elasticsearch.cluster.metadata.MetaData;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNode;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNodes;
|
||||
import org.elasticsearch.cluster.routing.allocation.decider.AwarenessAllocationDecider;
|
||||
import org.elasticsearch.cluster.routing.OperationRouting;
|
||||
import org.elasticsearch.cluster.service.ClusterApplierService;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.cluster.service.MasterService;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.ClusterSettings;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
|
@ -36,8 +40,10 @@ import org.junit.Before;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.net.InetAddress;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
@ -55,8 +61,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
new InferenceProcessor.Factory(parameters.client,
|
||||
parameters.ingestService.getClusterService(),
|
||||
Settings.EMPTY,
|
||||
parameters.ingestService,
|
||||
licenseState));
|
||||
parameters.ingestService));
|
||||
}
|
||||
};
|
||||
private Client client;
|
||||
|
@ -68,10 +73,15 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
public void setUpVariables() {
|
||||
ThreadPool tp = mock(ThreadPool.class);
|
||||
client = mock(Client.class);
|
||||
clusterService = mock(ClusterService.class);
|
||||
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY,
|
||||
Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS));
|
||||
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
|
||||
Settings settings = Settings.builder().put("node.name", "InferenceProcessorFactoryTests_node").build();
|
||||
ClusterSettings clusterSettings = new ClusterSettings(settings,
|
||||
new HashSet<>(Arrays.asList(InferenceProcessor.MAX_INFERENCE_PROCESSORS,
|
||||
MasterService.MASTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING,
|
||||
OperationRouting.USE_ADAPTIVE_REPLICA_SELECTION_SETTING,
|
||||
ClusterService.USER_DEFINED_META_DATA,
|
||||
AwarenessAllocationDecider.CLUSTER_ROUTING_ALLOCATION_AWARENESS_ATTRIBUTE_SETTING,
|
||||
ClusterApplierService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING)));
|
||||
clusterService = new ClusterService(settings, clusterSettings, tp);
|
||||
ingestService = new IngestService(clusterService, tp, null, null,
|
||||
null, Collections.singletonList(SKINNY_PLUGIN), client);
|
||||
licenseState = mock(XPackLicenseState.class);
|
||||
|
@ -84,8 +94,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
|
||||
clusterService,
|
||||
Settings.EMPTY,
|
||||
ingestService,
|
||||
licenseState);
|
||||
ingestService);
|
||||
processorFactory.accept(buildClusterState(metaData));
|
||||
|
||||
assertThat(processorFactory.numInferenceProcessors(), equalTo(0));
|
||||
|
@ -102,8 +111,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
|
||||
clusterService,
|
||||
Settings.builder().put(InferenceProcessor.MAX_INFERENCE_PROCESSORS.getKey(), 1).build(),
|
||||
ingestService,
|
||||
licenseState);
|
||||
ingestService);
|
||||
|
||||
processorFactory.accept(buildClusterStateWithModelReferences("model1"));
|
||||
|
||||
|
@ -118,8 +126,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
|
||||
clusterService,
|
||||
Settings.EMPTY,
|
||||
ingestService,
|
||||
licenseState);
|
||||
ingestService);
|
||||
|
||||
Map<String, Object> config = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
|
@ -160,8 +167,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
|
||||
clusterService,
|
||||
Settings.EMPTY,
|
||||
ingestService,
|
||||
licenseState);
|
||||
ingestService);
|
||||
processorFactory.accept(builderClusterStateWithModelReferences(Version.V_7_5_0, "model1"));
|
||||
|
||||
Map<String, Object> regression = new HashMap<String, Object>() {{
|
||||
|
@ -204,8 +210,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
|
||||
clusterService,
|
||||
Settings.EMPTY,
|
||||
ingestService,
|
||||
licenseState);
|
||||
ingestService);
|
||||
|
||||
Map<String, Object> regression = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
|
|
|
@ -8,11 +8,12 @@ package org.elasticsearch.xpack.ml.inference.ingest;
|
|||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -23,21 +24,30 @@ import java.util.Map;
|
|||
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.hamcrest.core.Is.is;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
public class InferenceProcessorTests extends ESTestCase {
|
||||
|
||||
private Client client;
|
||||
private InferenceAuditor auditor;
|
||||
|
||||
@Before
|
||||
public void setUpVariables() {
|
||||
client = mock(Client.class);
|
||||
auditor = mock(InferenceAuditor.class);
|
||||
}
|
||||
|
||||
public void testMutateDocumentWithClassification() {
|
||||
String targetField = "classification_value";
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
auditor,
|
||||
"my_processor",
|
||||
targetField,
|
||||
"classification_model",
|
||||
|
@ -50,8 +60,9 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
InferModelAction.Response response = new InferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", null)));
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", null)),
|
||||
true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat(document.getFieldValue(targetField, String.class), equalTo("foo"));
|
||||
|
@ -63,6 +74,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
public void testMutateDocumentClassificationTopNClasses() {
|
||||
String targetField = "classification_value_probabilities";
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
auditor,
|
||||
"my_processor",
|
||||
targetField,
|
||||
"classification_model",
|
||||
|
@ -79,8 +91,9 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
|
||||
|
||||
InferModelAction.Response response = new InferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes)));
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes)),
|
||||
true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat((List<Map<?,?>>)document.getFieldValue(targetField, List.class),
|
||||
|
@ -92,6 +105,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
public void testMutateDocumentRegression() {
|
||||
String targetField = "regression_value";
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
auditor,
|
||||
"my_processor",
|
||||
targetField,
|
||||
"regression_model",
|
||||
|
@ -104,8 +118,8 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
InferModelAction.Response response = new InferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7)));
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7)), true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
|
||||
|
@ -116,6 +130,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
public void testMutateDocumentNoModelMetaData() {
|
||||
String targetField = "regression_value";
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
auditor,
|
||||
"my_processor",
|
||||
targetField,
|
||||
"regression_model",
|
||||
|
@ -128,8 +143,8 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
InferModelAction.Response response = new InferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7)));
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7)), true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
|
||||
|
@ -139,6 +154,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
public void testMutateDocumentModelMetaDataExistingField() {
|
||||
String targetField = "regression_value";
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
auditor,
|
||||
"my_processor",
|
||||
targetField,
|
||||
"regression_model",
|
||||
|
@ -157,8 +173,8 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
InferModelAction.Response response = new InferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7)));
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7)), true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
|
||||
|
@ -174,6 +190,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
|
||||
|
||||
InferenceProcessor processor = new InferenceProcessor(client,
|
||||
auditor,
|
||||
"my_processor",
|
||||
"my_field",
|
||||
modelId,
|
||||
|
@ -204,6 +221,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
}};
|
||||
|
||||
InferenceProcessor processor = new InferenceProcessor(client,
|
||||
auditor,
|
||||
"my_processor",
|
||||
"my_field",
|
||||
modelId,
|
||||
|
@ -227,4 +245,50 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
}};
|
||||
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap));
|
||||
}
|
||||
|
||||
public void testHandleResponseLicenseChanged() {
|
||||
String targetField = "regression_value";
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
auditor,
|
||||
"my_processor",
|
||||
targetField,
|
||||
"regression_model",
|
||||
new RegressionConfig(),
|
||||
Collections.emptyMap(),
|
||||
"ml.my_processor",
|
||||
true);
|
||||
|
||||
Map<String, Object> source = new HashMap<>();
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(false));
|
||||
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7)), true);
|
||||
inferenceProcessor.handleResponse(response, document, (doc, ex) -> {
|
||||
assertThat(doc, is(not(nullValue())));
|
||||
assertThat(ex, is(nullValue()));
|
||||
});
|
||||
|
||||
assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(true));
|
||||
|
||||
response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7)), false);
|
||||
|
||||
inferenceProcessor.handleResponse(response, document, (doc, ex) -> {
|
||||
assertThat(doc, is(not(nullValue())));
|
||||
assertThat(ex, is(nullValue()));
|
||||
});
|
||||
|
||||
assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(true));
|
||||
|
||||
inferenceProcessor.handleResponse(response, document, (doc, ex) -> {
|
||||
assertThat(doc, is(not(nullValue())));
|
||||
assertThat(ex, is(nullValue()));
|
||||
});
|
||||
|
||||
verify(auditor, times(1)).warning(eq("regression_model"), any(String.class));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchException;
|
|||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
|
@ -22,7 +23,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
|||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -68,6 +69,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
.setTrainedModel(buildClassification(true))
|
||||
.setModelId(modelId1))
|
||||
.setVersion(Version.CURRENT)
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
.setCreateTime(Instant.now())
|
||||
.setEstimatedOperations(0)
|
||||
.setEstimatedHeapMemory(0)
|
||||
|
@ -119,20 +121,20 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
}});
|
||||
|
||||
// Test regression
|
||||
InferModelAction.Request request = new InferModelAction.Request(modelId1, toInfer, new RegressionConfig());
|
||||
InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet();
|
||||
InternalInferModelAction.Request request = new InternalInferModelAction.Request(modelId1, toInfer, new RegressionConfig(), true);
|
||||
InternalInferModelAction.Response response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
|
||||
assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()),
|
||||
contains(1.3, 1.25));
|
||||
|
||||
request = new InferModelAction.Request(modelId1, toInfer2, new RegressionConfig());
|
||||
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
|
||||
request = new InternalInferModelAction.Request(modelId1, toInfer2, new RegressionConfig(), true);
|
||||
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
|
||||
assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()),
|
||||
contains(1.65, 1.55));
|
||||
|
||||
|
||||
// Test classification
|
||||
request = new InferModelAction.Request(modelId2, toInfer, new ClassificationConfig(0));
|
||||
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
|
||||
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfig(0), true);
|
||||
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
|
||||
assertThat(response.getInferenceResults()
|
||||
.stream()
|
||||
.map(i -> ((SingleValueInferenceResults)i).valueAsString())
|
||||
|
@ -140,8 +142,8 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
contains("not_to_be", "to_be"));
|
||||
|
||||
// Get top classes
|
||||
request = new InferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2));
|
||||
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
|
||||
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2), true);
|
||||
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
|
||||
|
||||
ClassificationInferenceResults classificationInferenceResults =
|
||||
(ClassificationInferenceResults)response.getInferenceResults().get(0);
|
||||
|
@ -159,8 +161,8 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
|
||||
|
||||
// Test that top classes restrict the number returned
|
||||
request = new InferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1));
|
||||
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
|
||||
request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1), true);
|
||||
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
|
||||
|
||||
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0);
|
||||
assertThat(classificationInferenceResults.getTopClasses(), hasSize(1));
|
||||
|
@ -169,9 +171,13 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
|
||||
public void testInferMissingModel() {
|
||||
String model = "test-infer-missing-model";
|
||||
InferModelAction.Request request = new InferModelAction.Request(model, Collections.emptyList(), new RegressionConfig());
|
||||
InternalInferModelAction.Request request = new InternalInferModelAction.Request(
|
||||
model,
|
||||
Collections.emptyList(),
|
||||
new RegressionConfig(),
|
||||
true);
|
||||
try {
|
||||
client().execute(InferModelAction.INSTANCE, request).actionGet();
|
||||
client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
|
||||
} catch (ElasticsearchException ex) {
|
||||
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model)));
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.action.delete.DeleteRequest;
|
|||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
|
@ -144,6 +145,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
.setDescription("trained model config for test")
|
||||
.setModelId(modelId)
|
||||
.setVersion(Version.CURRENT)
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
.setEstimatedHeapMemory(0)
|
||||
.setEstimatedOperations(0)
|
||||
.setInput(TrainedModelInputTests.createRandomInput());
|
||||
|
|
Loading…
Reference in New Issue