* [ML] ML Model Inference Ingest Processor (#49052) * [ML][Inference] adds lazy model loader and inference (#47410) This adds a couple of things: - A model loader service that is accessible via transport calls. This service will load in models and cache them. They will stay loaded until a processor no longer references them - A Model class and its first sub-class LocalModel. Used to cache model information and run inference. - Transport action and handler for requests to infer against a local model Related Feature PRs: * [ML][Inference] Adjust inference configuration option API (#47812) * [ML][Inference] adds logistic_regression output aggregator (#48075) * [ML][Inference] Adding read/del trained models (#47882) * [ML][Inference] Adding inference ingest processor (#47859) * [ML][Inference] fixing classification inference for ensemble (#48463) * [ML][Inference] Adding model memory estimations (#48323) * [ML][Inference] adding more options to inference processor (#48545) * [ML][Inference] handle string values better in feature extraction (#48584) * [ML][Inference] Adding _stats endpoint for inference (#48492) * [ML][Inference] add inference processors and trained models to usage (#47869) * [ML][Inference] add new flag for optionally including model definition (#48718) * [ML][Inference] adding license checks (#49056) * [ML][Inference] Adding memory and compute estimates to inference (#48955) * fixing version of indexed docs for model inference
This commit is contained in:
parent
48f53efd9a
commit
eefe7688ce
|
@ -22,6 +22,7 @@ import org.elasticsearch.Version;
|
|||
import org.elasticsearch.client.common.TimeUtil;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
@ -47,6 +48,8 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
public static final ParseField TAGS = new ParseField("tags");
|
||||
public static final ParseField METADATA = new ParseField("metadata");
|
||||
public static final ParseField INPUT = new ParseField("input");
|
||||
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
|
||||
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
|
||||
|
||||
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
|
||||
true,
|
||||
|
@ -66,6 +69,8 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
|
||||
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
||||
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
|
||||
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
|
||||
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
|
||||
}
|
||||
|
||||
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
|
||||
|
@ -81,6 +86,8 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
private final List<String> tags;
|
||||
private final Map<String, Object> metadata;
|
||||
private final TrainedModelInput input;
|
||||
private final Long estimatedHeapMemory;
|
||||
private final Long estimatedOperations;
|
||||
|
||||
TrainedModelConfig(String modelId,
|
||||
String createdBy,
|
||||
|
@ -90,7 +97,9 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
TrainedModelDefinition definition,
|
||||
List<String> tags,
|
||||
Map<String, Object> metadata,
|
||||
TrainedModelInput input) {
|
||||
TrainedModelInput input,
|
||||
Long estimatedHeapMemory,
|
||||
Long estimatedOperations) {
|
||||
this.modelId = modelId;
|
||||
this.createdBy = createdBy;
|
||||
this.version = version;
|
||||
|
@ -100,6 +109,8 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
this.tags = tags == null ? null : Collections.unmodifiableList(tags);
|
||||
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
|
||||
this.input = input;
|
||||
this.estimatedHeapMemory = estimatedHeapMemory;
|
||||
this.estimatedOperations = estimatedOperations;
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
|
@ -138,6 +149,18 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
return input;
|
||||
}
|
||||
|
||||
public ByteSizeValue getEstimatedHeapMemory() {
|
||||
return estimatedHeapMemory == null ? null : new ByteSizeValue(estimatedHeapMemory);
|
||||
}
|
||||
|
||||
public Long getEstimatedHeapMemoryBytes() {
|
||||
return estimatedHeapMemory;
|
||||
}
|
||||
|
||||
public Long getEstimatedOperations() {
|
||||
return estimatedOperations;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
@ -172,6 +195,12 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
if (input != null) {
|
||||
builder.field(INPUT.getPreferredName(), input);
|
||||
}
|
||||
if (estimatedHeapMemory != null) {
|
||||
builder.field(ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), estimatedHeapMemory);
|
||||
}
|
||||
if (estimatedOperations != null) {
|
||||
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -194,6 +223,8 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
Objects.equals(definition, that.definition) &&
|
||||
Objects.equals(tags, that.tags) &&
|
||||
Objects.equals(input, that.input) &&
|
||||
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
|
||||
Objects.equals(estimatedOperations, that.estimatedOperations) &&
|
||||
Objects.equals(metadata, that.metadata);
|
||||
}
|
||||
|
||||
|
@ -206,6 +237,8 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
definition,
|
||||
description,
|
||||
tags,
|
||||
estimatedHeapMemory,
|
||||
estimatedOperations,
|
||||
metadata,
|
||||
input);
|
||||
}
|
||||
|
@ -222,6 +255,8 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
private List<String> tags;
|
||||
private TrainedModelDefinition definition;
|
||||
private TrainedModelInput input;
|
||||
private Long estimatedHeapMemory;
|
||||
private Long estimatedOperations;
|
||||
|
||||
public Builder setModelId(String modelId) {
|
||||
this.modelId = modelId;
|
||||
|
@ -277,6 +312,16 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
|
||||
this.estimatedHeapMemory = estimatedHeapMemory;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setEstimatedOperations(Long estimatedOperations) {
|
||||
this.estimatedOperations = estimatedOperations;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainedModelConfig build() {
|
||||
return new TrainedModelConfig(
|
||||
modelId,
|
||||
|
@ -287,7 +332,9 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
definition,
|
||||
tags,
|
||||
metadata,
|
||||
input);
|
||||
input,
|
||||
estimatedHeapMemory,
|
||||
estimatedOperations);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -64,7 +64,10 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
|
|||
randomBoolean() ? null :
|
||||
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
|
||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
|
||||
randomBoolean() ? null : TrainedModelInputTests.createRandomInput());
|
||||
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
|
||||
randomBoolean() ? null : randomNonNegativeLong(),
|
||||
randomBoolean() ? null : randomNonNegativeLong());
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -33,6 +33,7 @@ import java.util.Collections;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
public class IngestStats implements Writeable, ToXContentFragment {
|
||||
|
@ -150,6 +151,21 @@ public class IngestStats implements Writeable, ToXContentFragment {
|
|||
return processorStats;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
IngestStats that = (IngestStats) o;
|
||||
return Objects.equals(totalStats, that.totalStats)
|
||||
&& Objects.equals(pipelineStats, that.pipelineStats)
|
||||
&& Objects.equals(processorStats, that.processorStats);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(totalStats, pipelineStats, processorStats);
|
||||
}
|
||||
|
||||
public static class Stats implements Writeable, ToXContentFragment {
|
||||
|
||||
private final long ingestCount;
|
||||
|
@ -218,6 +234,22 @@ public class IngestStats implements Writeable, ToXContentFragment {
|
|||
builder.field("failed", ingestFailedCount);
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
IngestStats.Stats that = (IngestStats.Stats) o;
|
||||
return Objects.equals(ingestCount, that.ingestCount)
|
||||
&& Objects.equals(ingestTimeInMillis, that.ingestTimeInMillis)
|
||||
&& Objects.equals(ingestFailedCount, that.ingestFailedCount)
|
||||
&& Objects.equals(ingestCurrent, that.ingestCurrent);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(ingestCount, ingestTimeInMillis, ingestFailedCount, ingestCurrent);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -270,6 +302,20 @@ public class IngestStats implements Writeable, ToXContentFragment {
|
|||
public Stats getStats() {
|
||||
return stats;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
IngestStats.PipelineStat that = (IngestStats.PipelineStat) o;
|
||||
return Objects.equals(pipelineId, that.pipelineId)
|
||||
&& Objects.equals(stats, that.stats);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(pipelineId, stats);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -297,5 +343,21 @@ public class IngestStats implements Writeable, ToXContentFragment {
|
|||
public Stats getStats() {
|
||||
return stats;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
IngestStats.ProcessorStat that = (IngestStats.ProcessorStat) o;
|
||||
return Objects.equals(name, that.name)
|
||||
&& Objects.equals(type, that.type)
|
||||
&& Objects.equals(stats, that.stats);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(name, type, stats);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,6 +88,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteFilterAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
|
||||
|
@ -109,6 +110,9 @@ import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
|
||||
|
@ -153,6 +157,19 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.P
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
|
@ -371,6 +388,10 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
|||
StopDataFrameAnalyticsAction.INSTANCE,
|
||||
EvaluateDataFrameAction.INSTANCE,
|
||||
EstimateMemoryUsageAction.INSTANCE,
|
||||
InferModelAction.INSTANCE,
|
||||
GetTrainedModelsAction.INSTANCE,
|
||||
DeleteTrainedModelAction.INSTANCE,
|
||||
GetTrainedModelsStatsAction.INSTANCE,
|
||||
// security
|
||||
ClearRealmCacheAction.INSTANCE,
|
||||
ClearRolesCacheAction.INSTANCE,
|
||||
|
@ -519,6 +540,16 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
|||
new NamedWriteableRegistry.Entry(OutputAggregator.class,
|
||||
LogisticRegression.NAME.getPreferredName(),
|
||||
LogisticRegression::new),
|
||||
// ML - Inference Results
|
||||
new NamedWriteableRegistry.Entry(InferenceResults.class,
|
||||
ClassificationInferenceResults.NAME,
|
||||
ClassificationInferenceResults::new),
|
||||
new NamedWriteableRegistry.Entry(InferenceResults.class,
|
||||
RegressionInferenceResults.NAME,
|
||||
RegressionInferenceResults::new),
|
||||
// ML - Inference Configuration
|
||||
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new),
|
||||
new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new),
|
||||
|
||||
// monitoring
|
||||
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new),
|
||||
|
|
|
@ -29,10 +29,12 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
|
|||
public static final String CREATED_BY = "created_by";
|
||||
public static final String NODE_COUNT = "node_count";
|
||||
public static final String DATA_FRAME_ANALYTICS_JOBS_FIELD = "data_frame_analytics_jobs";
|
||||
public static final String INFERENCE_FIELD = "inference";
|
||||
|
||||
private final Map<String, Object> jobsUsage;
|
||||
private final Map<String, Object> datafeedsUsage;
|
||||
private final Map<String, Object> analyticsUsage;
|
||||
private final Map<String, Object> inferenceUsage;
|
||||
private final int nodeCount;
|
||||
|
||||
public MachineLearningFeatureSetUsage(boolean available,
|
||||
|
@ -40,11 +42,13 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
|
|||
Map<String, Object> jobsUsage,
|
||||
Map<String, Object> datafeedsUsage,
|
||||
Map<String, Object> analyticsUsage,
|
||||
Map<String, Object> inferenceUsage,
|
||||
int nodeCount) {
|
||||
super(XPackField.MACHINE_LEARNING, available, enabled);
|
||||
this.jobsUsage = Objects.requireNonNull(jobsUsage);
|
||||
this.datafeedsUsage = Objects.requireNonNull(datafeedsUsage);
|
||||
this.analyticsUsage = Objects.requireNonNull(analyticsUsage);
|
||||
this.inferenceUsage = Objects.requireNonNull(inferenceUsage);
|
||||
this.nodeCount = nodeCount;
|
||||
}
|
||||
|
||||
|
@ -57,12 +61,17 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
|
|||
} else {
|
||||
this.analyticsUsage = Collections.emptyMap();
|
||||
}
|
||||
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||
this.inferenceUsage = in.readMap();
|
||||
} else {
|
||||
this.inferenceUsage = Collections.emptyMap();
|
||||
}
|
||||
if (in.getVersion().onOrAfter(Version.V_6_5_0)) {
|
||||
this.nodeCount = in.readInt();
|
||||
} else {
|
||||
this.nodeCount = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
|
@ -72,10 +81,13 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
|
|||
if (out.getVersion().onOrAfter(Version.V_7_4_0)) {
|
||||
out.writeMap(analyticsUsage);
|
||||
}
|
||||
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||
out.writeMap(inferenceUsage);
|
||||
}
|
||||
if (out.getVersion().onOrAfter(Version.V_6_5_0)) {
|
||||
out.writeInt(nodeCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void innerXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
|
@ -83,6 +95,7 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
|
|||
builder.field(JOBS_FIELD, jobsUsage);
|
||||
builder.field(DATAFEEDS_FIELD, datafeedsUsage);
|
||||
builder.field(DATA_FRAME_ANALYTICS_JOBS_FIELD, analyticsUsage);
|
||||
builder.field(INFERENCE_FIELD, inferenceUsage);
|
||||
if (nodeCount >= 0) {
|
||||
builder.field(NODE_COUNT, nodeCount);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.action.ActionRequestValidationException;
|
||||
import org.elasticsearch.action.ActionType;
|
||||
import org.elasticsearch.action.support.master.AcknowledgedRequest;
|
||||
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ToXContentFragment;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class DeleteTrainedModelAction extends ActionType<AcknowledgedResponse> {
|
||||
|
||||
public static final DeleteTrainedModelAction INSTANCE = new DeleteTrainedModelAction();
|
||||
public static final String NAME = "cluster:admin/xpack/ml/inference/delete";
|
||||
|
||||
private DeleteTrainedModelAction() {
|
||||
super(NAME, AcknowledgedResponse::new);
|
||||
}
|
||||
|
||||
public static class Request extends AcknowledgedRequest<Request> implements ToXContentFragment {
|
||||
|
||||
private String id;
|
||||
|
||||
public Request(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
id = in.readString();
|
||||
}
|
||||
|
||||
public Request() {}
|
||||
|
||||
public Request(String id) {
|
||||
this.id = ExceptionsHelper.requireNonNull(id, TrainedModelConfig.MODEL_ID);
|
||||
}
|
||||
|
||||
public String getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActionRequestValidationException validate() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), id);
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
DeleteTrainedModelAction.Request request = (DeleteTrainedModelAction.Request) o;
|
||||
return Objects.equals(id, request.id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
out.writeString(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(id);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.action.ActionType;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
|
||||
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
|
||||
import org.elasticsearch.xpack.core.action.util.QueryPage;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {
|
||||
|
||||
public static final GetTrainedModelsAction INSTANCE = new GetTrainedModelsAction();
|
||||
public static final String NAME = "cluster:monitor/xpack/ml/inference/get";
|
||||
|
||||
private GetTrainedModelsAction() {
|
||||
super(NAME, Response::new);
|
||||
}
|
||||
|
||||
public static class Request extends AbstractGetResourcesRequest {
|
||||
|
||||
public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
|
||||
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
|
||||
|
||||
private final boolean includeModelDefinition;
|
||||
|
||||
public Request(String id, boolean includeModelDefinition) {
|
||||
setResourceId(id);
|
||||
setAllowNoResources(true);
|
||||
this.includeModelDefinition = includeModelDefinition;
|
||||
}
|
||||
|
||||
public Request(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
this.includeModelDefinition = in.readBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getResourceIdField() {
|
||||
return TrainedModelConfig.MODEL_ID.getPreferredName();
|
||||
}
|
||||
|
||||
public boolean isIncludeModelDefinition() {
|
||||
return includeModelDefinition;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
out.writeBoolean(includeModelDefinition);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), includeModelDefinition);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (obj == this) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null || getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
Request other = (Request) obj;
|
||||
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition;
|
||||
}
|
||||
}
|
||||
|
||||
public static class Response extends AbstractGetResourcesResponse<TrainedModelConfig> {
|
||||
|
||||
public static final ParseField RESULTS_FIELD = new ParseField("trained_model_configs");
|
||||
|
||||
public Response(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
}
|
||||
|
||||
public Response(QueryPage<TrainedModelConfig> trainedModels) {
|
||||
super(trainedModels);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Reader<TrainedModelConfig> getReader() {
|
||||
return TrainedModelConfig::new;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private long totalCount;
|
||||
private List<TrainedModelConfig> configs = Collections.emptyList();
|
||||
|
||||
private Builder() {
|
||||
}
|
||||
|
||||
public Builder setTotalCount(long totalCount) {
|
||||
this.totalCount = totalCount;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setModels(List<TrainedModelConfig> configs) {
|
||||
this.configs = configs;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Response build() {
|
||||
return new Response(new QueryPage<>(configs, totalCount, RESULTS_FIELD));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,194 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.action.ActionRequestBuilder;
|
||||
import org.elasticsearch.action.ActionType;
|
||||
import org.elasticsearch.client.ElasticsearchClient;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestStats;
|
||||
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
|
||||
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
|
||||
import org.elasticsearch.xpack.core.action.util.QueryPage;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStatsAction.Response> {
|
||||
|
||||
public static final GetTrainedModelsStatsAction INSTANCE = new GetTrainedModelsStatsAction();
|
||||
public static final String NAME = "cluster:monitor/xpack/ml/inference/stats/get";
|
||||
|
||||
public static final ParseField MODEL_ID = new ParseField("model_id");
|
||||
public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
|
||||
|
||||
private GetTrainedModelsStatsAction() {
|
||||
super(NAME, GetTrainedModelsStatsAction.Response::new);
|
||||
}
|
||||
|
||||
public static class Request extends AbstractGetResourcesRequest {
|
||||
|
||||
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
|
||||
|
||||
public Request() {
|
||||
setAllowNoResources(true);
|
||||
}
|
||||
|
||||
public Request(String id) {
|
||||
setResourceId(id);
|
||||
setAllowNoResources(true);
|
||||
}
|
||||
|
||||
public Request(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getResourceIdField() {
|
||||
return TrainedModelConfig.MODEL_ID.getPreferredName();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class RequestBuilder extends ActionRequestBuilder<Request, Response> {
|
||||
|
||||
public RequestBuilder(ElasticsearchClient client, GetTrainedModelsStatsAction action) {
|
||||
super(client, action, new Request());
|
||||
}
|
||||
}
|
||||
|
||||
public static class Response extends AbstractGetResourcesResponse<Response.TrainedModelStats> {
|
||||
|
||||
public static class TrainedModelStats implements ToXContentObject, Writeable {
|
||||
private final String modelId;
|
||||
private final IngestStats ingestStats;
|
||||
private final int pipelineCount;
|
||||
|
||||
private static final IngestStats EMPTY_INGEST_STATS = new IngestStats(new IngestStats.Stats(0, 0, 0, 0),
|
||||
Collections.emptyList(),
|
||||
Collections.emptyMap());
|
||||
|
||||
public TrainedModelStats(String modelId, IngestStats ingestStats, int pipelineCount) {
|
||||
this.modelId = Objects.requireNonNull(modelId);
|
||||
this.ingestStats = ingestStats == null ? EMPTY_INGEST_STATS : ingestStats;
|
||||
if (pipelineCount < 0) {
|
||||
throw new ElasticsearchException("[{}] must be a greater than or equal to 0", PIPELINE_COUNT.getPreferredName());
|
||||
}
|
||||
this.pipelineCount = pipelineCount;
|
||||
}
|
||||
|
||||
public TrainedModelStats(StreamInput in) throws IOException {
|
||||
modelId = in.readString();
|
||||
ingestStats = new IngestStats(in);
|
||||
pipelineCount = in.readVInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(MODEL_ID.getPreferredName(), modelId);
|
||||
builder.field(PIPELINE_COUNT.getPreferredName(), pipelineCount);
|
||||
if (pipelineCount > 0) {
|
||||
// Ingest stats is a fragment
|
||||
ingestStats.toXContent(builder, params);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(modelId);
|
||||
ingestStats.writeTo(out);
|
||||
out.writeVInt(pipelineCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(modelId, ingestStats, pipelineCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (obj == null) {
|
||||
return false;
|
||||
}
|
||||
if (getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
TrainedModelStats other = (TrainedModelStats) obj;
|
||||
return Objects.equals(this.modelId, other.modelId)
|
||||
&& Objects.equals(this.ingestStats, other.ingestStats)
|
||||
&& Objects.equals(this.pipelineCount, other.pipelineCount);
|
||||
}
|
||||
}
|
||||
|
||||
public static final ParseField RESULTS_FIELD = new ParseField("trained_model_stats");
|
||||
|
||||
public Response(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
}
|
||||
|
||||
public Response(QueryPage<Response.TrainedModelStats> trainedModels) {
|
||||
super(trainedModels);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Reader<Response.TrainedModelStats> getReader() {
|
||||
return Response.TrainedModelStats::new;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private long totalModelCount;
|
||||
private Set<String> expandedIds;
|
||||
private Map<String, IngestStats> ingestStatsMap;
|
||||
|
||||
public Builder setTotalModelCount(long totalModelCount) {
|
||||
this.totalModelCount = totalModelCount;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setExpandedIds(Set<String> expandedIds) {
|
||||
this.expandedIds = expandedIds;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Set<String> getExpandedIds() {
|
||||
return this.expandedIds;
|
||||
}
|
||||
|
||||
public Builder setIngestStatsByModelId(Map<String, IngestStats> ingestStatsByModelId) {
|
||||
this.ingestStatsMap = ingestStatsByModelId;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Response build() {
|
||||
List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIds.size());
|
||||
expandedIds.forEach(id -> {
|
||||
IngestStats ingestStats = ingestStatsMap.get(id);
|
||||
trainedModelStats.add(new TrainedModelStats(id, ingestStats, ingestStats == null ?
|
||||
0 :
|
||||
ingestStats.getPipelineStats().size()));
|
||||
});
|
||||
return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.action.ActionRequest;
|
||||
import org.elasticsearch.action.ActionRequestValidationException;
|
||||
import org.elasticsearch.action.ActionResponse;
|
||||
import org.elasticsearch.action.ActionType;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class InferModelAction extends ActionType<InferModelAction.Response> {
|
||||
|
||||
public static final InferModelAction INSTANCE = new InferModelAction();
|
||||
public static final String NAME = "cluster:admin/xpack/ml/inference/infer";
|
||||
|
||||
private InferModelAction() {
|
||||
super(NAME, Response::new);
|
||||
}
|
||||
|
||||
public static class Request extends ActionRequest {
|
||||
|
||||
private final String modelId;
|
||||
private final List<Map<String, Object>> objectsToInfer;
|
||||
private final InferenceConfig config;
|
||||
|
||||
public Request(String modelId) {
|
||||
this(modelId, Collections.emptyList(), new RegressionConfig());
|
||||
}
|
||||
|
||||
public Request(String modelId, List<Map<String, Object>> objectsToInfer, InferenceConfig inferenceConfig) {
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
|
||||
this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer"));
|
||||
this.config = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config");
|
||||
}
|
||||
|
||||
public Request(String modelId, Map<String, Object> objectToInfer, InferenceConfig config) {
|
||||
this(modelId,
|
||||
Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")),
|
||||
config);
|
||||
}
|
||||
|
||||
public Request(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
this.modelId = in.readString();
|
||||
this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap));
|
||||
this.config = in.readNamedWriteable(InferenceConfig.class);
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
return modelId;
|
||||
}
|
||||
|
||||
public List<Map<String, Object>> getObjectsToInfer() {
|
||||
return objectsToInfer;
|
||||
}
|
||||
|
||||
public InferenceConfig getConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActionRequestValidationException validate() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
out.writeString(modelId);
|
||||
out.writeCollection(objectsToInfer, StreamOutput::writeMap);
|
||||
out.writeNamedWriteable(config);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
InferModelAction.Request that = (InferModelAction.Request) o;
|
||||
return Objects.equals(modelId, that.modelId)
|
||||
&& Objects.equals(config, that.config)
|
||||
&& Objects.equals(objectsToInfer, that.objectsToInfer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(modelId, objectsToInfer, config);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class Response extends ActionResponse {
|
||||
|
||||
private final List<InferenceResults> inferenceResults;
|
||||
|
||||
public Response(List<InferenceResults> inferenceResults) {
|
||||
super();
|
||||
this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults"));
|
||||
}
|
||||
|
||||
public Response(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class));
|
||||
}
|
||||
|
||||
public List<InferenceResults> getInferenceResults() {
|
||||
return inferenceResults;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeNamedWriteableList(inferenceResults);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
InferModelAction.Response that = (InferModelAction.Response) o;
|
||||
return Objects.equals(inferenceResults, that.inferenceResults);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(inferenceResults);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -8,7 +8,13 @@ package org.elasticsearch.xpack.core.ml.inference;
|
|||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
|
@ -110,6 +116,18 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
|||
LogisticRegression.NAME.getPreferredName(),
|
||||
LogisticRegression::new));
|
||||
|
||||
// Inference Results
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
|
||||
ClassificationInferenceResults.NAME,
|
||||
ClassificationInferenceResults::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
|
||||
RegressionInferenceResults.NAME,
|
||||
RegressionInferenceResults::new));
|
||||
|
||||
// Inference Configs
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new));
|
||||
|
||||
return namedWriteables;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.common.Strings;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
@ -34,6 +35,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
|
||||
public static final String NAME = "trained_model_config";
|
||||
|
||||
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
|
||||
|
||||
public static final ParseField MODEL_ID = new ParseField("model_id");
|
||||
public static final ParseField CREATED_BY = new ParseField("created_by");
|
||||
public static final ParseField VERSION = new ParseField("version");
|
||||
|
@ -43,6 +46,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
public static final ParseField TAGS = new ParseField("tags");
|
||||
public static final ParseField METADATA = new ParseField("metadata");
|
||||
public static final ParseField INPUT = new ParseField("input");
|
||||
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
|
||||
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
|
||||
|
||||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
||||
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
|
||||
|
@ -66,6 +71,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
parser.declareObject(TrainedModelConfig.Builder::setInput,
|
||||
(p, c) -> TrainedModelInput.fromXContent(p, ignoreUnknownFields),
|
||||
INPUT);
|
||||
parser.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
|
||||
parser.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -81,6 +88,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
private final List<String> tags;
|
||||
private final Map<String, Object> metadata;
|
||||
private final TrainedModelInput input;
|
||||
private final long estimatedHeapMemory;
|
||||
private final long estimatedOperations;
|
||||
|
||||
private final TrainedModelDefinition definition;
|
||||
|
||||
|
@ -92,7 +101,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
TrainedModelDefinition definition,
|
||||
List<String> tags,
|
||||
Map<String, Object> metadata,
|
||||
TrainedModelInput input) {
|
||||
TrainedModelInput input,
|
||||
Long estimatedHeapMemory,
|
||||
Long estimatedOperations) {
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
||||
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
|
||||
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
|
||||
|
@ -102,6 +113,15 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS));
|
||||
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
|
||||
this.input = ExceptionsHelper.requireNonNull(input, INPUT);
|
||||
if (ExceptionsHelper.requireNonNull(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES) < 0) {
|
||||
throw new IllegalArgumentException(
|
||||
"[" + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName() + "] must be greater than or equal to 0");
|
||||
}
|
||||
this.estimatedHeapMemory = estimatedHeapMemory;
|
||||
if (ExceptionsHelper.requireNonNull(estimatedOperations, ESTIMATED_OPERATIONS) < 0) {
|
||||
throw new IllegalArgumentException("[" + ESTIMATED_OPERATIONS.getPreferredName() + "] must be greater than or equal to 0");
|
||||
}
|
||||
this.estimatedOperations = estimatedOperations;
|
||||
}
|
||||
|
||||
public TrainedModelConfig(StreamInput in) throws IOException {
|
||||
|
@ -114,6 +134,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
tags = Collections.unmodifiableList(in.readList(StreamInput::readString));
|
||||
metadata = in.readMap();
|
||||
input = new TrainedModelInput(in);
|
||||
estimatedHeapMemory = in.readVLong();
|
||||
estimatedOperations = in.readVLong();
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
|
@ -157,6 +179,14 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
return new Builder();
|
||||
}
|
||||
|
||||
public long getEstimatedHeapMemory() {
|
||||
return estimatedHeapMemory;
|
||||
}
|
||||
|
||||
public long getEstimatedOperations() {
|
||||
return estimatedOperations;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(modelId);
|
||||
|
@ -168,6 +198,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
out.writeCollection(tags, StreamOutput::writeString);
|
||||
out.writeMap(metadata);
|
||||
input.writeTo(out);
|
||||
out.writeVLong(estimatedHeapMemory);
|
||||
out.writeVLong(estimatedOperations);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -180,7 +212,6 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
builder.field(DESCRIPTION.getPreferredName(), description);
|
||||
}
|
||||
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
|
||||
|
||||
// We don't store the definition in the same document as the configuration
|
||||
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
|
||||
builder.field(DEFINITION.getPreferredName(), definition);
|
||||
|
@ -193,6 +224,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
|
||||
}
|
||||
builder.field(INPUT.getPreferredName(), input);
|
||||
builder.humanReadableField(
|
||||
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
|
||||
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
|
||||
new ByteSizeValue(estimatedHeapMemory));
|
||||
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -215,6 +251,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
Objects.equals(definition, that.definition) &&
|
||||
Objects.equals(tags, that.tags) &&
|
||||
Objects.equals(input, that.input) &&
|
||||
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
|
||||
Objects.equals(estimatedOperations, that.estimatedOperations) &&
|
||||
Objects.equals(metadata, that.metadata);
|
||||
}
|
||||
|
||||
|
@ -228,6 +266,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
description,
|
||||
tags,
|
||||
metadata,
|
||||
estimatedHeapMemory,
|
||||
estimatedOperations,
|
||||
input);
|
||||
}
|
||||
|
||||
|
@ -242,6 +282,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
private Map<String, Object> metadata;
|
||||
private TrainedModelInput input;
|
||||
private TrainedModelDefinition definition;
|
||||
private Long estimatedHeapMemory;
|
||||
private Long estimatedOperations;
|
||||
|
||||
public Builder setModelId(String modelId) {
|
||||
this.modelId = modelId;
|
||||
|
@ -297,6 +339,16 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setEstimatedHeapMemory(long estimatedHeapMemory) {
|
||||
this.estimatedHeapMemory = estimatedHeapMemory;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setEstimatedOperations(long estimatedOperations) {
|
||||
this.estimatedOperations = estimatedOperations;
|
||||
return this;
|
||||
}
|
||||
|
||||
// TODO move to REST level instead of here in the builder
|
||||
public void validate() {
|
||||
// We require a definition to be available here even though it will be stored in a different doc
|
||||
|
@ -327,6 +379,16 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
|
||||
CREATE_TIME.getPreferredName());
|
||||
}
|
||||
|
||||
if (estimatedHeapMemory != null) {
|
||||
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
|
||||
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName());
|
||||
}
|
||||
|
||||
if (estimatedOperations != null) {
|
||||
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
|
||||
ESTIMATED_OPERATIONS.getPreferredName());
|
||||
}
|
||||
}
|
||||
|
||||
public TrainedModelConfig build() {
|
||||
|
@ -339,7 +401,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
definition,
|
||||
tags,
|
||||
metadata,
|
||||
input);
|
||||
input,
|
||||
estimatedHeapMemory,
|
||||
estimatedOperations);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,12 +5,16 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Accountables;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
@ -19,6 +23,8 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst
|
|||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
|
@ -27,13 +33,18 @@ import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
|||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
||||
public class TrainedModelDefinition implements ToXContentObject, Writeable, Accountable {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TrainedModelDefinition.class);
|
||||
public static final String NAME = "trained_model_definition";
|
||||
public static final String HEAP_MEMORY_ESTIMATION = "heap_memory_estimation";
|
||||
|
||||
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
|
||||
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
|
||||
|
@ -106,6 +117,11 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
|||
true,
|
||||
PREPROCESSORS.getPreferredName(),
|
||||
preProcessors);
|
||||
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) {
|
||||
builder.humanReadableField(HEAP_MEMORY_ESTIMATION + "_bytes",
|
||||
HEAP_MEMORY_ESTIMATION,
|
||||
new ByteSizeValue(ramBytesUsed()));
|
||||
}
|
||||
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
|
||||
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
|
||||
assert modelId != null;
|
||||
|
@ -123,6 +139,15 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
|||
return preProcessors;
|
||||
}
|
||||
|
||||
private void preProcess(Map<String, Object> fields) {
|
||||
preProcessors.forEach(preProcessor -> preProcessor.process(fields));
|
||||
}
|
||||
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
|
||||
preProcess(fields);
|
||||
return trainedModel.infer(fields, config);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
|
@ -143,6 +168,24 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
|||
return Objects.hash(trainedModel, preProcessors, modelId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
size += RamUsageEstimator.sizeOf(trainedModel);
|
||||
size += RamUsageEstimator.sizeOfCollection(preProcessors);
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<Accountable> getChildResources() {
|
||||
List<Accountable> accountables = new ArrayList<>(preProcessors.size() + 2);
|
||||
accountables.add(Accountables.namedAccountable("trained_model", trainedModel));
|
||||
for(PreProcessor preProcessor : preProcessors) {
|
||||
accountables.add(Accountables.namedAccountable("pre_processor_" + preProcessor.getName(), preProcessor));
|
||||
}
|
||||
return accountables;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private List<PreProcessor> preProcessors;
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
|
@ -25,6 +27,8 @@ import java.util.Objects;
|
|||
*/
|
||||
public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FrequencyEncoding.class);
|
||||
|
||||
public static final ParseField NAME = new ParseField("frequency_encoding");
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
|
@ -143,4 +147,17 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
|||
return Objects.hash(field, featureName, frequencyMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
size += RamUsageEstimator.sizeOf(field);
|
||||
size += RamUsageEstimator.sizeOf(featureName);
|
||||
size += RamUsageEstimator.sizeOfMap(frequencyMap);
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
|
@ -23,6 +25,7 @@ import java.util.Objects;
|
|||
*/
|
||||
public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(OneHotEncoding.class);
|
||||
public static final ParseField NAME = new ParseField("one_hot_encoding");
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField HOT_MAP = new ParseField("hot_map");
|
||||
|
@ -127,4 +130,16 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
|||
return Objects.hash(field, hotMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
size += RamUsageEstimator.sizeOf(field);
|
||||
size += RamUsageEstimator.sizeOfMap(hotMap);
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
|
@ -14,7 +15,7 @@ import java.util.Map;
|
|||
* Describes a pre-processor for a defined machine learning model
|
||||
* This processor should take a set of fields and return the modified set of fields.
|
||||
*/
|
||||
public interface PreProcessor extends NamedXContentObject, NamedWriteable {
|
||||
public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable {
|
||||
|
||||
/**
|
||||
* Process the given fields and their values and return the modified map.
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
|
@ -25,6 +27,7 @@ import java.util.Objects;
|
|||
*/
|
||||
public class TargetMeanEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TargetMeanEncoding.class);
|
||||
public static final ParseField NAME = new ParseField("target_mean_encoding");
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
|
@ -158,4 +161,17 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
return Objects.hash(field, featureName, meanMap, defaultValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
size += RamUsageEstimator.sizeOf(field);
|
||||
size += RamUsageEstimator.sizeOf(featureName);
|
||||
size += RamUsageEstimator.sizeOfMap(meanMap);
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,175 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ClassificationInferenceResults extends SingleValueInferenceResults {
|
||||
|
||||
public static final String NAME = "classification";
|
||||
public static final ParseField CLASSIFICATION_LABEL = new ParseField("classification_label");
|
||||
public static final ParseField TOP_CLASSES = new ParseField("top_classes");
|
||||
|
||||
private final String classificationLabel;
|
||||
private final List<TopClassEntry> topClasses;
|
||||
|
||||
public ClassificationInferenceResults(double value, String classificationLabel, List<TopClassEntry> topClasses) {
|
||||
super(value);
|
||||
this.classificationLabel = classificationLabel;
|
||||
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
|
||||
}
|
||||
|
||||
public ClassificationInferenceResults(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
this.classificationLabel = in.readOptionalString();
|
||||
this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new));
|
||||
}
|
||||
|
||||
public String getClassificationLabel() {
|
||||
return classificationLabel;
|
||||
}
|
||||
|
||||
public List<TopClassEntry> getTopClasses() {
|
||||
return topClasses;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
out.writeOptionalString(classificationLabel);
|
||||
out.writeCollection(topClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
if (classificationLabel != null) {
|
||||
builder.field(CLASSIFICATION_LABEL.getPreferredName(), classificationLabel);
|
||||
}
|
||||
if (topClasses.isEmpty() == false) {
|
||||
builder.field(TOP_CLASSES.getPreferredName(), topClasses);
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
ClassificationInferenceResults that = (ClassificationInferenceResults) object;
|
||||
return Objects.equals(value(), that.value()) &&
|
||||
Objects.equals(classificationLabel, that.classificationLabel) &&
|
||||
Objects.equals(topClasses, that.topClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(value(), classificationLabel, topClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String valueAsString() {
|
||||
return classificationLabel == null ? super.valueAsString() : classificationLabel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeResult(IngestDocument document, String resultField) {
|
||||
ExceptionsHelper.requireNonNull(document, "document");
|
||||
ExceptionsHelper.requireNonNull(resultField, "resultField");
|
||||
if (topClasses.isEmpty()) {
|
||||
document.setFieldValue(resultField, valueAsString());
|
||||
} else {
|
||||
document.setFieldValue(resultField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public static class TopClassEntry implements ToXContentObject, Writeable {
|
||||
|
||||
public final ParseField CLASSIFICATION = new ParseField("classification");
|
||||
public final ParseField PROBABILITY = new ParseField("probability");
|
||||
|
||||
private final String classification;
|
||||
private final double probability;
|
||||
|
||||
public TopClassEntry(String classification, Double probability) {
|
||||
this.classification = ExceptionsHelper.requireNonNull(classification, CLASSIFICATION);
|
||||
this.probability = ExceptionsHelper.requireNonNull(probability, PROBABILITY);
|
||||
}
|
||||
|
||||
public TopClassEntry(StreamInput in) throws IOException {
|
||||
this.classification = in.readString();
|
||||
this.probability = in.readDouble();
|
||||
}
|
||||
|
||||
public String getClassification() {
|
||||
return classification;
|
||||
}
|
||||
|
||||
public double getProbability() {
|
||||
return probability;
|
||||
}
|
||||
|
||||
public Map<String, Object> asValueMap() {
|
||||
Map<String, Object> map = new HashMap<>(2);
|
||||
map.put(CLASSIFICATION.getPreferredName(), classification);
|
||||
map.put(PROBABILITY.getPreferredName(), probability);
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(classification);
|
||||
out.writeDouble(probability);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASSIFICATION.getPreferredName(), classification);
|
||||
builder.field(PROBABILITY.getPreferredName(), probability);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
TopClassEntry that = (TopClassEntry) object;
|
||||
return Objects.equals(classification, that.classification) &&
|
||||
Objects.equals(probability, that.probability);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(classification, probability);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
public interface InferenceResults extends NamedXContentObject, NamedWriteable {
|
||||
|
||||
void writeResult(IngestDocument document, String resultField);
|
||||
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class RawInferenceResults extends SingleValueInferenceResults {
|
||||
|
||||
public static final String NAME = "raw";
|
||||
|
||||
public RawInferenceResults(double value) {
|
||||
super(value);
|
||||
}
|
||||
|
||||
public RawInferenceResults(StreamInput in) throws IOException {
|
||||
super(in.readDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
XContentBuilder innerToXContent(XContentBuilder builder, Params params) {
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
RawInferenceResults that = (RawInferenceResults) object;
|
||||
return Objects.equals(value(), that.value());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(value());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeResult(IngestDocument document, String resultField) {
|
||||
throw new UnsupportedOperationException("[raw] does not support writing inference results");
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class RegressionInferenceResults extends SingleValueInferenceResults {
|
||||
|
||||
public static final String NAME = "regression";
|
||||
|
||||
public RegressionInferenceResults(double value) {
|
||||
super(value);
|
||||
}
|
||||
|
||||
public RegressionInferenceResults(StreamInput in) throws IOException {
|
||||
super(in.readDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
XContentBuilder innerToXContent(XContentBuilder builder, Params params) {
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
RegressionInferenceResults that = (RegressionInferenceResults) object;
|
||||
return Objects.equals(value(), that.value());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(value());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeResult(IngestDocument document, String resultField) {
|
||||
ExceptionsHelper.requireNonNull(document, "document");
|
||||
ExceptionsHelper.requireNonNull(resultField, "resultField");
|
||||
document.setFieldValue(resultField, value());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public abstract class SingleValueInferenceResults implements InferenceResults {
|
||||
|
||||
public final ParseField VALUE = new ParseField("value");
|
||||
|
||||
private final double value;
|
||||
|
||||
SingleValueInferenceResults(StreamInput in) throws IOException {
|
||||
value = in.readDouble();
|
||||
}
|
||||
|
||||
SingleValueInferenceResults(double value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public Double value() {
|
||||
return value;
|
||||
}
|
||||
|
||||
public String valueAsString() {
|
||||
return String.valueOf(value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(value);
|
||||
}
|
||||
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(VALUE.getPreferredName(), value);
|
||||
innerToXContent(builder, params);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException;
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class ClassificationConfig implements InferenceConfig {
|
||||
|
||||
public static final String NAME = "classification";
|
||||
|
||||
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
|
||||
|
||||
public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0);
|
||||
|
||||
private final int numTopClasses;
|
||||
|
||||
public static ClassificationConfig fromMap(Map<String, Object> map) {
|
||||
Map<String, Object> options = new HashMap<>(map);
|
||||
Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
|
||||
if (options.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
|
||||
}
|
||||
return new ClassificationConfig(numTopClasses);
|
||||
}
|
||||
|
||||
public ClassificationConfig(Integer numTopClasses) {
|
||||
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
|
||||
}
|
||||
|
||||
public ClassificationConfig(StreamInput in) throws IOException {
|
||||
this.numTopClasses = in.readInt();
|
||||
}
|
||||
|
||||
public int getNumTopClasses() {
|
||||
return numTopClasses;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeInt(numTopClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ClassificationConfig that = (ClassificationConfig) o;
|
||||
return Objects.equals(numTopClasses, that.numTopClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(numTopClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (numTopClasses != 0) {
|
||||
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isTargetTypeSupported(TargetType targetType) {
|
||||
return TargetType.CLASSIFICATION.equals(targetType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Version getMinimalSupportedVersion() {
|
||||
return MIN_SUPPORTED_VERSION;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
|
||||
public interface InferenceConfig extends NamedXContentObject, NamedWriteable {
|
||||
|
||||
boolean isTargetTypeSupported(TargetType targetType);
|
||||
|
||||
/**
|
||||
* All nodes in the cluster must be at least this version
|
||||
*/
|
||||
Version getMinimalSupportedVersion();
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
public final class InferenceHelpers {
|
||||
|
||||
private InferenceHelpers() { }
|
||||
|
||||
public static List<ClassificationInferenceResults.TopClassEntry> topClasses(List<Double> probabilities,
|
||||
List<String> classificationLabels,
|
||||
int numToInclude) {
|
||||
if (numToInclude == 0) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
int[] sortedIndices = IntStream.range(0, probabilities.size())
|
||||
.boxed()
|
||||
.sorted(Comparator.comparing(probabilities::get).reversed())
|
||||
.mapToInt(i -> i)
|
||||
.toArray();
|
||||
|
||||
if (classificationLabels != null && probabilities.size() != classificationLabels.size()) {
|
||||
throw ExceptionsHelper
|
||||
.serverError(
|
||||
"model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]",
|
||||
null,
|
||||
probabilities.size(),
|
||||
classificationLabels);
|
||||
}
|
||||
|
||||
List<String> labels = classificationLabels == null ?
|
||||
// If we don't have the labels we should return the top classification values anyways, they will just be numeric
|
||||
IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) :
|
||||
classificationLabels;
|
||||
|
||||
int count = numToInclude < 0 ? probabilities.size() : Math.min(numToInclude, probabilities.size());
|
||||
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
|
||||
for(int i = 0; i < count; i++) {
|
||||
int idx = sortedIndices[i];
|
||||
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx)));
|
||||
}
|
||||
|
||||
return topClassEntries;
|
||||
}
|
||||
|
||||
public static String classificationLabel(double inferenceValue, @Nullable List<String> classificationLabels) {
|
||||
assert inferenceValue == Math.rint(inferenceValue);
|
||||
if (classificationLabels == null) {
|
||||
return String.valueOf(inferenceValue);
|
||||
}
|
||||
int label = Double.valueOf(inferenceValue).intValue();
|
||||
if (label < 0 || label >= classificationLabels.size()) {
|
||||
throw ExceptionsHelper.serverError(
|
||||
"model returned classification value of [{}] which is not a valid index in classification labels [{}]",
|
||||
null,
|
||||
label,
|
||||
classificationLabels);
|
||||
}
|
||||
return classificationLabels.get(label);
|
||||
}
|
||||
|
||||
public static Double toDouble(Object value) {
|
||||
if (value instanceof Number) {
|
||||
return ((Number)value).doubleValue();
|
||||
}
|
||||
if (value instanceof String) {
|
||||
try {
|
||||
return Double.valueOf((String)value);
|
||||
} catch (NumberFormatException nfe) {
|
||||
assert false : "value is not properly formatted double [" + value + "]";
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Used by ensemble to pass into sub-models.
|
||||
*/
|
||||
public class NullInferenceConfig implements InferenceConfig {
|
||||
|
||||
public static final NullInferenceConfig INSTANCE = new NullInferenceConfig();
|
||||
|
||||
private NullInferenceConfig() { }
|
||||
|
||||
@Override
|
||||
public boolean isTargetTypeSupported(TargetType targetType) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Version getMinimalSupportedVersion() {
|
||||
return Version.CURRENT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return "null";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "null";
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
return builder;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class RegressionConfig implements InferenceConfig {
|
||||
|
||||
public static final String NAME = "regression";
|
||||
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
|
||||
|
||||
public static RegressionConfig fromMap(Map<String, Object> map) {
|
||||
if (map.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
|
||||
}
|
||||
return new RegressionConfig();
|
||||
}
|
||||
|
||||
public RegressionConfig() {
|
||||
}
|
||||
|
||||
public RegressionConfig(StreamInput in) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isTargetTypeSupported(TargetType targetType) {
|
||||
return TargetType.REGRESSION.equals(targetType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Version getMinimalSupportedVersion() {
|
||||
return MIN_SUPPORTED_VERSION;
|
||||
}
|
||||
|
||||
}
|
|
@ -5,14 +5,16 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public interface TrainedModel extends NamedXContentObject, NamedWriteable {
|
||||
public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accountable {
|
||||
|
||||
/**
|
||||
* @return List of featureNames expected by the model. In the order that they are expected
|
||||
|
@ -23,16 +25,11 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable {
|
|||
* Infer against the provided fields
|
||||
*
|
||||
* @param fields The fields and their values to infer against
|
||||
* @param config The configuration options for inference
|
||||
* @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0).
|
||||
* For regression this is continuous.
|
||||
*/
|
||||
double infer(Map<String, Object> fields);
|
||||
|
||||
/**
|
||||
* @param fields similar to {@link TrainedModel#infer(Map)}, but fields are already in order and doubles
|
||||
* @return The predicted value.
|
||||
*/
|
||||
double infer(List<Double> fields);
|
||||
InferenceResults infer(Map<String, Object> fields, InferenceConfig config);
|
||||
|
||||
/**
|
||||
* @return {@link TargetType} for the model.
|
||||
|
@ -40,26 +37,7 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable {
|
|||
TargetType targetType();
|
||||
|
||||
/**
|
||||
* This gathers the probabilities for each potential classification value.
|
||||
*
|
||||
* The probabilities are indexed by classification ordinal label encoding.
|
||||
* The length of this list is equal to the number of classification labels.
|
||||
*
|
||||
* This only should return if the implementation model is inferring classification values and not regression
|
||||
* @param fields The fields and their values to infer against
|
||||
* @return The probabilities of each classification value
|
||||
*/
|
||||
List<Double> classificationProbability(Map<String, Object> fields);
|
||||
|
||||
/**
|
||||
* @param fields similar to {@link TrainedModel#classificationProbability(Map)} but the fields are already in order and doubles
|
||||
* @return The probabilities of each classification value
|
||||
*/
|
||||
List<Double> classificationProbability(List<Double> fields);
|
||||
|
||||
/**
|
||||
* The ordinal encoded list of the classification labels.
|
||||
* @return Oridinal encoded list of classification labels.
|
||||
* @return Ordinal encoded list of classification labels.
|
||||
*/
|
||||
@Nullable
|
||||
List<String> classificationLabels();
|
||||
|
@ -72,4 +50,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable {
|
|||
* @throws org.elasticsearch.ElasticsearchException if validations fail
|
||||
*/
|
||||
void validate();
|
||||
|
||||
/**
|
||||
* @return The estimated number of operations required at inference time
|
||||
*/
|
||||
long estimatedNumOperations();
|
||||
}
|
||||
|
|
|
@ -5,6 +5,9 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Accountables;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -12,7 +15,16 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
|||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
|
@ -20,14 +32,20 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.OptionalDouble;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
|
||||
|
||||
public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Ensemble.class);
|
||||
// TODO should we have regression/classification sub-classes that accept the builder?
|
||||
public static final ParseField NAME = new ParseField("ensemble");
|
||||
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
|
||||
|
@ -106,14 +124,18 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
}
|
||||
|
||||
@Override
|
||||
public double infer(Map<String, Object> fields) {
|
||||
List<Double> processedInferences = inferAndProcess(fields);
|
||||
return outputAggregator.aggregate(processedInferences);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double infer(List<Double> fields) {
|
||||
throw new UnsupportedOperationException("Ensemble requires map containing field names and values");
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
|
||||
if (config.isTargetTypeSupported(targetType) == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
|
||||
}
|
||||
List<Double> inferenceResults = this.models.stream().map(model -> {
|
||||
InferenceResults results = model.infer(fields, NullInferenceConfig.INSTANCE);
|
||||
assert results instanceof SingleValueInferenceResults;
|
||||
return ((SingleValueInferenceResults)results).value();
|
||||
}).collect(Collectors.toList());
|
||||
List<Double> processed = outputAggregator.processValues(inferenceResults);
|
||||
return buildResults(processed, config);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -121,18 +143,27 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
return targetType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> classificationProbability(Map<String, Object> fields) {
|
||||
if ((targetType == TargetType.CLASSIFICATION) == false) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
|
||||
private InferenceResults buildResults(List<Double> processedInferences, InferenceConfig config) {
|
||||
// Indicates that the config is useless and the caller just wants the raw value
|
||||
if (config instanceof NullInferenceConfig) {
|
||||
return new RawInferenceResults(outputAggregator.aggregate(processedInferences));
|
||||
}
|
||||
switch(targetType) {
|
||||
case REGRESSION:
|
||||
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences));
|
||||
case CLASSIFICATION:
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
List<ClassificationInferenceResults.TopClassEntry> topClasses = InferenceHelpers.topClasses(
|
||||
processedInferences,
|
||||
classificationLabels,
|
||||
classificationConfig.getNumTopClasses());
|
||||
double value = outputAggregator.aggregate(processedInferences);
|
||||
return new ClassificationInferenceResults(outputAggregator.aggregate(processedInferences),
|
||||
classificationLabel(value, classificationLabels),
|
||||
topClasses);
|
||||
default:
|
||||
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
|
||||
}
|
||||
return inferAndProcess(fields);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> classificationProbability(List<Double> fields) {
|
||||
throw new UnsupportedOperationException("Ensemble requires map containing field names and values");
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -140,11 +171,6 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
return classificationLabels;
|
||||
}
|
||||
|
||||
private List<Double> inferAndProcess(Map<String, Object> fields) {
|
||||
List<Double> modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList());
|
||||
return outputAggregator.processValues(modelInferences);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -204,6 +230,13 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
|
||||
@Override
|
||||
public void validate() {
|
||||
if (outputAggregator.compatibleWith(targetType) == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"aggregate_output [{}] is not compatible with target_type [{}]",
|
||||
this.targetType,
|
||||
outputAggregator.getName()
|
||||
);
|
||||
}
|
||||
if (outputAggregator.expectedValueSize() != null &&
|
||||
outputAggregator.expectedValueSize() != models.size()) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
|
@ -219,10 +252,38 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
this.models.forEach(TrainedModel::validate);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long estimatedNumOperations() {
|
||||
OptionalDouble avg = models.stream().mapToLong(TrainedModel::estimatedNumOperations).average();
|
||||
assert avg.isPresent() : "unexpected null when calculating number of operations";
|
||||
// Average operations for each model and the operations required for processing and aggregating with the outputAggregator
|
||||
return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
size += RamUsageEstimator.sizeOfCollection(featureNames);
|
||||
size += RamUsageEstimator.sizeOfCollection(classificationLabels);
|
||||
size += RamUsageEstimator.sizeOfCollection(models);
|
||||
size += outputAggregator.ramBytesUsed();
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<Accountable> getChildResources() {
|
||||
List<Accountable> accountables = new ArrayList<>(models.size() + 1);
|
||||
for (TrainedModel model : models) {
|
||||
accountables.add(Accountables.namedAccountable(model.getName(), model));
|
||||
}
|
||||
accountables.add(Accountables.namedAccountable(outputAggregator.getName(), outputAggregator));
|
||||
return Collections.unmodifiableCollection(accountables);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private List<String> featureNames;
|
||||
private List<TrainedModel> trainedModels;
|
||||
|
|
|
@ -6,16 +6,17 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.IntStream;
|
||||
|
@ -24,6 +25,7 @@ import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.sigmoid
|
|||
|
||||
public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class);
|
||||
public static final ParseField NAME = new ParseField("logistic_regression");
|
||||
public static final ParseField WEIGHTS = new ParseField("weights");
|
||||
|
||||
|
@ -48,19 +50,23 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie
|
|||
return LENIENT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final List<Double> weights;
|
||||
private final double[] weights;
|
||||
|
||||
LogisticRegression() {
|
||||
this((List<Double>) null);
|
||||
}
|
||||
|
||||
public LogisticRegression(List<Double> weights) {
|
||||
this.weights = weights == null ? null : Collections.unmodifiableList(weights);
|
||||
private LogisticRegression(List<Double> weights) {
|
||||
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray());
|
||||
}
|
||||
|
||||
public LogisticRegression(double[] weights) {
|
||||
this.weights = weights;
|
||||
}
|
||||
|
||||
public LogisticRegression(StreamInput in) throws IOException {
|
||||
if (in.readBoolean()) {
|
||||
this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble));
|
||||
this.weights = in.readDoubleArray();
|
||||
} else {
|
||||
this.weights = null;
|
||||
}
|
||||
|
@ -68,18 +74,18 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie
|
|||
|
||||
@Override
|
||||
public Integer expectedValueSize() {
|
||||
return this.weights == null ? null : this.weights.size();
|
||||
return this.weights == null ? null : this.weights.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> processValues(List<Double> values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
if (weights != null && values.size() != weights.size()) {
|
||||
if (weights != null && values.size() != weights.length) {
|
||||
throw new IllegalArgumentException("values must be the same length as weights.");
|
||||
}
|
||||
double summation = weights == null ?
|
||||
values.stream().mapToDouble(Double::valueOf).sum() :
|
||||
IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).sum();
|
||||
IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).sum();
|
||||
double probOfClassOne = sigmoid(summation);
|
||||
assert 0.0 <= probOfClassOne && probOfClassOne <= 1.0;
|
||||
return Arrays.asList(1.0 - probOfClassOne, probOfClassOne);
|
||||
|
@ -108,6 +114,11 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean compatibleWith(TargetType targetType) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -117,7 +128,7 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie
|
|||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeBoolean(weights != null);
|
||||
if (weights != null) {
|
||||
out.writeCollection(weights, StreamOutput::writeDouble);
|
||||
out.writeDoubleArray(weights);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -136,12 +147,17 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
LogisticRegression that = (LogisticRegression) o;
|
||||
return Objects.equals(weights, that.weights);
|
||||
return Arrays.equals(weights, that.weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(weights);
|
||||
return Arrays.hashCode(weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights);
|
||||
return SHALLOW_SIZE + weightSize;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,12 +5,14 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface OutputAggregator extends NamedXContentObject, NamedWriteable {
|
||||
public interface OutputAggregator extends NamedXContentObject, NamedWriteable, Accountable {
|
||||
|
||||
/**
|
||||
* @return The expected size of the values array when aggregating. `null` implies there is no expected size.
|
||||
|
@ -44,4 +46,6 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable {
|
|||
* @return The name of the output aggregator
|
||||
*/
|
||||
String getName();
|
||||
|
||||
boolean compatibleWith(TargetType targetType);
|
||||
}
|
||||
|
|
|
@ -6,15 +6,18 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
@ -23,6 +26,7 @@ import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax
|
|||
|
||||
public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedMode.class);
|
||||
public static final ParseField NAME = new ParseField("weighted_mode");
|
||||
public static final ParseField WEIGHTS = new ParseField("weights");
|
||||
|
||||
|
@ -47,19 +51,23 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
return LENIENT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final List<Double> weights;
|
||||
private final double[] weights;
|
||||
|
||||
WeightedMode() {
|
||||
this.weights = null;
|
||||
this((List<Double>) null);
|
||||
}
|
||||
|
||||
public WeightedMode(List<Double> weights) {
|
||||
this.weights = weights == null ? null : Collections.unmodifiableList(weights);
|
||||
private WeightedMode(List<Double> weights) {
|
||||
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray());
|
||||
}
|
||||
|
||||
public WeightedMode(double[] weights) {
|
||||
this.weights = weights;
|
||||
}
|
||||
|
||||
public WeightedMode(StreamInput in) throws IOException {
|
||||
if (in.readBoolean()) {
|
||||
this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble));
|
||||
this.weights = in.readDoubleArray();
|
||||
} else {
|
||||
this.weights = null;
|
||||
}
|
||||
|
@ -67,13 +75,13 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
|
||||
@Override
|
||||
public Integer expectedValueSize() {
|
||||
return this.weights == null ? null : this.weights.size();
|
||||
return this.weights == null ? null : this.weights.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> processValues(List<Double> values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
if (weights != null && values.size() != weights.size()) {
|
||||
if (weights != null && values.size() != weights.length) {
|
||||
throw new IllegalArgumentException("values must be the same length as weights.");
|
||||
}
|
||||
List<Integer> freqArray = new ArrayList<>();
|
||||
|
@ -93,7 +101,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
}
|
||||
List<Double> frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY));
|
||||
for (int i = 0; i < freqArray.size(); i++) {
|
||||
Double weight = weights == null ? 1.0 : weights.get(i);
|
||||
Double weight = weights == null ? 1.0 : weights[i];
|
||||
Integer value = freqArray.get(i);
|
||||
Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight;
|
||||
frequencies.set(value, frequency);
|
||||
|
@ -123,6 +131,11 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean compatibleWith(TargetType targetType) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -132,7 +145,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeBoolean(weights != null);
|
||||
if (weights != null) {
|
||||
out.writeCollection(weights, StreamOutput::writeDouble);
|
||||
out.writeDoubleArray(weights);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -151,11 +164,17 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
WeightedMode that = (WeightedMode) o;
|
||||
return Objects.equals(weights, that.weights);
|
||||
return Arrays.equals(weights, that.weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(weights);
|
||||
return Arrays.hashCode(weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights);
|
||||
return SHALLOW_SIZE + weightSize;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,15 +6,17 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
||||
|
||||
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
@ -23,6 +25,7 @@ import java.util.stream.IntStream;
|
|||
|
||||
public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedSum.class);
|
||||
public static final ParseField NAME = new ParseField("weighted_sum");
|
||||
public static final ParseField WEIGHTS = new ParseField("weights");
|
||||
|
||||
|
@ -47,19 +50,23 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar
|
|||
return LENIENT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final List<Double> weights;
|
||||
private final double[] weights;
|
||||
|
||||
WeightedSum() {
|
||||
this.weights = null;
|
||||
this((List<Double>) null);
|
||||
}
|
||||
|
||||
public WeightedSum(List<Double> weights) {
|
||||
this.weights = weights == null ? null : Collections.unmodifiableList(weights);
|
||||
private WeightedSum(List<Double> weights) {
|
||||
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray());
|
||||
}
|
||||
|
||||
public WeightedSum(double[] weights) {
|
||||
this.weights = weights;
|
||||
}
|
||||
|
||||
public WeightedSum(StreamInput in) throws IOException {
|
||||
if (in.readBoolean()) {
|
||||
this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble));
|
||||
this.weights = in.readDoubleArray();
|
||||
} else {
|
||||
this.weights = null;
|
||||
}
|
||||
|
@ -71,10 +78,10 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar
|
|||
if (weights == null) {
|
||||
return values;
|
||||
}
|
||||
if (values.size() != weights.size()) {
|
||||
if (values.size() != weights.length) {
|
||||
throw new IllegalArgumentException("values must be the same length as weights.");
|
||||
}
|
||||
return IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).boxed().collect(Collectors.toList());
|
||||
return IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).boxed().collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -104,7 +111,7 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar
|
|||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeBoolean(weights != null);
|
||||
if (weights != null) {
|
||||
out.writeCollection(weights, StreamOutput::writeDouble);
|
||||
out.writeDoubleArray(weights);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -123,16 +130,27 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
WeightedSum that = (WeightedSum) o;
|
||||
return Objects.equals(weights, that.weights);
|
||||
return Arrays.equals(weights, that.weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(weights);
|
||||
return Arrays.hashCode(weights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer expectedValueSize() {
|
||||
return weights == null ? null : this.weights.size();
|
||||
return weights == null ? null : this.weights.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean compatibleWith(TargetType targetType) {
|
||||
return TargetType.REGRESSION.equals(targetType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights);
|
||||
return SHALLOW_SIZE + weightSize;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,9 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Accountables;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -13,7 +16,15 @@ import org.elasticsearch.common.util.CachedSupplier;
|
|||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
@ -22,6 +33,7 @@ import java.io.IOException;
|
|||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
|
@ -31,8 +43,11 @@ import java.util.Queue;
|
|||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
|
||||
|
||||
public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel, Accountable {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Tree.class);
|
||||
// TODO should we have regression/classification sub-classes that accept the builder?
|
||||
public static final ParseField NAME = new ParseField("tree");
|
||||
|
||||
|
@ -72,7 +87,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
|
||||
Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
|
||||
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
|
||||
this.nodes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE));
|
||||
if(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) {
|
||||
throw new IllegalArgumentException("[tree_structure] must not be empty");
|
||||
}
|
||||
this.nodes = Collections.unmodifiableList(nodes);
|
||||
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
|
||||
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
|
||||
this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
|
||||
|
@ -105,20 +123,42 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
}
|
||||
|
||||
@Override
|
||||
public double infer(Map<String, Object> fields) {
|
||||
List<Double> features = featureNames.stream().map(f ->
|
||||
fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null
|
||||
).collect(Collectors.toList());
|
||||
return infer(features);
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
|
||||
if (config.isTargetTypeSupported(targetType) == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
|
||||
}
|
||||
|
||||
List<Double> features = featureNames.stream().map(f -> InferenceHelpers.toDouble(fields.get(f))).collect(Collectors.toList());
|
||||
return infer(features, config);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double infer(List<Double> features) {
|
||||
private InferenceResults infer(List<Double> features, InferenceConfig config) {
|
||||
TreeNode node = nodes.get(0);
|
||||
while(node.isLeaf() == false) {
|
||||
node = nodes.get(node.compare(features));
|
||||
}
|
||||
return node.getLeafValue();
|
||||
return buildResult(node.getLeafValue(), config);
|
||||
}
|
||||
|
||||
private InferenceResults buildResult(Double value, InferenceConfig config) {
|
||||
// Indicates that the config is useless and the caller just wants the raw value
|
||||
if (config instanceof NullInferenceConfig) {
|
||||
return new RawInferenceResults(value);
|
||||
}
|
||||
switch (targetType) {
|
||||
case CLASSIFICATION:
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
List<ClassificationInferenceResults.TopClassEntry> topClasses = InferenceHelpers.topClasses(
|
||||
classificationProbability(value),
|
||||
classificationLabels,
|
||||
classificationConfig.getNumTopClasses());
|
||||
return new ClassificationInferenceResults(value, classificationLabel(value, classificationLabels), topClasses);
|
||||
case REGRESSION:
|
||||
return new RegressionInferenceResults(value);
|
||||
default:
|
||||
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -142,34 +182,15 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
return targetType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> classificationProbability(Map<String, Object> fields) {
|
||||
if ((targetType == TargetType.CLASSIFICATION) == false) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
|
||||
}
|
||||
List<Double> features = featureNames.stream().map(f ->
|
||||
fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
return classificationProbability(features);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> classificationProbability(List<Double> fields) {
|
||||
if ((targetType == TargetType.CLASSIFICATION) == false) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
|
||||
}
|
||||
double label = infer(fields);
|
||||
private List<Double> classificationProbability(double inferenceValue) {
|
||||
// If we are classification, we should assume that the inference return value is whole.
|
||||
assert label == Math.rint(label);
|
||||
assert inferenceValue == Math.rint(inferenceValue);
|
||||
double maxCategory = this.highestOrderCategory.get();
|
||||
// If we are classification, we should assume that the largest leaf value is whole.
|
||||
assert maxCategory == Math.rint(maxCategory);
|
||||
List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
|
||||
// TODO, eventually have TreeNodes contain confidence levels
|
||||
list.set(Double.valueOf(label).intValue(), 1.0);
|
||||
list.set(Double.valueOf(inferenceValue).intValue(), 1.0);
|
||||
return list;
|
||||
}
|
||||
|
||||
|
@ -239,6 +260,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
detectCycle();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long estimatedNumOperations() {
|
||||
// Grabbing the features from the doc + the depth of the tree
|
||||
return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
|
||||
}
|
||||
|
||||
private void checkTargetType() {
|
||||
if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
|
@ -247,9 +274,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
}
|
||||
|
||||
private void detectCycle() {
|
||||
if (nodes.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
Set<Integer> visited = new HashSet<>(nodes.size());
|
||||
Queue<Integer> toVisit = new ArrayDeque<>(nodes.size());
|
||||
toVisit.add(0);
|
||||
|
@ -270,10 +294,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
}
|
||||
|
||||
private void detectMissingNodes() {
|
||||
if (nodes.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<Integer> missingNodes = new ArrayList<>();
|
||||
for (int i = 0; i < nodes.size(); i++) {
|
||||
TreeNode currentNode = nodes.get(i);
|
||||
|
@ -302,6 +322,24 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
size += RamUsageEstimator.sizeOfCollection(classificationLabels);
|
||||
size += RamUsageEstimator.sizeOfCollection(featureNames);
|
||||
size += RamUsageEstimator.sizeOfCollection(nodes);
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<Accountable> getChildResources() {
|
||||
List<Accountable> accountables = new ArrayList<>(nodes.size());
|
||||
for (TreeNode node : nodes) {
|
||||
accountables.add(Accountables.namedAccountable("tree_node_" + node.getNodeIndex(), node));
|
||||
}
|
||||
return Collections.unmodifiableCollection(accountables);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private List<String> featureNames;
|
||||
private ArrayList<TreeNode.Builder> nodes;
|
||||
|
|
|
@ -5,6 +5,9 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.Numbers;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -15,14 +18,15 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Operator;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TreeNode implements ToXContentObject, Writeable {
|
||||
|
||||
public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TreeNode.class);
|
||||
public static final String NAME = "tree_node";
|
||||
|
||||
public static final ParseField DECISION_TYPE = new ParseField("decision_type");
|
||||
|
@ -63,31 +67,31 @@ public class TreeNode implements ToXContentObject, Writeable {
|
|||
}
|
||||
|
||||
private final Operator operator;
|
||||
private final Double threshold;
|
||||
private final Integer splitFeature;
|
||||
private final double threshold;
|
||||
private final int splitFeature;
|
||||
private final int nodeIndex;
|
||||
private final Double splitGain;
|
||||
private final Double leafValue;
|
||||
private final double splitGain;
|
||||
private final double leafValue;
|
||||
private final boolean defaultLeft;
|
||||
private final int leftChild;
|
||||
private final int rightChild;
|
||||
|
||||
|
||||
TreeNode(Operator operator,
|
||||
Double threshold,
|
||||
Integer splitFeature,
|
||||
Integer nodeIndex,
|
||||
Double splitGain,
|
||||
Double leafValue,
|
||||
Boolean defaultLeft,
|
||||
Integer leftChild,
|
||||
Integer rightChild) {
|
||||
private TreeNode(Operator operator,
|
||||
Double threshold,
|
||||
Integer splitFeature,
|
||||
int nodeIndex,
|
||||
Double splitGain,
|
||||
Double leafValue,
|
||||
Boolean defaultLeft,
|
||||
Integer leftChild,
|
||||
Integer rightChild) {
|
||||
this.operator = operator == null ? Operator.LTE : operator;
|
||||
this.threshold = threshold;
|
||||
this.splitFeature = splitFeature;
|
||||
this.nodeIndex = ExceptionsHelper.requireNonNull(nodeIndex, NODE_INDEX.getPreferredName());
|
||||
this.splitGain = splitGain;
|
||||
this.leafValue = leafValue;
|
||||
this.threshold = threshold == null ? Double.NaN : threshold;
|
||||
this.splitFeature = splitFeature == null ? -1 : splitFeature;
|
||||
this.nodeIndex = nodeIndex;
|
||||
this.splitGain = splitGain == null ? Double.NaN : splitGain;
|
||||
this.leafValue = leafValue == null ? Double.NaN : leafValue;
|
||||
this.defaultLeft = defaultLeft == null ? false : defaultLeft;
|
||||
this.leftChild = leftChild == null ? -1 : leftChild;
|
||||
this.rightChild = rightChild == null ? -1 : rightChild;
|
||||
|
@ -95,11 +99,11 @@ public class TreeNode implements ToXContentObject, Writeable {
|
|||
|
||||
public TreeNode(StreamInput in) throws IOException {
|
||||
operator = Operator.readFromStream(in);
|
||||
threshold = in.readOptionalDouble();
|
||||
splitFeature = in.readOptionalInt();
|
||||
splitGain = in.readOptionalDouble();
|
||||
nodeIndex = in.readInt();
|
||||
leafValue = in.readOptionalDouble();
|
||||
threshold = in.readDouble();
|
||||
splitFeature = in.readInt();
|
||||
splitGain = in.readDouble();
|
||||
nodeIndex = in.readVInt();
|
||||
leafValue = in.readDouble();
|
||||
defaultLeft = in.readBoolean();
|
||||
leftChild = in.readInt();
|
||||
rightChild = in.readInt();
|
||||
|
@ -110,23 +114,23 @@ public class TreeNode implements ToXContentObject, Writeable {
|
|||
return operator;
|
||||
}
|
||||
|
||||
public Double getThreshold() {
|
||||
public double getThreshold() {
|
||||
return threshold;
|
||||
}
|
||||
|
||||
public Integer getSplitFeature() {
|
||||
public int getSplitFeature() {
|
||||
return splitFeature;
|
||||
}
|
||||
|
||||
public Integer getNodeIndex() {
|
||||
public int getNodeIndex() {
|
||||
return nodeIndex;
|
||||
}
|
||||
|
||||
public Double getSplitGain() {
|
||||
public double getSplitGain() {
|
||||
return splitGain;
|
||||
}
|
||||
|
||||
public Double getLeafValue() {
|
||||
public double getLeafValue() {
|
||||
return leafValue;
|
||||
}
|
||||
|
||||
|
@ -164,11 +168,11 @@ public class TreeNode implements ToXContentObject, Writeable {
|
|||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
operator.writeTo(out);
|
||||
out.writeOptionalDouble(threshold);
|
||||
out.writeOptionalInt(splitFeature);
|
||||
out.writeOptionalDouble(splitGain);
|
||||
out.writeInt(nodeIndex);
|
||||
out.writeOptionalDouble(leafValue);
|
||||
out.writeDouble(threshold);
|
||||
out.writeInt(splitFeature);
|
||||
out.writeDouble(splitGain);
|
||||
out.writeVInt(nodeIndex);
|
||||
out.writeDouble(leafValue);
|
||||
out.writeBoolean(defaultLeft);
|
||||
out.writeInt(leftChild);
|
||||
out.writeInt(rightChild);
|
||||
|
@ -177,12 +181,14 @@ public class TreeNode implements ToXContentObject, Writeable {
|
|||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
addOptionalField(builder, DECISION_TYPE, operator);
|
||||
addOptionalField(builder, THRESHOLD, threshold);
|
||||
addOptionalField(builder, SPLIT_FEATURE, splitFeature);
|
||||
addOptionalField(builder, SPLIT_GAIN, splitGain);
|
||||
builder.field(DECISION_TYPE.getPreferredName(), operator);
|
||||
addOptionalDouble(builder, THRESHOLD, threshold);
|
||||
if (splitFeature > -1) {
|
||||
builder.field(SPLIT_FEATURE.getPreferredName(), splitFeature);
|
||||
}
|
||||
addOptionalDouble(builder, SPLIT_GAIN, splitGain);
|
||||
builder.field(NODE_INDEX.getPreferredName(), nodeIndex);
|
||||
addOptionalField(builder, LEAF_VALUE, leafValue);
|
||||
addOptionalDouble(builder, LEAF_VALUE, leafValue);
|
||||
builder.field(DEFAULT_LEFT.getPreferredName(), defaultLeft);
|
||||
if (leftChild >= 0) {
|
||||
builder.field(LEFT_CHILD.getPreferredName(), leftChild);
|
||||
|
@ -194,8 +200,8 @@ public class TreeNode implements ToXContentObject, Writeable {
|
|||
return builder;
|
||||
}
|
||||
|
||||
private void addOptionalField(XContentBuilder builder, ParseField field, Object value) throws IOException {
|
||||
if (value != null) {
|
||||
private void addOptionalDouble(XContentBuilder builder, ParseField field, double value) throws IOException {
|
||||
if (Numbers.isValidDouble(value)) {
|
||||
builder.field(field.getPreferredName(), value);
|
||||
}
|
||||
}
|
||||
|
@ -237,7 +243,12 @@ public class TreeNode implements ToXContentObject, Writeable {
|
|||
public static Builder builder(int nodeIndex) {
|
||||
return new Builder(nodeIndex);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return SHALLOW_SIZE;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private Operator operator;
|
||||
private Double threshold;
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.utils;
|
||||
|
||||
import org.elasticsearch.common.Numbers;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
@ -22,24 +24,24 @@ public final class Statistics {
|
|||
*/
|
||||
public static List<Double> softMax(List<Double> values) {
|
||||
Double expSum = 0.0;
|
||||
Double max = values.stream().filter(v -> isInvalid(v) == false).max(Double::compareTo).orElse(null);
|
||||
Double max = values.stream().filter(Statistics::isValid).max(Double::compareTo).orElse(null);
|
||||
if (max == null) {
|
||||
throw new IllegalArgumentException("no valid values present");
|
||||
}
|
||||
List<Double> exps = values.stream().map(v -> isInvalid(v) ? Double.NEGATIVE_INFINITY : v - max)
|
||||
List<Double> exps = values.stream().map(v -> isValid(v) ? v - max : Double.NEGATIVE_INFINITY)
|
||||
.collect(Collectors.toList());
|
||||
for (int i = 0; i < exps.size(); i++) {
|
||||
if (isInvalid(exps.get(i)) == false) {
|
||||
if (isValid(exps.get(i))) {
|
||||
Double exp = Math.exp(exps.get(i));
|
||||
expSum += exp;
|
||||
exps.set(i, exp);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < exps.size(); i++) {
|
||||
if (isInvalid(exps.get(i))) {
|
||||
exps.set(i, 0.0);
|
||||
} else {
|
||||
if (isValid(exps.get(i))) {
|
||||
exps.set(i, exps.get(i)/expSum);
|
||||
} else {
|
||||
exps.set(i, 0.0);
|
||||
}
|
||||
}
|
||||
return exps;
|
||||
|
@ -49,8 +51,8 @@ public final class Statistics {
|
|||
return 1/(1 + Math.exp(-value));
|
||||
}
|
||||
|
||||
public static boolean isInvalid(Double v) {
|
||||
return v == null || Double.isInfinite(v) || Double.isNaN(v);
|
||||
private static boolean isValid(Double v) {
|
||||
return v != null && Numbers.isValidDouble(v);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -83,7 +83,13 @@ public final class Messages {
|
|||
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
|
||||
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
|
||||
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
|
||||
public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";
|
||||
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
|
||||
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
|
||||
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
|
||||
public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]";
|
||||
public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED =
|
||||
"Getting model definition is not supported when getting more than one model";
|
||||
|
||||
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
|
||||
public static final String JOB_AUDIT_CREATED = "Job created";
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.notifications;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage;
|
||||
import org.elasticsearch.xpack.core.common.notifications.Level;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Job;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
|
||||
public class InferenceAuditMessage extends AbstractAuditMessage {
|
||||
|
||||
//TODO this should be MODEL_ID...
|
||||
private static final ParseField JOB_ID = Job.ID;
|
||||
public static final ConstructingObjectParser<InferenceAuditMessage, Void> PARSER =
|
||||
createParser("ml_inference_audit_message", InferenceAuditMessage::new, JOB_ID);
|
||||
|
||||
public InferenceAuditMessage(String resourceId, String message, Level level, Date timestamp, String nodeName) {
|
||||
super(resourceId, message, level, timestamp, nodeName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final String getJobType() {
|
||||
return "inference";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getResourceField() {
|
||||
return JOB_ID.getPreferredName();
|
||||
}
|
||||
}
|
|
@ -43,6 +43,10 @@ public class ExceptionsHelper {
|
|||
return new ResourceAlreadyExistsException("A data frame analytics with id [{}] already exists", id);
|
||||
}
|
||||
|
||||
public static ResourceNotFoundException missingTrainedModel(String modelId) {
|
||||
return new ResourceNotFoundException("No known trained model with model_id [{}]", modelId);
|
||||
}
|
||||
|
||||
public static ElasticsearchException serverError(String msg) {
|
||||
return new ElasticsearchException(msg);
|
||||
}
|
||||
|
@ -51,6 +55,10 @@ public class ExceptionsHelper {
|
|||
return new ElasticsearchException(msg, cause);
|
||||
}
|
||||
|
||||
public static ElasticsearchException serverError(String msg, Throwable cause, Object... args) {
|
||||
return new ElasticsearchException(msg, cause, args);
|
||||
}
|
||||
|
||||
public static ElasticsearchStatusException conflictStatusException(String msg, Throwable cause, Object... args) {
|
||||
return new ElasticsearchStatusException(msg, RestStatus.CONFLICT, cause, args);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction.Request;
|
||||
|
||||
public class DeleteTrainedModelsRequestTests extends AbstractWireSerializingTestCase<Request> {
|
||||
|
||||
@Override
|
||||
protected Request createTestInstance() {
|
||||
return new Request(randomAlphaOfLengthBetween(1, 20));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Request> instanceReader() {
|
||||
return Request::new;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
|
||||
|
||||
public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCase<Request> {
|
||||
|
||||
@Override
|
||||
protected Request createTestInstance() {
|
||||
Request request = new Request(randomAlphaOfLength(20), randomBoolean());
|
||||
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
|
||||
return request;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Request> instanceReader() {
|
||||
return Request::new;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.ingest.IngestStats;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.action.util.QueryPage;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class GetTrainedModelsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> {
|
||||
|
||||
@Override
|
||||
protected Response createTestInstance() {
|
||||
int listSize = randomInt(10);
|
||||
List<Response.TrainedModelStats> trainedModelStats = Stream.generate(() -> randomAlphaOfLength(10))
|
||||
.limit(listSize).map(id ->
|
||||
new Response.TrainedModelStats(id,
|
||||
randomBoolean() ? randomIngestStats() : null,
|
||||
randomIntBetween(0, 10))
|
||||
)
|
||||
.collect(Collectors.toList());
|
||||
return new Response(new QueryPage<>(trainedModelStats, randomLongBetween(listSize, 1000), Response.RESULTS_FIELD));
|
||||
}
|
||||
|
||||
private IngestStats randomIngestStats() {
|
||||
List<String> pipelineIds = Stream.generate(()-> randomAlphaOfLength(10))
|
||||
.limit(randomIntBetween(0, 10))
|
||||
.collect(Collectors.toList());
|
||||
return new IngestStats(
|
||||
new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()),
|
||||
pipelineIds.stream().map(id -> new IngestStats.PipelineStat(id, randomStats())).collect(Collectors.toList()),
|
||||
pipelineIds.stream().collect(Collectors.toMap(Function.identity(), (v) -> randomProcessorStats())));
|
||||
}
|
||||
|
||||
private IngestStats.Stats randomStats(){
|
||||
return new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong());
|
||||
}
|
||||
|
||||
private List<IngestStats.ProcessorStat> randomProcessorStats() {
|
||||
return Stream.generate(() -> randomAlphaOfLength(10))
|
||||
.limit(randomIntBetween(0, 10))
|
||||
.map(name -> new IngestStats.ProcessorStat(name, "inference", randomStats()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Response> instanceReader() {
|
||||
return Response::new;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class InferModelActionRequestTests extends AbstractWireSerializingTestCase<Request> {
|
||||
|
||||
@Override
|
||||
protected Request createTestInstance() {
|
||||
return randomBoolean() ?
|
||||
new Request(
|
||||
randomAlphaOfLength(10),
|
||||
Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()),
|
||||
randomInferenceConfig()) :
|
||||
new Request(
|
||||
randomAlphaOfLength(10),
|
||||
randomMap(),
|
||||
randomInferenceConfig());
|
||||
}
|
||||
|
||||
private static InferenceConfig randomInferenceConfig() {
|
||||
return randomFrom(RegressionConfigTests.randomRegressionConfig(), ClassificationConfigTests.randomClassificationConfig());
|
||||
}
|
||||
|
||||
private static Map<String, Object> randomMap() {
|
||||
return Stream.generate(()-> randomAlphaOfLength(10))
|
||||
.limit(randomInt(10))
|
||||
.collect(Collectors.toMap(Function.identity(), (v) -> randomAlphaOfLength(10)));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Request> instanceReader() {
|
||||
return Request::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(entries);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class InferModelActionResponseTests extends AbstractWireSerializingTestCase<Response> {
|
||||
|
||||
@Override
|
||||
protected Response createTestInstance() {
|
||||
String resultType = randomFrom(ClassificationInferenceResults.NAME, RegressionInferenceResults.NAME);
|
||||
return new Response(
|
||||
Stream.generate(() -> randomInferenceResult(resultType))
|
||||
.limit(randomIntBetween(0, 10))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
private static InferenceResults randomInferenceResult(String resultType) {
|
||||
if (resultType.equals(ClassificationInferenceResults.NAME)) {
|
||||
return ClassificationInferenceResultsTests.createRandomResults();
|
||||
} else if (resultType.equals(RegressionInferenceResults.NAME)) {
|
||||
return RegressionInferenceResultsTests.createRandomResults();
|
||||
} else {
|
||||
fail("unexpected result type [" + resultType + "]");
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Response> instanceReader() {
|
||||
return Response::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(entries);
|
||||
}
|
||||
|
||||
}
|
|
@ -33,6 +33,7 @@ import java.util.function.Predicate;
|
|||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
@ -73,7 +74,9 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
|||
null, // is not parsed so should not be provided
|
||||
tags,
|
||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
|
||||
TrainedModelInputTests.createRandomInput());
|
||||
TrainedModelInputTests.createRandomInput(),
|
||||
randomNonNegativeLong(),
|
||||
randomNonNegativeLong());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -96,6 +99,16 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
|||
return new NamedWriteableRegistry(entries);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ToXContent.Params getToXContentParams() {
|
||||
return lenient ? ToXContent.EMPTY_PARAMS : new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean assertToXContentEquivalence() {
|
||||
return false;
|
||||
}
|
||||
|
||||
public void testToXContentWithParams() throws IOException {
|
||||
TrainedModelConfig config = new TrainedModelConfig(
|
||||
randomAlphaOfLength(10),
|
||||
|
@ -106,7 +119,9 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
|||
TrainedModelDefinitionTests.createRandomBuilder(randomAlphaOfLength(10)).build(),
|
||||
Collections.emptyList(),
|
||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
|
||||
TrainedModelInputTests.createRandomInput());
|
||||
TrainedModelInputTests.createRandomInput(),
|
||||
randomNonNegativeLong(),
|
||||
randomNonNegativeLong());
|
||||
|
||||
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
|
||||
assertThat(reference.utf8ToString(), containsString("definition"));
|
||||
|
|
|
@ -5,12 +5,16 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference;
|
||||
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.DeprecationHandler;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
|
@ -19,9 +23,9 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding
|
|||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
|
@ -31,26 +35,22 @@ import java.util.function.Predicate;
|
|||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
|
||||
|
||||
public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<TrainedModelDefinition> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException {
|
||||
return TrainedModelDefinition.fromXContent(parser, lenient).build();
|
||||
return TrainedModelDefinition.fromXContent(parser, true).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -58,6 +58,16 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
|
|||
return field -> !field.isEmpty();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ToXContent.Params getToXContentParams() {
|
||||
return new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean assertToXContentEquivalence() {
|
||||
return false;
|
||||
}
|
||||
|
||||
public static TrainedModelDefinition.Builder createRandomBuilder(String modelId) {
|
||||
int numberOfProcessors = randomIntBetween(1, 10);
|
||||
return new TrainedModelDefinition.Builder()
|
||||
|
@ -69,7 +79,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
|
|||
TargetMeanEncodingTests.createRandom()))
|
||||
.limit(numberOfProcessors)
|
||||
.collect(Collectors.toList()))
|
||||
.setTrainedModel(randomFrom(TreeTests.createRandom()));
|
||||
.setTrainedModel(randomFrom(TreeTests.createRandom(), EnsembleTests.createRandom()));
|
||||
}
|
||||
|
||||
private static final String ENSEMBLE_MODEL = "" +
|
||||
|
@ -273,9 +283,27 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
|
|||
assertThat(definition.getTrainedModel().getClass(), equalTo(Tree.class));
|
||||
}
|
||||
|
||||
public void testStrictParser() throws IOException {
|
||||
TrainedModelDefinition.Builder builder = createRandomBuilder("asdf");
|
||||
BytesReference reference = XContentHelper.toXContent(builder.build(),
|
||||
XContentType.JSON,
|
||||
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")),
|
||||
false);
|
||||
|
||||
XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
||||
reference,
|
||||
XContentType.JSON);
|
||||
|
||||
XContentParseException exception = expectThrows(XContentParseException.class,
|
||||
() -> TrainedModelDefinition.fromXContent(parser, false));
|
||||
|
||||
assertThat(exception.getMessage(), containsString("[trained_model_definition] unknown field [doc_type]"));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TrainedModelDefinition createTestInstance() {
|
||||
return createRandomBuilder(null).build();
|
||||
return createRandomBuilder(randomAlphaOfLength(10)).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -298,4 +326,9 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
|
|||
return new NamedWriteableRegistry(entries);
|
||||
}
|
||||
|
||||
public void testRamUsageEstimation() {
|
||||
TrainedModelDefinition test = createTestInstance();
|
||||
assertThat(test.ramBytesUsed(), greaterThan(0L));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase<ClassificationInferenceResults> {
|
||||
|
||||
public static ClassificationInferenceResults createRandomResults() {
|
||||
return new ClassificationInferenceResults(randomDouble(),
|
||||
randomBoolean() ? null : randomAlphaOfLength(10),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(ClassificationInferenceResultsTests::createRandomClassEntry)
|
||||
.limit(randomIntBetween(0, 10))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
private static ClassificationInferenceResults.TopClassEntry createRandomClassEntry() {
|
||||
return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble());
|
||||
}
|
||||
|
||||
public void testWriteResultsWithClassificationLabel() {
|
||||
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, "foo", Collections.emptyList());
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
|
||||
assertThat(document.getFieldValue("result_field", String.class), equalTo("foo"));
|
||||
}
|
||||
|
||||
public void testWriteResultsWithoutClassificationLabel() {
|
||||
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, Collections.emptyList());
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
|
||||
assertThat(document.getFieldValue("result_field", String.class), equalTo("1.0"));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testWriteResultsWithTopClasses() {
|
||||
List<ClassificationInferenceResults.TopClassEntry> entries = Arrays.asList(
|
||||
new ClassificationInferenceResults.TopClassEntry("foo", 0.7),
|
||||
new ClassificationInferenceResults.TopClassEntry("bar", 0.2),
|
||||
new ClassificationInferenceResults.TopClassEntry("baz", 0.1));
|
||||
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
|
||||
"foo",
|
||||
entries);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
|
||||
List<?> list = document.getFieldValue("result_field", List.class);
|
||||
assertThat(list.size(), equalTo(3));
|
||||
|
||||
for(int i = 0; i < 3; i++) {
|
||||
Map<String, Object> map = (Map<String, Object>)list.get(i);
|
||||
assertThat(map, equalTo(entries.get(i).asValueMap()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClassificationInferenceResults createTestInstance() {
|
||||
return createRandomResults();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<ClassificationInferenceResults> instanceReader() {
|
||||
return ClassificationInferenceResults::new;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
|
||||
public class RawInferenceResultsTests extends AbstractWireSerializingTestCase<RawInferenceResults> {
|
||||
|
||||
public static RawInferenceResults createRandomResults() {
|
||||
return new RawInferenceResults(randomDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RawInferenceResults createTestInstance() {
|
||||
return createRandomResults();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<RawInferenceResults> instanceReader() {
|
||||
return RawInferenceResults::new;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
|
||||
public class RegressionInferenceResultsTests extends AbstractWireSerializingTestCase<RegressionInferenceResults> {
|
||||
|
||||
public static RegressionInferenceResults createRandomResults() {
|
||||
return new RegressionInferenceResults(randomDouble());
|
||||
}
|
||||
|
||||
public void testWriteResults() {
|
||||
RegressionInferenceResults result = new RegressionInferenceResults(0.3);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
|
||||
assertThat(document.getFieldValue("result_field", Double.class), equalTo(0.3));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RegressionInferenceResults createTestInstance() {
|
||||
return createRandomResults();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<RegressionInferenceResults> instanceReader() {
|
||||
return RegressionInferenceResults::new;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class ClassificationConfigTests extends AbstractWireSerializingTestCase<ClassificationConfig> {
|
||||
|
||||
public static ClassificationConfig randomClassificationConfig() {
|
||||
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10));
|
||||
}
|
||||
|
||||
public void testFromMap() {
|
||||
ClassificationConfig expected = new ClassificationConfig(0);
|
||||
assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected));
|
||||
|
||||
expected = new ClassificationConfig(3);
|
||||
assertThat(ClassificationConfig.fromMap(Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3)),
|
||||
equalTo(expected));
|
||||
}
|
||||
|
||||
public void testFromMapWithUnknownField() {
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||
() -> ClassificationConfig.fromMap(Collections.singletonMap("some_key", 1)));
|
||||
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClassificationConfig createTestInstance() {
|
||||
return randomClassificationConfig();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<ClassificationConfig> instanceReader() {
|
||||
return ClassificationConfig::new;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
|
||||
public class InferenceHelpersTests extends ESTestCase {
|
||||
|
||||
public void testToDoubleFromNumbers() {
|
||||
assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5)));
|
||||
assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5)));
|
||||
assertThat(5.0, equalTo(InferenceHelpers.toDouble(5L)));
|
||||
assertThat(5.0, equalTo(InferenceHelpers.toDouble(5)));
|
||||
assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5f)));
|
||||
}
|
||||
|
||||
public void testToDoubleFromString() {
|
||||
assertThat(0.5, equalTo(InferenceHelpers.toDouble("0.5")));
|
||||
assertThat(-0.5, equalTo(InferenceHelpers.toDouble("-0.5")));
|
||||
assertThat(5.0, equalTo(InferenceHelpers.toDouble("5")));
|
||||
assertThat(-5.0, equalTo(InferenceHelpers.toDouble("-5")));
|
||||
|
||||
// if ae are turned off, then we should get a null value
|
||||
// otherwise, we should expect an assertion failure telling us that the string is improperly formatted
|
||||
try {
|
||||
assertThat(InferenceHelpers.toDouble(""), is(nullValue()));
|
||||
} catch (AssertionError ae) {
|
||||
assertThat(ae.getMessage(), equalTo("value is not properly formatted double []"));
|
||||
}
|
||||
try {
|
||||
assertThat(InferenceHelpers.toDouble("notADouble"), is(nullValue()));
|
||||
} catch (AssertionError ae) {
|
||||
assertThat(ae.getMessage(), equalTo("value is not properly formatted double [notADouble]"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testToDoubleFromNull() {
|
||||
assertThat(InferenceHelpers.toDouble(null), is(nullValue()));
|
||||
}
|
||||
|
||||
public void testDoubleFromUnknownObj() {
|
||||
assertThat(InferenceHelpers.toDouble(new HashMap<>()), is(nullValue()));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class RegressionConfigTests extends AbstractWireSerializingTestCase<RegressionConfig> {
|
||||
|
||||
public static RegressionConfig randomRegressionConfig() {
|
||||
return new RegressionConfig();
|
||||
}
|
||||
|
||||
public void testFromMap() {
|
||||
RegressionConfig expected = new RegressionConfig();
|
||||
assertThat(RegressionConfig.fromMap(Collections.emptyMap()), equalTo(expected));
|
||||
}
|
||||
|
||||
public void testFromMapWithUnknownField() {
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||
() -> RegressionConfig.fromMap(Collections.singletonMap("some_key", 1)));
|
||||
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RegressionConfig createTestInstance() {
|
||||
return randomRegressionConfig();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<RegressionConfig> instanceReader() {
|
||||
return RegressionConfig::new;
|
||||
}
|
||||
|
||||
}
|
|
@ -15,13 +15,16 @@ import org.elasticsearch.search.SearchModule;
|
|||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -68,9 +71,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
|
||||
.limit(numberOfModels)
|
||||
.collect(Collectors.toList());
|
||||
List<Double> weights = randomBoolean() ?
|
||||
double[] weights = randomBoolean() ?
|
||||
null :
|
||||
Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
|
||||
Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).mapToDouble(Double::valueOf).toArray();
|
||||
OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights),
|
||||
new WeightedSum(weights),
|
||||
new LogisticRegression(weights));
|
||||
|
@ -114,9 +117,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
int numberOfModels = 5;
|
||||
List<Double> weights = new ArrayList<>(numberOfModels + 2);
|
||||
double[] weights = new double[numberOfModels + 2];
|
||||
for (int i = 0; i < numberOfModels + 2; i++) {
|
||||
weights.add(randomDouble());
|
||||
weights[i] = randomDouble();
|
||||
}
|
||||
OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
|
||||
|
||||
|
@ -158,6 +161,27 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
});
|
||||
}
|
||||
|
||||
public void testEnsembleWithAggregatorOutputNotSupportingTargetType() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Ensemble.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(
|
||||
Tree.builder()
|
||||
.setNodes(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()))
|
||||
.setFeatureNames(featureNames)
|
||||
.build()))
|
||||
.setClassificationLabels(Arrays.asList("label1", "label2"))
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setOutputAggregator(new WeightedSum())
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
}
|
||||
|
||||
public void testEnsembleWithTargetTypeAndLabelsMismatch() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa";
|
||||
|
@ -189,6 +213,7 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
.setFeatureNames(featureNames)
|
||||
.build()))
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setOutputAggregator(new WeightedMode())
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
|
@ -236,32 +261,35 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0)))
|
||||
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}))
|
||||
.build();
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
List<Double> expected = Arrays.asList(0.231475216, 0.768524783);
|
||||
List<Double> expected = Arrays.asList(0.768524783, 0.231475216);
|
||||
double eps = 0.000001;
|
||||
List<Double> probabilities = ensemble.classificationProbability(featureMap);
|
||||
List<ClassificationInferenceResults.TopClassEntry> probabilities =
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
||||
}
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
expected = Arrays.asList(0.3100255188, 0.689974481);
|
||||
probabilities = ensemble.classificationProbability(featureMap);
|
||||
expected = Arrays.asList(0.689974481, 0.3100255188);
|
||||
probabilities =
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
||||
}
|
||||
|
||||
featureVector = Arrays.asList(0.0, 1.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
expected = Arrays.asList(0.231475216, 0.768524783);
|
||||
probabilities = ensemble.classificationProbability(featureMap);
|
||||
expected = Arrays.asList(0.768524783, 0.231475216);
|
||||
probabilities =
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
||||
}
|
||||
|
||||
// This should handle missing values and take the default_left path
|
||||
|
@ -270,9 +298,10 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
put("bar", null);
|
||||
}};
|
||||
expected = Arrays.asList(0.6899744811, 0.3100255188);
|
||||
probabilities = ensemble.classificationProbability(featureMap);
|
||||
probabilities =
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -292,7 +321,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
.setLeftChild(3)
|
||||
.setRightChild(4))
|
||||
.addNode(TreeNode.builder(3).setLeafValue(0.0))
|
||||
.addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
|
||||
.addNode(TreeNode.builder(4).setLeafValue(1.0))
|
||||
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
|
||||
.build();
|
||||
Tree tree2 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
|
@ -302,6 +333,7 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(0.0))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(1.0))
|
||||
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
|
||||
.build();
|
||||
Tree tree3 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
|
@ -312,31 +344,36 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
.setThreshold(1.0))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(1.0))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(0.0))
|
||||
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
|
||||
.build();
|
||||
Ensemble ensemble = Ensemble.builder()
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0)))
|
||||
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}))
|
||||
.build();
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
|
||||
featureVector = Arrays.asList(0.0, 1.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
put("foo", 0.3);
|
||||
put("bar", null);
|
||||
}};
|
||||
assertEquals(0.0, ensemble.infer(featureMap), 0.00001);
|
||||
assertThat(0.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
}
|
||||
|
||||
public void testRegressionInference() {
|
||||
|
@ -370,16 +407,18 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
.setTargetType(TargetType.REGRESSION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2))
|
||||
.setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5)))
|
||||
.setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5}))
|
||||
.build();
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(0.9, ensemble.infer(featureMap), 0.00001);
|
||||
assertThat(0.9,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001));
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(0.5, ensemble.infer(featureMap), 0.00001);
|
||||
assertThat(0.5,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001));
|
||||
|
||||
// Test with NO aggregator supplied, verifies default behavior of non-weighted sum
|
||||
ensemble = Ensemble.builder()
|
||||
|
@ -390,17 +429,32 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
|
||||
featureVector = Arrays.asList(0.4, 0.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(1.8, ensemble.infer(featureMap), 0.00001);
|
||||
assertThat(1.8,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001));
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001));
|
||||
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
put("foo", 0.3);
|
||||
put("bar", null);
|
||||
}};
|
||||
assertEquals(1.8, ensemble.infer(featureMap), 0.00001);
|
||||
assertThat(1.8,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001));
|
||||
}
|
||||
|
||||
public void testOperationsEstimations() {
|
||||
Tree tree1 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 2);
|
||||
Tree tree2 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
|
||||
Tree tree3 = TreeTests.buildRandomTree(Arrays.asList("foo", "baz"), 3);
|
||||
Ensemble ensemble = Ensemble.builder().setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setFeatureNames(Arrays.asList("foo", "bar", "baz"))
|
||||
.setOutputAggregator(new LogisticRegression(new double[]{0.1, 0.4, 1.0}))
|
||||
.build();
|
||||
assertThat(ensemble.estimatedNumOperations(), equalTo(9L));
|
||||
}
|
||||
|
||||
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
|
||||
|
|
|
@ -8,20 +8,21 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticRegression> {
|
||||
|
||||
@Override
|
||||
LogisticRegression createTestInstance(int numberOfWeights) {
|
||||
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList());
|
||||
double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray();
|
||||
return new LogisticRegression(weights);
|
||||
}
|
||||
|
||||
|
@ -41,13 +42,13 @@ public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticReg
|
|||
}
|
||||
|
||||
public void testAggregate() {
|
||||
List<Double> ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0);
|
||||
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
|
||||
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
|
||||
|
||||
LogisticRegression logisticRegression = new LogisticRegression(ones);
|
||||
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0));
|
||||
|
||||
List<Double> variedWeights = Arrays.asList(.01, -1.0, .1, 0.0, 0.0);
|
||||
double[] variedWeights = new double[]{.01, -1.0, .1, 0.0, 0.0};
|
||||
|
||||
logisticRegression = new LogisticRegression(variedWeights);
|
||||
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(0.0));
|
||||
|
@ -56,4 +57,9 @@ public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticReg
|
|||
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0));
|
||||
}
|
||||
|
||||
public void testCompatibleWith() {
|
||||
LogisticRegression logisticRegression = createTestInstance();
|
||||
assertThat(logisticRegression.compatibleWith(TargetType.CLASSIFICATION), is(true));
|
||||
assertThat(logisticRegression.compatibleWith(TargetType.REGRESSION), is(true));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,20 +8,21 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
|
||||
|
||||
@Override
|
||||
WeightedMode createTestInstance(int numberOfWeights) {
|
||||
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList());
|
||||
double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray();
|
||||
return new WeightedMode(weights);
|
||||
}
|
||||
|
||||
|
@ -41,13 +42,13 @@ public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
|
|||
}
|
||||
|
||||
public void testAggregate() {
|
||||
List<Double> ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0);
|
||||
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
|
||||
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
|
||||
|
||||
WeightedMode weightedMode = new WeightedMode(ones);
|
||||
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
|
||||
|
||||
List<Double> variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0);
|
||||
double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0};
|
||||
|
||||
weightedMode = new WeightedMode(variedWeights);
|
||||
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(5.0));
|
||||
|
@ -55,4 +56,10 @@ public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
|
|||
weightedMode = new WeightedMode();
|
||||
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
|
||||
}
|
||||
|
||||
public void testCompatibleWith() {
|
||||
WeightedMode weightedMode = createTestInstance();
|
||||
assertThat(weightedMode.compatibleWith(TargetType.CLASSIFICATION), is(true));
|
||||
assertThat(weightedMode.compatibleWith(TargetType.REGRESSION), is(true));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,20 +8,21 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class WeightedSumTests extends WeightedAggregatorTests<WeightedSum> {
|
||||
|
||||
@Override
|
||||
WeightedSum createTestInstance(int numberOfWeights) {
|
||||
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList());
|
||||
double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray();
|
||||
return new WeightedSum(weights);
|
||||
}
|
||||
|
||||
|
@ -41,13 +42,13 @@ public class WeightedSumTests extends WeightedAggregatorTests<WeightedSum> {
|
|||
}
|
||||
|
||||
public void testAggregate() {
|
||||
List<Double> ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0);
|
||||
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
|
||||
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
|
||||
|
||||
WeightedSum weightedSum = new WeightedSum(ones);
|
||||
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0));
|
||||
|
||||
List<Double> variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0);
|
||||
double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0};
|
||||
|
||||
weightedSum = new WeightedSum(variedWeights);
|
||||
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(28.0));
|
||||
|
@ -55,4 +56,10 @@ public class WeightedSumTests extends WeightedAggregatorTests<WeightedSum> {
|
|||
weightedSum = new WeightedSum();
|
||||
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0));
|
||||
}
|
||||
|
||||
public void testCompatibleWith() {
|
||||
WeightedSum weightedSum = createTestInstance();
|
||||
assertThat(weightedSum.compatibleWith(TargetType.CLASSIFICATION), is(false));
|
||||
assertThat(weightedSum.compatibleWith(TargetType.REGRESSION), is(true));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,10 @@ import org.elasticsearch.ElasticsearchStatusException;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -104,6 +108,19 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
return Tree::new;
|
||||
}
|
||||
|
||||
public void testInferWithStump() {
|
||||
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
|
||||
builder.setRoot(TreeNode.builder(0).setLeafValue(42.0));
|
||||
builder.setFeatureNames(Collections.emptyList());
|
||||
|
||||
Tree tree = builder.build();
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
List<Double> featureVector = Arrays.asList(0.6, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump
|
||||
assertThat(42.0,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
|
||||
}
|
||||
|
||||
public void testInfer() {
|
||||
// Build a tree with 2 nodes and 3 leaves using 2 features
|
||||
// The leaves have unique values 0.1, 0.2, 0.3
|
||||
|
@ -120,26 +137,36 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
// This feature vector should hit the right child of the root node
|
||||
List<Double> featureVector = Arrays.asList(0.6, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(0.3, closeTo(tree.infer(featureMap), 0.00001));
|
||||
assertThat(0.3,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
|
||||
|
||||
// This should hit the left child of the left child of the root node
|
||||
// i.e. it takes the path left, left
|
||||
featureVector = Arrays.asList(0.3, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001));
|
||||
assertThat(0.1,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
|
||||
|
||||
// This should hit the right child of the left child of the root node
|
||||
// i.e. it takes the path left, right
|
||||
featureVector = Arrays.asList(0.3, 0.9);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(0.2, closeTo(tree.infer(featureMap), 0.00001));
|
||||
assertThat(0.2,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
|
||||
|
||||
// This should still work if the internal values are strings
|
||||
List<String> featureVectorStrings = Arrays.asList("0.3", "0.9");
|
||||
featureMap = zipObjMap(featureNames, featureVectorStrings);
|
||||
assertThat(0.2,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
|
||||
|
||||
// This should handle missing values and take the default_left path
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
put("foo", 0.3);
|
||||
put("bar", null);
|
||||
}};
|
||||
assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001));
|
||||
assertThat(0.1,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
|
||||
}
|
||||
|
||||
public void testTreeClassificationProbability() {
|
||||
|
@ -153,31 +180,43 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
builder.addLeaf(leftChildNode.getRightChild(), 0.0);
|
||||
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree tree = builder.setFeatureNames(featureNames).build();
|
||||
Tree tree = builder.setFeatureNames(featureNames).setClassificationLabels(Arrays.asList("cat", "dog")).build();
|
||||
|
||||
double eps = 0.000001;
|
||||
// This feature vector should hit the right child of the root node
|
||||
List<Double> featureVector = Arrays.asList(0.6, 0.0);
|
||||
List<Double> expectedProbs = Arrays.asList(1.0, 0.0);
|
||||
List<String> expectedFields = Arrays.asList("dog", "cat");
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap));
|
||||
List<ClassificationInferenceResults.TopClassEntry> probabilities =
|
||||
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
for(int i = 0; i < expectedProbs.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
|
||||
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
|
||||
}
|
||||
|
||||
// This should hit the left child of the left child of the root node
|
||||
// i.e. it takes the path left, left
|
||||
featureVector = Arrays.asList(0.3, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap));
|
||||
|
||||
// This should hit the right child of the left child of the root node
|
||||
// i.e. it takes the path left, right
|
||||
featureVector = Arrays.asList(0.3, 0.9);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap));
|
||||
probabilities =
|
||||
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
for(int i = 0; i < expectedProbs.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
|
||||
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
|
||||
}
|
||||
|
||||
// This should handle missing values and take the default_left path
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
put("foo", 0.3);
|
||||
put("bar", null);
|
||||
}};
|
||||
assertEquals(1.0, tree.infer(featureMap), 0.00001);
|
||||
probabilities =
|
||||
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
for(int i = 0; i < expectedProbs.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
|
||||
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
|
||||
}
|
||||
}
|
||||
|
||||
public void testTreeWithNullRoot() {
|
||||
|
@ -261,7 +300,12 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
assertThat(ex.getMessage(), equalTo(msg));
|
||||
}
|
||||
|
||||
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
|
||||
public void testOperationsEstimations() {
|
||||
Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
|
||||
assertThat(tree.estimatedNumOperations(), equalTo(7L));
|
||||
}
|
||||
|
||||
private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
|
||||
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,19 +6,16 @@
|
|||
package org.elasticsearch.xpack.core.ml.notifications;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.xpack.core.common.notifications.Level;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Job;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
public class AnomalyDetectionAuditMessageTests extends AuditMessageTests<AnomalyDetectionAuditMessage> {
|
||||
|
||||
public class AnomalyDetectionAuditMessageTests extends AbstractXContentTestCase<AnomalyDetectionAuditMessage> {
|
||||
|
||||
public void testGetJobType() {
|
||||
AnomalyDetectionAuditMessage message = createTestInstance();
|
||||
assertThat(message.getJobType(), equalTo(Job.ANOMALY_DETECTOR_JOB_TYPE));
|
||||
@Override
|
||||
public String getJobType() {
|
||||
return Job.ANOMALY_DETECTOR_JOB_TYPE;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -26,11 +23,6 @@ public class AnomalyDetectionAuditMessageTests extends AbstractXContentTestCase<
|
|||
return AnomalyDetectionAuditMessage.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AnomalyDetectionAuditMessage createTestInstance() {
|
||||
return new AnomalyDetectionAuditMessage(
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.notifications;
|
||||
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage;
|
||||
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public abstract class AuditMessageTests<T extends AbstractAuditMessage> extends AbstractXContentTestCase<T> {
|
||||
|
||||
public abstract String getJobType();
|
||||
|
||||
public void testGetJobType() {
|
||||
AbstractAuditMessage message = createTestInstance();
|
||||
assertThat(message.getJobType(), equalTo(getJobType()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -6,30 +6,22 @@
|
|||
package org.elasticsearch.xpack.core.ml.notifications;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.xpack.core.common.notifications.Level;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
public class DataFrameAnalyticsAuditMessageTests extends AuditMessageTests<DataFrameAnalyticsAuditMessage> {
|
||||
|
||||
public class DataFrameAnalyticsAuditMessageTests extends AbstractXContentTestCase<DataFrameAnalyticsAuditMessage> {
|
||||
|
||||
public void testGetJobType() {
|
||||
DataFrameAnalyticsAuditMessage message = createTestInstance();
|
||||
assertThat(message.getJobType(), equalTo("data_frame_analytics"));
|
||||
@Override
|
||||
public String getJobType() {
|
||||
return "data_frame_analytics";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected DataFrameAnalyticsAuditMessage doParseInstance(XContentParser parser) {
|
||||
return DataFrameAnalyticsAuditMessage.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected DataFrameAnalyticsAuditMessage createTestInstance() {
|
||||
return new DataFrameAnalyticsAuditMessage(
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.notifications;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.common.notifications.Level;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
public class InferenceAuditMessageTests extends AuditMessageTests<InferenceAuditMessage> {
|
||||
|
||||
@Override
|
||||
public String getJobType() {
|
||||
return "inference";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected InferenceAuditMessage doParseInstance(XContentParser parser) {
|
||||
return InferenceAuditMessage.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected InferenceAuditMessage createTestInstance() {
|
||||
return new InferenceAuditMessage(
|
||||
randomBoolean() ? null : randomAlphaOfLength(10),
|
||||
randomAlphaOfLengthBetween(1, 20),
|
||||
randomFrom(Level.values()),
|
||||
new Date(),
|
||||
randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20)
|
||||
);
|
||||
}
|
||||
}
|
|
@ -126,6 +126,13 @@ integTest.runner {
|
|||
'ml/filter_crud/Test get all filter given index exists but no mapping for filter_id',
|
||||
'ml/get_datafeed_stats/Test get datafeed stats given missing datafeed_id',
|
||||
'ml/get_datafeeds/Test get datafeed given missing datafeed_id',
|
||||
'ml/inference_crud/Test delete given used trained model',
|
||||
'ml/inference_crud/Test delete given unused trained model',
|
||||
'ml/inference_crud/Test delete with missing model',
|
||||
'ml/inference_crud/Test get given missing trained model',
|
||||
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
|
||||
'ml/inference_stats_crud/Test get stats given missing trained model',
|
||||
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
|
||||
'ml/jobs_crud/Test cannot create job with existing categorizer state document',
|
||||
'ml/jobs_crud/Test cannot create job with existing quantiles document',
|
||||
'ml/jobs_crud/Test cannot create job with existing result document',
|
||||
|
|
|
@ -0,0 +1,544 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
||||
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
|
||||
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.index.mapper.MapperService;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
||||
|
||||
@Before
|
||||
public void createBothModels() {
|
||||
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||
.setId("test_classification")
|
||||
.setSource(CLASSIFICATION_CONFIG, XContentType.JSON)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get().status(), equalTo(RestStatus.CREATED));
|
||||
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||
.setId(TrainedModelDefinition.docId("test_classification"))
|
||||
.setSource(CLASSIFICATION_DEFINITION, XContentType.JSON)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get().status(), equalTo(RestStatus.CREATED));
|
||||
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||
.setId("test_regression")
|
||||
.setSource(REGRESSION_CONFIG, XContentType.JSON)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get().status(), equalTo(RestStatus.CREATED));
|
||||
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||
.setId(TrainedModelDefinition.docId("test_regression"))
|
||||
.setSource(REGRESSION_DEFINITION, XContentType.JSON)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get().status(), equalTo(RestStatus.CREATED));
|
||||
}
|
||||
|
||||
public void testPipelineCreationAndDeletion() throws Exception {
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline",
|
||||
new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON).get().isAcknowledged(), is(true));
|
||||
|
||||
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
|
||||
.setSource(new HashMap<String, Object>(){{
|
||||
put("col1", randomFrom("female", "male"));
|
||||
put("col2", randomFrom("S", "M", "L", "XL"));
|
||||
put("col3", randomFrom("true", "false", "none", "other"));
|
||||
put("col4", randomIntBetween(0, 10));
|
||||
}})
|
||||
.setPipeline("simple_classification_pipeline")
|
||||
.get();
|
||||
|
||||
assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(),
|
||||
is(true));
|
||||
|
||||
assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline",
|
||||
new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON).get().isAcknowledged(), is(true));
|
||||
|
||||
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
|
||||
.setSource(new HashMap<String, Object>(){{
|
||||
put("col1", randomFrom("female", "male"));
|
||||
put("col2", randomFrom("S", "M", "L", "XL"));
|
||||
put("col3", randomFrom("true", "false", "none", "other"));
|
||||
put("col4", randomIntBetween(0, 10));
|
||||
}})
|
||||
.setPipeline("simple_regression_pipeline")
|
||||
.get();
|
||||
|
||||
assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(),
|
||||
is(true));
|
||||
}
|
||||
|
||||
assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline",
|
||||
new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON).get().isAcknowledged(), is(true));
|
||||
|
||||
assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline",
|
||||
new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON).get().isAcknowledged(), is(true));
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
|
||||
.setSource(generateSourceDoc())
|
||||
.setPipeline("simple_classification_pipeline")
|
||||
.get();
|
||||
|
||||
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
|
||||
.setSource(generateSourceDoc())
|
||||
.setPipeline("simple_regression_pipeline")
|
||||
.get();
|
||||
}
|
||||
|
||||
assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(),
|
||||
is(true));
|
||||
|
||||
assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(),
|
||||
is(true));
|
||||
|
||||
client().admin().indices().refresh(new RefreshRequest("index_for_inference_test")).get();
|
||||
|
||||
assertThat(client().search(new SearchRequest().indices("index_for_inference_test")
|
||||
.source(new SearchSourceBuilder()
|
||||
.size(0)
|
||||
.trackTotalHits(true)
|
||||
.query(QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery("regression_value"))))).get().getHits().getTotalHits().value,
|
||||
equalTo(20L));
|
||||
|
||||
assertThat(client().search(new SearchRequest().indices("index_for_inference_test")
|
||||
.source(new SearchSourceBuilder()
|
||||
.size(0)
|
||||
.trackTotalHits(true)
|
||||
.query(QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery("result_class"))))).get().getHits().getTotalHits().value,
|
||||
equalTo(20L));
|
||||
|
||||
}
|
||||
|
||||
public void testSimulate() {
|
||||
String source = "{\n" +
|
||||
" \"pipeline\": {\n" +
|
||||
" \"processors\": [\n" +
|
||||
" {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"target_field\": \"result_class\",\n" +
|
||||
" \"inference_config\": {\"classification\":{}},\n" +
|
||||
" \"model_id\": \"test_classification\",\n" +
|
||||
" \"field_mappings\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
" \"col2\": \"col2\",\n" +
|
||||
" \"col3\": \"col3\",\n" +
|
||||
" \"col4\": \"col4\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"target_field\": \"result_class_prob\",\n" +
|
||||
" \"inference_config\": {\"classification\": {\"num_top_classes\":2}},\n" +
|
||||
" \"model_id\": \"test_classification\",\n" +
|
||||
" \"field_mappings\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
" \"col2\": \"col2\",\n" +
|
||||
" \"col3\": \"col3\",\n" +
|
||||
" \"col4\": \"col4\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"target_field\": \"regression_value\",\n" +
|
||||
" \"model_id\": \"test_regression\",\n" +
|
||||
" \"inference_config\": {\"regression\":{}},\n" +
|
||||
" \"field_mappings\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
" \"col2\": \"col2\",\n" +
|
||||
" \"col3\": \"col3\",\n" +
|
||||
" \"col4\": \"col4\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ]\n" +
|
||||
" },\n" +
|
||||
" \"docs\": [\n" +
|
||||
" {\"_source\": {\n" +
|
||||
" \"col1\": \"female\",\n" +
|
||||
" \"col2\": \"M\",\n" +
|
||||
" \"col3\": \"none\",\n" +
|
||||
" \"col4\": 10\n" +
|
||||
" }}]\n" +
|
||||
"}";
|
||||
|
||||
SimulatePipelineResponse response = client().admin().cluster()
|
||||
.prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON).get();
|
||||
SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0);
|
||||
assertThat(baseResult.getIngestDocument().getFieldValue("regression_value", Double.class), equalTo(1.0));
|
||||
assertThat(baseResult.getIngestDocument().getFieldValue("result_class", String.class), equalTo("second"));
|
||||
assertThat(baseResult.getIngestDocument().getFieldValue("result_class_prob", List.class).size(), equalTo(2));
|
||||
|
||||
String sourceWithMissingModel = "{\n" +
|
||||
" \"pipeline\": {\n" +
|
||||
" \"processors\": [\n" +
|
||||
" {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"target_field\": \"result_class\",\n" +
|
||||
" \"model_id\": \"test_classification_missing\",\n" +
|
||||
" \"inference_config\": {\"classification\":{}},\n" +
|
||||
" \"field_mappings\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
" \"col2\": \"col2\",\n" +
|
||||
" \"col3\": \"col3\",\n" +
|
||||
" \"col4\": \"col4\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ]\n" +
|
||||
" },\n" +
|
||||
" \"docs\": [\n" +
|
||||
" {\"_source\": {\n" +
|
||||
" \"col1\": \"female\",\n" +
|
||||
" \"col2\": \"M\",\n" +
|
||||
" \"col3\": \"none\",\n" +
|
||||
" \"col4\": 10\n" +
|
||||
" }}]\n" +
|
||||
"}";
|
||||
|
||||
response = client().admin().cluster()
|
||||
.prepareSimulatePipeline(new BytesArray(sourceWithMissingModel.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON).get();
|
||||
|
||||
assertThat(((SimulateDocumentBaseResult) response.getResults().get(0)).getFailure().getMessage(),
|
||||
containsString("Could not find trained model [test_classification_missing]"));
|
||||
}
|
||||
|
||||
private Map<String, Object> generateSourceDoc() {
|
||||
return new HashMap<String, Object>(){{
|
||||
put("col1", randomFrom("female", "male"));
|
||||
put("col2", randomFrom("S", "M", "L", "XL"));
|
||||
put("col3", randomFrom("true", "false", "none", "other"));
|
||||
put("col4", randomIntBetween(0, 10));
|
||||
}};
|
||||
}
|
||||
|
||||
private static final String REGRESSION_DEFINITION = "{" +
|
||||
" \"preprocessors\": [\n" +
|
||||
" {\n" +
|
||||
" \"one_hot_encoding\": {\n" +
|
||||
" \"field\": \"col1\",\n" +
|
||||
" \"hot_map\": {\n" +
|
||||
" \"male\": \"col1_male\",\n" +
|
||||
" \"female\": \"col1_female\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"target_mean_encoding\": {\n" +
|
||||
" \"field\": \"col2\",\n" +
|
||||
" \"feature_name\": \"col2_encoded\",\n" +
|
||||
" \"target_map\": {\n" +
|
||||
" \"S\": 5.0,\n" +
|
||||
" \"M\": 10.0,\n" +
|
||||
" \"L\": 20\n" +
|
||||
" },\n" +
|
||||
" \"default_value\": 5.0\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"frequency_encoding\": {\n" +
|
||||
" \"field\": \"col3\",\n" +
|
||||
" \"feature_name\": \"col3_encoded\",\n" +
|
||||
" \"frequency_map\": {\n" +
|
||||
" \"none\": 0.75,\n" +
|
||||
" \"true\": 0.10,\n" +
|
||||
" \"false\": 0.15\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"trained_model\": {\n" +
|
||||
" \"ensemble\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col1_male\",\n" +
|
||||
" \"col1_female\",\n" +
|
||||
" \"col2_encoded\",\n" +
|
||||
" \"col3_encoded\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"aggregate_output\": {\n" +
|
||||
" \"weighted_sum\": {\n" +
|
||||
" \"weights\": [\n" +
|
||||
" 0.5,\n" +
|
||||
" 0.5\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" \"target_type\": \"regression\",\n" +
|
||||
" \"trained_models\": [\n" +
|
||||
" {\n" +
|
||||
" \"tree\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col1_male\",\n" +
|
||||
" \"col1_female\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"tree_structure\": [\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"leaf_value\": 2\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"tree\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col2_encoded\",\n" +
|
||||
" \"col3_encoded\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"tree_structure\": [\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"leaf_value\": 2\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" \"model_id\": \"test_regression\"\n" +
|
||||
"}";
|
||||
|
||||
private static final String REGRESSION_CONFIG = "{" +
|
||||
" \"model_id\": \"test_regression\",\n" +
|
||||
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
|
||||
" \"description\": \"test model for regression\",\n" +
|
||||
" \"version\": \"7.6.0\",\n" +
|
||||
" \"created_by\": \"ml_test\",\n" +
|
||||
" \"estimated_heap_memory_usage_bytes\": 0," +
|
||||
" \"estimated_operations\": 0," +
|
||||
" \"created_time\": 0" +
|
||||
"}";
|
||||
|
||||
private static final String CLASSIFICATION_DEFINITION = "{" +
|
||||
" \"preprocessors\": [\n" +
|
||||
" {\n" +
|
||||
" \"one_hot_encoding\": {\n" +
|
||||
" \"field\": \"col1\",\n" +
|
||||
" \"hot_map\": {\n" +
|
||||
" \"male\": \"col1_male\",\n" +
|
||||
" \"female\": \"col1_female\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"target_mean_encoding\": {\n" +
|
||||
" \"field\": \"col2\",\n" +
|
||||
" \"feature_name\": \"col2_encoded\",\n" +
|
||||
" \"target_map\": {\n" +
|
||||
" \"S\": 5.0,\n" +
|
||||
" \"M\": 10.0,\n" +
|
||||
" \"L\": 20\n" +
|
||||
" },\n" +
|
||||
" \"default_value\": 5.0\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"frequency_encoding\": {\n" +
|
||||
" \"field\": \"col3\",\n" +
|
||||
" \"feature_name\": \"col3_encoded\",\n" +
|
||||
" \"frequency_map\": {\n" +
|
||||
" \"none\": 0.75,\n" +
|
||||
" \"true\": 0.10,\n" +
|
||||
" \"false\": 0.15\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"trained_model\": {\n" +
|
||||
" \"ensemble\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col1_male\",\n" +
|
||||
" \"col1_female\",\n" +
|
||||
" \"col2_encoded\",\n" +
|
||||
" \"col3_encoded\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"aggregate_output\": {\n" +
|
||||
" \"weighted_mode\": {\n" +
|
||||
" \"weights\": [\n" +
|
||||
" 0.5,\n" +
|
||||
" 0.5\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" \"target_type\": \"classification\",\n" +
|
||||
" \"classification_labels\": [\"first\", \"second\"],\n" +
|
||||
" \"trained_models\": [\n" +
|
||||
" {\n" +
|
||||
" \"tree\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col1_male\",\n" +
|
||||
" \"col1_female\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"tree_structure\": [\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"leaf_value\": 2\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"tree\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col2_encoded\",\n" +
|
||||
" \"col3_encoded\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"tree_structure\": [\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"leaf_value\": 2\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" \"model_id\": \"test_classification\"\n" +
|
||||
"}";
|
||||
|
||||
private static final String CLASSIFICATION_CONFIG = "" +
|
||||
"{\n" +
|
||||
" \"model_id\": \"test_classification\",\n" +
|
||||
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
|
||||
" \"description\": \"test model for classification\",\n" +
|
||||
" \"version\": \"7.6.0\",\n" +
|
||||
" \"created_by\": \"benwtrent\",\n" +
|
||||
" \"estimated_heap_memory_usage_bytes\": 0," +
|
||||
" \"estimated_operations\": 0," +
|
||||
" \"created_time\": 0\n" +
|
||||
"}";
|
||||
|
||||
private static final String CLASSIFICATION_PIPELINE = "{" +
|
||||
" \"processors\": [\n" +
|
||||
" {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"target_field\": \"result_class\",\n" +
|
||||
" \"model_id\": \"test_classification\",\n" +
|
||||
" \"inference_config\": {\"classification\": {}},\n" +
|
||||
" \"field_mappings\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
" \"col2\": \"col2\",\n" +
|
||||
" \"col3\": \"col3\",\n" +
|
||||
" \"col4\": \"col4\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }]}\n";
|
||||
|
||||
private static final String REGRESSION_PIPELINE = "{" +
|
||||
" \"processors\": [\n" +
|
||||
" {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"target_field\": \"regression_value\",\n" +
|
||||
" \"model_id\": \"test_regression\",\n" +
|
||||
" \"inference_config\": {\"regression\": {}},\n" +
|
||||
" \"field_mappings\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
" \"col2\": \"col2\",\n" +
|
||||
" \"col3\": \"col3\",\n" +
|
||||
" \"col4\": \"col4\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }]}\n";
|
||||
|
||||
}
|
|
@ -0,0 +1,221 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.apache.http.util.EntityUtils;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.client.Request;
|
||||
import org.elasticsearch.client.Response;
|
||||
import org.elasticsearch.client.ResponseException;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.test.SecuritySettingsSourceField;
|
||||
import org.elasticsearch.test.rest.ESRestTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests;
|
||||
import org.junit.After;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
||||
public class TrainedModelIT extends ESRestTestCase {
|
||||
|
||||
private static final String BASIC_AUTH_VALUE = basicAuthHeaderValue("x_pack_rest_user",
|
||||
SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
|
||||
|
||||
@Override
|
||||
protected Settings restClientSettings() {
|
||||
return Settings.builder().put(super.restClientSettings()).put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean preserveTemplatesUponCompletion() {
|
||||
return true;
|
||||
}
|
||||
|
||||
public void testGetTrainedModels() throws IOException {
|
||||
String modelId = "test_regression_model";
|
||||
String modelId2 = "test_regression_model-2";
|
||||
Request model1 = new Request("PUT",
|
||||
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
|
||||
model1.setJsonEntity(buildRegressionModel(modelId));
|
||||
assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
|
||||
|
||||
Request modelDefinition1 = new Request("PUT",
|
||||
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId));
|
||||
modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId));
|
||||
assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
|
||||
|
||||
Request model2 = new Request("PUT",
|
||||
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2);
|
||||
model2.setJsonEntity(buildRegressionModel(modelId2));
|
||||
assertThat(client().performRequest(model2).getStatusLine().getStatusCode(), equalTo(201));
|
||||
|
||||
adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
|
||||
Response getModel = client().performRequest(new Request("GET",
|
||||
MachineLearning.BASE_PATH + "inference/" + modelId));
|
||||
|
||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||
String response = EntityUtils.toString(getModel.getEntity());
|
||||
|
||||
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
|
||||
assertThat(response, containsString("\"count\":1"));
|
||||
|
||||
getModel = client().performRequest(new Request("GET",
|
||||
MachineLearning.BASE_PATH + "inference/test_regression*"));
|
||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||
|
||||
response = EntityUtils.toString(getModel.getEntity());
|
||||
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
|
||||
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
|
||||
assertThat(response, not(containsString("\"definition\"")));
|
||||
assertThat(response, containsString("\"count\":2"));
|
||||
|
||||
getModel = client().performRequest(new Request("GET",
|
||||
MachineLearning.BASE_PATH + "inference/test_regression_model?human=true&include_model_definition=true"));
|
||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||
|
||||
response = EntityUtils.toString(getModel.getEntity());
|
||||
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
|
||||
assertThat(response, containsString("\"heap_memory_estimation_bytes\""));
|
||||
assertThat(response, containsString("\"heap_memory_estimation\""));
|
||||
assertThat(response, containsString("\"definition\""));
|
||||
assertThat(response, containsString("\"count\":1"));
|
||||
|
||||
ResponseException responseException = expectThrows(ResponseException.class, () ->
|
||||
client().performRequest(new Request("GET",
|
||||
MachineLearning.BASE_PATH + "inference/test_regression*?human=true&include_model_definition=true")));
|
||||
assertThat(EntityUtils.toString(responseException.getResponse().getEntity()),
|
||||
containsString(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED));
|
||||
|
||||
getModel = client().performRequest(new Request("GET",
|
||||
MachineLearning.BASE_PATH + "inference/test_regression_model,test_regression_model-2"));
|
||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||
|
||||
response = EntityUtils.toString(getModel.getEntity());
|
||||
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
|
||||
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
|
||||
assertThat(response, containsString("\"count\":2"));
|
||||
|
||||
getModel = client().performRequest(new Request("GET",
|
||||
MachineLearning.BASE_PATH + "inference/classification*?allow_no_match=true"));
|
||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||
|
||||
response = EntityUtils.toString(getModel.getEntity());
|
||||
assertThat(response, containsString("\"count\":0"));
|
||||
|
||||
ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(new Request("GET",
|
||||
MachineLearning.BASE_PATH + "inference/classification*?allow_no_match=false")));
|
||||
assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(404));
|
||||
|
||||
getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=0&size=1"));
|
||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||
|
||||
response = EntityUtils.toString(getModel.getEntity());
|
||||
assertThat(response, containsString("\"count\":2"));
|
||||
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
|
||||
assertThat(response, not(containsString("\"model_id\":\"test_regression_model-2\"")));
|
||||
|
||||
getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=1&size=1"));
|
||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||
|
||||
response = EntityUtils.toString(getModel.getEntity());
|
||||
assertThat(response, containsString("\"count\":2"));
|
||||
assertThat(response, not(containsString("\"model_id\":\"test_regression_model\"")));
|
||||
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
|
||||
}
|
||||
|
||||
public void testDeleteTrainedModels() throws IOException {
|
||||
String modelId = "test_delete_regression_model";
|
||||
Request model1 = new Request("PUT",
|
||||
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
|
||||
model1.setJsonEntity(buildRegressionModel(modelId));
|
||||
assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
|
||||
|
||||
Request modelDefinition1 = new Request("PUT",
|
||||
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId));
|
||||
modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId));
|
||||
assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
|
||||
|
||||
adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
|
||||
|
||||
Response delModel = client().performRequest(new Request("DELETE",
|
||||
MachineLearning.BASE_PATH + "inference/" + modelId));
|
||||
String response = EntityUtils.toString(delModel.getEntity());
|
||||
assertThat(response, containsString("\"acknowledged\":true"));
|
||||
|
||||
ResponseException responseException = expectThrows(ResponseException.class,
|
||||
() -> client().performRequest(new Request("DELETE", MachineLearning.BASE_PATH + "inference/" + modelId)));
|
||||
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));
|
||||
|
||||
responseException = expectThrows(ResponseException.class,
|
||||
() -> client().performRequest(
|
||||
new Request("GET",
|
||||
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId))));
|
||||
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));
|
||||
|
||||
responseException = expectThrows(ResponseException.class,
|
||||
() -> client().performRequest(
|
||||
new Request("GET",
|
||||
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId)));
|
||||
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));
|
||||
}
|
||||
|
||||
private static String buildRegressionModel(String modelId) throws IOException {
|
||||
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||
TrainedModelConfig.builder()
|
||||
.setModelId(modelId)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3")))
|
||||
.setCreatedBy("ml_test")
|
||||
.setVersion(Version.CURRENT)
|
||||
.setCreateTime(Instant.now())
|
||||
.setEstimatedOperations(0)
|
||||
.setEstimatedHeapMemory(0)
|
||||
.build()
|
||||
.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
|
||||
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
|
||||
}
|
||||
}
|
||||
|
||||
private static String buildRegressionModelDefinition(String modelId) throws IOException {
|
||||
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||
new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Collections.emptyList())
|
||||
.setTrainedModel(LocalModelTests.buildRegression())
|
||||
.setModelId(modelId)
|
||||
.build()
|
||||
.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
|
||||
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@After
|
||||
public void clearMlState() throws Exception {
|
||||
new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata();
|
||||
ESRestTestCase.waitForPendingTasks(adminClient());
|
||||
}
|
||||
}
|
|
@ -43,12 +43,14 @@ import org.elasticsearch.env.NodeEnvironment;
|
|||
import org.elasticsearch.index.IndexSettings;
|
||||
import org.elasticsearch.index.analysis.TokenizerFactory;
|
||||
import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider;
|
||||
import org.elasticsearch.ingest.Processor;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.monitor.os.OsProbe;
|
||||
import org.elasticsearch.monitor.os.OsStats;
|
||||
import org.elasticsearch.persistent.PersistentTasksExecutor;
|
||||
import org.elasticsearch.plugins.ActionPlugin;
|
||||
import org.elasticsearch.plugins.AnalysisPlugin;
|
||||
import org.elasticsearch.plugins.IngestPlugin;
|
||||
import org.elasticsearch.plugins.PersistentTaskPlugin;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.rest.RestController;
|
||||
|
@ -72,6 +74,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteFilterAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
|
||||
|
@ -93,7 +96,10 @@ import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
|
||||
|
@ -139,6 +145,7 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteFilterAction;
|
|||
import org.elasticsearch.xpack.ml.action.TransportDeleteForecastAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportDeleteJobAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportEstimateMemoryUsageAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportFinalizeJobExecutionAction;
|
||||
|
@ -160,6 +167,9 @@ import org.elasticsearch.xpack.ml.action.TransportGetJobsStatsAction;
|
|||
import org.elasticsearch.xpack.ml.action.TransportGetModelSnapshotsAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsStatsAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportInferModelAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportKillProcessAction;
|
||||
import org.elasticsearch.xpack.ml.action.TransportMlInfoAction;
|
||||
|
@ -199,6 +209,8 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactor
|
|||
import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.job.JobManager;
|
||||
|
@ -221,6 +233,7 @@ import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerFactory;
|
|||
import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory;
|
||||
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
||||
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
|
||||
import org.elasticsearch.xpack.ml.process.NativeController;
|
||||
import org.elasticsearch.xpack.ml.process.NativeControllerHolder;
|
||||
|
@ -257,6 +270,9 @@ import org.elasticsearch.xpack.ml.rest.filter.RestDeleteFilterAction;
|
|||
import org.elasticsearch.xpack.ml.rest.filter.RestGetFiltersAction;
|
||||
import org.elasticsearch.xpack.ml.rest.filter.RestPutFilterAction;
|
||||
import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction;
|
||||
import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction;
|
||||
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction;
|
||||
import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction;
|
||||
import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction;
|
||||
import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction;
|
||||
|
@ -297,7 +313,7 @@ import java.util.function.UnaryOperator;
|
|||
import static java.util.Collections.emptyList;
|
||||
import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME;
|
||||
|
||||
public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlugin, PersistentTaskPlugin {
|
||||
public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin {
|
||||
public static final String NAME = "ml";
|
||||
public static final String BASE_PATH = "/_ml/";
|
||||
public static final String PRE_V7_BASE_PATH = "/_xpack/ml/";
|
||||
|
@ -321,6 +337,22 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
|
||||
};
|
||||
|
||||
@Override
|
||||
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
|
||||
if (this.enabled == false) {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
InferenceProcessor.Factory inferenceFactory = new InferenceProcessor.Factory(parameters.client,
|
||||
parameters.ingestService.getClusterService(),
|
||||
this.settings,
|
||||
parameters.ingestService,
|
||||
getLicenseState());
|
||||
getLicenseState().addListener(inferenceFactory);
|
||||
parameters.ingestService.addIngestClusterStateListener(inferenceFactory);
|
||||
return Collections.singletonMap(InferenceProcessor.TYPE, inferenceFactory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<DiscoveryNodeRole> getRoles() {
|
||||
return Collections.singleton(ML_ROLE);
|
||||
|
@ -401,18 +433,21 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
|
||||
public List<Setting<?>> getSettings() {
|
||||
return Collections.unmodifiableList(
|
||||
Arrays.asList(MachineLearningField.AUTODETECT_PROCESS,
|
||||
PROCESS_CONNECT_TIMEOUT,
|
||||
ML_ENABLED,
|
||||
CONCURRENT_JOB_ALLOCATIONS,
|
||||
MachineLearningField.MAX_MODEL_MEMORY_LIMIT,
|
||||
MAX_LAZY_ML_NODES,
|
||||
MAX_MACHINE_MEMORY_PERCENT,
|
||||
AutodetectBuilder.DONT_PERSIST_MODEL_STATE_SETTING,
|
||||
AutodetectBuilder.MAX_ANOMALY_RECORDS_SETTING_DYNAMIC,
|
||||
MAX_OPEN_JOBS_PER_NODE,
|
||||
MIN_DISK_SPACE_OFF_HEAP,
|
||||
MlConfigMigrationEligibilityCheck.ENABLE_CONFIG_MIGRATION));
|
||||
Arrays.asList(MachineLearningField.AUTODETECT_PROCESS,
|
||||
PROCESS_CONNECT_TIMEOUT,
|
||||
ML_ENABLED,
|
||||
CONCURRENT_JOB_ALLOCATIONS,
|
||||
MachineLearningField.MAX_MODEL_MEMORY_LIMIT,
|
||||
MAX_LAZY_ML_NODES,
|
||||
MAX_MACHINE_MEMORY_PERCENT,
|
||||
AutodetectBuilder.DONT_PERSIST_MODEL_STATE_SETTING,
|
||||
AutodetectBuilder.MAX_ANOMALY_RECORDS_SETTING_DYNAMIC,
|
||||
MAX_OPEN_JOBS_PER_NODE,
|
||||
MIN_DISK_SPACE_OFF_HEAP,
|
||||
MlConfigMigrationEligibilityCheck.ENABLE_CONFIG_MIGRATION,
|
||||
InferenceProcessor.MAX_INFERENCE_PROCESSORS,
|
||||
ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE,
|
||||
ModelLoadingService.INFERENCE_MODEL_CACHE_TTL));
|
||||
}
|
||||
|
||||
public Settings additionalSettings() {
|
||||
|
@ -483,6 +518,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
|
||||
AnomalyDetectionAuditor anomalyDetectionAuditor = new AnomalyDetectionAuditor(client, clusterService.getNodeName());
|
||||
DataFrameAnalyticsAuditor dataFrameAnalyticsAuditor = new DataFrameAnalyticsAuditor(client, clusterService.getNodeName());
|
||||
InferenceAuditor inferenceAuditor = new InferenceAuditor(client, clusterService.getNodeName());
|
||||
this.dataFrameAnalyticsAuditor.set(dataFrameAnalyticsAuditor);
|
||||
JobResultsProvider jobResultsProvider = new JobResultsProvider(client, settings);
|
||||
JobResultsPersister jobResultsPersister = new JobResultsPersister(client);
|
||||
|
@ -569,7 +605,12 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
this.datafeedManager.set(datafeedManager);
|
||||
|
||||
// Inference components
|
||||
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry);
|
||||
final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry);
|
||||
final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||
inferenceAuditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
settings);
|
||||
|
||||
// Data frame analytics components
|
||||
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
|
||||
|
@ -614,12 +655,14 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
datafeedManager,
|
||||
anomalyDetectionAuditor,
|
||||
dataFrameAnalyticsAuditor,
|
||||
inferenceAuditor,
|
||||
mlAssignmentNotifier,
|
||||
memoryTracker,
|
||||
analyticsProcessManager,
|
||||
memoryEstimationProcessManager,
|
||||
dataFrameAnalyticsConfigProvider,
|
||||
nativeStorageProvider,
|
||||
modelLoadingService,
|
||||
trainedModelProvider
|
||||
);
|
||||
}
|
||||
|
@ -717,7 +760,10 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
new RestStartDataFrameAnalyticsAction(restController),
|
||||
new RestStopDataFrameAnalyticsAction(restController),
|
||||
new RestEvaluateDataFrameAction(restController),
|
||||
new RestEstimateMemoryUsageAction(restController)
|
||||
new RestEstimateMemoryUsageAction(restController),
|
||||
new RestGetTrainedModelsAction(restController),
|
||||
new RestDeleteTrainedModelAction(restController),
|
||||
new RestGetTrainedModelsStatsAction(restController)
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -784,7 +830,11 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
new ActionHandler<>(StartDataFrameAnalyticsAction.INSTANCE, TransportStartDataFrameAnalyticsAction.class),
|
||||
new ActionHandler<>(StopDataFrameAnalyticsAction.INSTANCE, TransportStopDataFrameAnalyticsAction.class),
|
||||
new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class),
|
||||
new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class)
|
||||
new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class),
|
||||
new ActionHandler<>(InferModelAction.INSTANCE, TransportInferModelAction.class),
|
||||
new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class),
|
||||
new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class),
|
||||
new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class)
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,12 @@ import org.apache.lucene.util.Constants;
|
|||
import org.apache.lucene.util.Counter;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
|
||||
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
|
||||
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
|
||||
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
|
||||
import org.elasticsearch.action.search.SearchRequestBuilder;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.metadata.MetaData;
|
||||
|
@ -18,8 +24,10 @@ import org.elasticsearch.cluster.service.ClusterService;
|
|||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.env.Environment;
|
||||
import org.elasticsearch.ingest.IngestStats;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.plugins.Platforms;
|
||||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
import org.elasticsearch.xpack.core.XPackFeatureSet;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
import org.elasticsearch.xpack.core.XPackPlugin;
|
||||
|
@ -31,11 +39,13 @@ import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Job;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.JobState;
|
||||
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeStats;
|
||||
import org.elasticsearch.xpack.core.ml.stats.ForecastStats;
|
||||
import org.elasticsearch.xpack.core.ml.stats.StatsAccumulator;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
|
||||
import org.elasticsearch.xpack.ml.process.NativeController;
|
||||
import org.elasticsearch.xpack.ml.process.NativeControllerHolder;
|
||||
|
@ -44,11 +54,14 @@ import java.io.IOException;
|
|||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
@ -135,10 +148,10 @@ public class MachineLearningFeatureSet implements XPackFeatureSet {
|
|||
@Override
|
||||
public void usage(ActionListener<XPackFeatureSet.Usage> listener) {
|
||||
ClusterState state = clusterService.state();
|
||||
new Retriever(client, jobManagerHolder, available(), enabled(), mlNodeCount(state)).execute(listener);
|
||||
new Retriever(client, jobManagerHolder, available(), enabled(), state).execute(listener);
|
||||
}
|
||||
|
||||
private int mlNodeCount(final ClusterState clusterState) {
|
||||
private static int mlNodeCount(boolean enabled, final ClusterState clusterState) {
|
||||
if (enabled == false) {
|
||||
return 0;
|
||||
}
|
||||
|
@ -161,9 +174,11 @@ public class MachineLearningFeatureSet implements XPackFeatureSet {
|
|||
private Map<String, Object> jobsUsage;
|
||||
private Map<String, Object> datafeedsUsage;
|
||||
private Map<String, Object> analyticsUsage;
|
||||
private Map<String, Object> inferenceUsage;
|
||||
private final ClusterState state;
|
||||
private int nodeCount;
|
||||
|
||||
public Retriever(Client client, JobManagerHolder jobManagerHolder, boolean available, boolean enabled, int nodeCount) {
|
||||
public Retriever(Client client, JobManagerHolder jobManagerHolder, boolean available, boolean enabled, ClusterState state) {
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.jobManagerHolder = jobManagerHolder;
|
||||
this.available = available;
|
||||
|
@ -171,61 +186,15 @@ public class MachineLearningFeatureSet implements XPackFeatureSet {
|
|||
this.jobsUsage = new LinkedHashMap<>();
|
||||
this.datafeedsUsage = new LinkedHashMap<>();
|
||||
this.analyticsUsage = new LinkedHashMap<>();
|
||||
this.nodeCount = nodeCount;
|
||||
this.inferenceUsage = new LinkedHashMap<>();
|
||||
this.nodeCount = mlNodeCount(enabled, state);
|
||||
this.state = state;
|
||||
}
|
||||
|
||||
public void execute(ActionListener<Usage> listener) {
|
||||
// empty holder means either ML disabled or transport client mode
|
||||
if (jobManagerHolder.isEmpty()) {
|
||||
listener.onResponse(
|
||||
new MachineLearningFeatureSetUsage(available,
|
||||
enabled,
|
||||
Collections.emptyMap(),
|
||||
Collections.emptyMap(),
|
||||
Collections.emptyMap(),
|
||||
0));
|
||||
return;
|
||||
}
|
||||
|
||||
// Step 3. Extract usage from data frame analytics and return usage response
|
||||
ActionListener<GetDataFrameAnalyticsStatsAction.Response> dataframeAnalyticsListener = ActionListener.wrap(
|
||||
response -> {
|
||||
addDataFrameAnalyticsUsage(response, analyticsUsage);
|
||||
listener.onResponse(new MachineLearningFeatureSetUsage(available,
|
||||
enabled,
|
||||
jobsUsage,
|
||||
datafeedsUsage,
|
||||
analyticsUsage,
|
||||
nodeCount));
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
// Step 2. Extract usage from datafeeds stats and return usage response
|
||||
ActionListener<GetDatafeedsStatsAction.Response> datafeedStatsListener =
|
||||
ActionListener.wrap(response -> {
|
||||
addDatafeedsUsage(response);
|
||||
GetDataFrameAnalyticsStatsAction.Request dataframeAnalyticsStatsRequest =
|
||||
new GetDataFrameAnalyticsStatsAction.Request(GetDatafeedsStatsAction.ALL);
|
||||
dataframeAnalyticsStatsRequest.setPageParams(new PageParams(0, 10_000));
|
||||
client.execute(GetDataFrameAnalyticsStatsAction.INSTANCE, dataframeAnalyticsStatsRequest, dataframeAnalyticsListener);
|
||||
},
|
||||
listener::onFailure);
|
||||
|
||||
// Step 1. Extract usage from jobs stats and then request stats for all datafeeds
|
||||
GetJobsStatsAction.Request jobStatsRequest = new GetJobsStatsAction.Request(MetaData.ALL);
|
||||
ActionListener<GetJobsStatsAction.Response> jobStatsListener = ActionListener.wrap(
|
||||
response -> {
|
||||
jobManagerHolder.getJobManager().expandJobs(MetaData.ALL, true, ActionListener.wrap(jobs -> {
|
||||
addJobsUsage(response, jobs.results());
|
||||
GetDatafeedsStatsAction.Request datafeedStatsRequest = new GetDatafeedsStatsAction.Request(
|
||||
GetDatafeedsStatsAction.ALL);
|
||||
client.execute(GetDatafeedsStatsAction.INSTANCE, datafeedStatsRequest, datafeedStatsListener);
|
||||
}, listener::onFailure));
|
||||
}, listener::onFailure);
|
||||
|
||||
// Step 0. Kick off the chain of callbacks by requesting jobs stats
|
||||
client.execute(GetJobsStatsAction.INSTANCE, jobStatsRequest, jobStatsListener);
|
||||
private static void initializeStats(Map<String, Long> emptyStatsMap) {
|
||||
emptyStatsMap.put("sum", 0L);
|
||||
emptyStatsMap.put("min", 0L);
|
||||
emptyStatsMap.put("max", 0L);
|
||||
}
|
||||
|
||||
private void addJobsUsage(GetJobsStatsAction.Response response, List<Job> jobs) {
|
||||
|
@ -322,7 +291,7 @@ public class MachineLearningFeatureSet implements XPackFeatureSet {
|
|||
}
|
||||
|
||||
private void addDataFrameAnalyticsUsage(GetDataFrameAnalyticsStatsAction.Response response,
|
||||
Map<String, Object> dataframeAnalyticsUsage) {
|
||||
Map<String, Object> dataframeAnalyticsUsage) {
|
||||
Map<DataFrameAnalyticsState, Counter> dataFrameAnalyticsStateCounterMap = new HashMap<>();
|
||||
|
||||
for(GetDataFrameAnalyticsStatsAction.Response.Stats stats : response.getResponse().results()) {
|
||||
|
@ -334,5 +303,149 @@ public class MachineLearningFeatureSet implements XPackFeatureSet {
|
|||
createCountUsageEntry(dataFrameAnalyticsStateCounterMap.get(state).get()));
|
||||
}
|
||||
}
|
||||
|
||||
private static void updateStats(Map<String, Long> statsMap, Long value) {
|
||||
statsMap.compute("sum", (k, v) -> v + value);
|
||||
statsMap.compute("min", (k, v) -> Math.min(v, value));
|
||||
statsMap.compute("max", (k, v) -> Math.max(v, value));
|
||||
}
|
||||
|
||||
private static String[] ingestNodes(final ClusterState clusterState) {
|
||||
String[] ingestNodes = new String[clusterState.nodes().getIngestNodes().size()];
|
||||
Iterator<String> nodeIterator = clusterState.nodes().getIngestNodes().keysIt();
|
||||
int i = 0;
|
||||
while (nodeIterator.hasNext()) {
|
||||
ingestNodes[i++] = nodeIterator.next();
|
||||
}
|
||||
return ingestNodes;
|
||||
}
|
||||
|
||||
public void execute(ActionListener<Usage> listener) {
|
||||
// empty holder means either ML disabled or transport client mode
|
||||
if (jobManagerHolder.isEmpty()) {
|
||||
listener.onResponse(
|
||||
new MachineLearningFeatureSetUsage(available,
|
||||
enabled,
|
||||
Collections.emptyMap(),
|
||||
Collections.emptyMap(),
|
||||
Collections.emptyMap(),
|
||||
Collections.emptyMap(),
|
||||
0));
|
||||
return;
|
||||
}
|
||||
|
||||
// Step 5. extract trained model config count and then return results
|
||||
ActionListener<SearchResponse> trainedModelConfigCountListener = ActionListener.wrap(
|
||||
response -> {
|
||||
addTrainedModelStats(response, inferenceUsage);
|
||||
MachineLearningFeatureSetUsage usage = new MachineLearningFeatureSetUsage(available,
|
||||
enabled, jobsUsage, datafeedsUsage, analyticsUsage, inferenceUsage, nodeCount);
|
||||
listener.onResponse(usage);
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
// Step 4. Extract usage from ingest statistics and gather trained model config count
|
||||
ActionListener<NodesStatsResponse> nodesStatsListener = ActionListener.wrap(
|
||||
response -> {
|
||||
addInferenceIngestUsage(response, inferenceUsage);
|
||||
SearchRequestBuilder requestBuilder = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
||||
.setSize(0)
|
||||
.setTrackTotalHits(true);
|
||||
ClientHelper.executeAsyncWithOrigin(client.threadPool().getThreadContext(),
|
||||
ClientHelper.ML_ORIGIN,
|
||||
requestBuilder.request(),
|
||||
trainedModelConfigCountListener,
|
||||
client::search);
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
// Step 3. Extract usage from data frame analytics stats and then request ingest node stats
|
||||
ActionListener<GetDataFrameAnalyticsStatsAction.Response> dataframeAnalyticsListener = ActionListener.wrap(
|
||||
response -> {
|
||||
addDataFrameAnalyticsUsage(response, analyticsUsage);
|
||||
String[] ingestNodes = ingestNodes(state);
|
||||
NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear().ingest(true);
|
||||
client.execute(NodesStatsAction.INSTANCE, nodesStatsRequest, nodesStatsListener);
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
// Step 2. Extract usage from datafeeds stats and return usage response
|
||||
ActionListener<GetDatafeedsStatsAction.Response> datafeedStatsListener =
|
||||
ActionListener.wrap(response -> {
|
||||
addDatafeedsUsage(response);
|
||||
GetDataFrameAnalyticsStatsAction.Request dataframeAnalyticsStatsRequest =
|
||||
new GetDataFrameAnalyticsStatsAction.Request(GetDatafeedsStatsAction.ALL);
|
||||
dataframeAnalyticsStatsRequest.setPageParams(new PageParams(0, 10_000));
|
||||
client.execute(GetDataFrameAnalyticsStatsAction.INSTANCE, dataframeAnalyticsStatsRequest, dataframeAnalyticsListener);
|
||||
},
|
||||
listener::onFailure);
|
||||
|
||||
// Step 1. Extract usage from jobs stats and then request stats for all datafeeds
|
||||
GetJobsStatsAction.Request jobStatsRequest = new GetJobsStatsAction.Request(MetaData.ALL);
|
||||
ActionListener<GetJobsStatsAction.Response> jobStatsListener = ActionListener.wrap(
|
||||
response -> {
|
||||
jobManagerHolder.getJobManager().expandJobs(MetaData.ALL, true, ActionListener.wrap(jobs -> {
|
||||
addJobsUsage(response, jobs.results());
|
||||
GetDatafeedsStatsAction.Request datafeedStatsRequest = new GetDatafeedsStatsAction.Request(
|
||||
GetDatafeedsStatsAction.ALL);
|
||||
client.execute(GetDatafeedsStatsAction.INSTANCE, datafeedStatsRequest, datafeedStatsListener);
|
||||
}, listener::onFailure));
|
||||
}, listener::onFailure);
|
||||
|
||||
// Step 0. Kick off the chain of callbacks by requesting jobs stats
|
||||
client.execute(GetJobsStatsAction.INSTANCE, jobStatsRequest, jobStatsListener);
|
||||
}
|
||||
|
||||
//TODO separate out ours and users models possibly regression vs classification
|
||||
private void addTrainedModelStats(SearchResponse response, Map<String, Object> inferenceUsage) {
|
||||
inferenceUsage.put("trained_models",
|
||||
Collections.singletonMap(MachineLearningFeatureSetUsage.ALL,
|
||||
createCountUsageEntry(response.getHits().getTotalHits().value)));
|
||||
}
|
||||
|
||||
//TODO separate out ours and users models possibly regression vs classification
|
||||
private void addInferenceIngestUsage(NodesStatsResponse response, Map<String, Object> inferenceUsage) {
|
||||
Set<String> pipelines = new HashSet<>();
|
||||
Map<String, Long> docCountStats = new HashMap<>(3);
|
||||
Map<String, Long> timeStats = new HashMap<>(3);
|
||||
Map<String, Long> failureStats = new HashMap<>(3);
|
||||
initializeStats(docCountStats);
|
||||
initializeStats(timeStats);
|
||||
initializeStats(failureStats);
|
||||
|
||||
response.getNodes()
|
||||
.stream()
|
||||
.map(NodeStats::getIngestStats)
|
||||
.map(IngestStats::getProcessorStats)
|
||||
.forEach(map ->
|
||||
map.forEach((pipelineId, processors) -> {
|
||||
boolean containsInference = false;
|
||||
for (IngestStats.ProcessorStat stats : processors) {
|
||||
if (stats.getName().equals(InferenceProcessor.TYPE)) {
|
||||
containsInference = true;
|
||||
long ingestCount = stats.getStats().getIngestCount();
|
||||
long ingestTime = stats.getStats().getIngestTimeInMillis();
|
||||
long failureCount = stats.getStats().getIngestFailedCount();
|
||||
updateStats(docCountStats, ingestCount);
|
||||
updateStats(timeStats, ingestTime);
|
||||
updateStats(failureStats, failureCount);
|
||||
}
|
||||
}
|
||||
if (containsInference) {
|
||||
pipelines.add(pipelineId);
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
Map<String, Object> ingestUsage = new HashMap<>(6);
|
||||
ingestUsage.put("pipelines", createCountUsageEntry(pipelines.size()));
|
||||
ingestUsage.put("num_docs_processed", docCountStats);
|
||||
ingestUsage.put("time_ms", timeStats);
|
||||
ingestUsage.put("num_failures", failureStats);
|
||||
inferenceUsage.put("ingest_processors", Collections.singletonMap(MachineLearningFeatureSetUsage.ALL, ingestUsage));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.action;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.ActionFilters;
|
||||
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
||||
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.block.ClusterBlockException;
|
||||
import org.elasticsearch.cluster.block.ClusterBlockLevel;
|
||||
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.ingest.IngestMetadata;
|
||||
import org.elasticsearch.ingest.IngestService;
|
||||
import org.elasticsearch.ingest.Pipeline;
|
||||
import org.elasticsearch.ingest.PipelineConfiguration;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
|
||||
/**
|
||||
* The action is a master node action to ensure it reads an up-to-date cluster
|
||||
* state in order to determine if there is a processor referencing the trained model
|
||||
*/
|
||||
public class TransportDeleteTrainedModelAction
|
||||
extends TransportMasterNodeAction<DeleteTrainedModelAction.Request, AcknowledgedResponse> {
|
||||
|
||||
private static final Logger LOGGER = LogManager.getLogger(TransportDeleteTrainedModelAction.class);
|
||||
|
||||
private final TrainedModelProvider trainedModelProvider;
|
||||
private final InferenceAuditor auditor;
|
||||
private final IngestService ingestService;
|
||||
|
||||
@Inject
|
||||
public TransportDeleteTrainedModelAction(TransportService transportService, ClusterService clusterService,
|
||||
ThreadPool threadPool, ActionFilters actionFilters,
|
||||
IndexNameExpressionResolver indexNameExpressionResolver,
|
||||
TrainedModelProvider configProvider, InferenceAuditor auditor,
|
||||
IngestService ingestService) {
|
||||
super(DeleteTrainedModelAction.NAME, transportService, clusterService, threadPool, actionFilters,
|
||||
DeleteTrainedModelAction.Request::new, indexNameExpressionResolver);
|
||||
this.trainedModelProvider = configProvider;
|
||||
this.ingestService = ingestService;
|
||||
this.auditor = Objects.requireNonNull(auditor);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String executor() {
|
||||
return ThreadPool.Names.SAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AcknowledgedResponse read(StreamInput in) throws IOException {
|
||||
return new AcknowledgedResponse(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void masterOperation(DeleteTrainedModelAction.Request request,
|
||||
ClusterState state,
|
||||
ActionListener<AcknowledgedResponse> listener) {
|
||||
String id = request.getId();
|
||||
IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE);
|
||||
Set<String> referencedModels = getReferencedModelKeys(currentIngestMetadata);
|
||||
|
||||
if (referencedModels.contains(id)) {
|
||||
listener.onFailure(new ElasticsearchStatusException("Cannot delete model [{}] as it is still referenced by ingest processors",
|
||||
RestStatus.CONFLICT,
|
||||
id));
|
||||
return;
|
||||
}
|
||||
|
||||
trainedModelProvider.deleteTrainedModel(request.getId(), ActionListener.wrap(
|
||||
r -> {
|
||||
auditor.info(request.getId(), "trained model deleted");
|
||||
listener.onResponse(new AcknowledgedResponse(true));
|
||||
},
|
||||
listener::onFailure
|
||||
));
|
||||
}
|
||||
|
||||
private Set<String> getReferencedModelKeys(IngestMetadata ingestMetadata) {
|
||||
Set<String> allReferencedModelKeys = new HashSet<>();
|
||||
if (ingestMetadata == null) {
|
||||
return allReferencedModelKeys;
|
||||
}
|
||||
for(Map.Entry<String, PipelineConfiguration> entry : ingestMetadata.getPipelines().entrySet()) {
|
||||
String pipelineId = entry.getKey();
|
||||
Map<String, Object> config = entry.getValue().getConfigAsMap();
|
||||
try {
|
||||
Pipeline pipeline = Pipeline.create(pipelineId,
|
||||
config,
|
||||
ingestService.getProcessorFactories(),
|
||||
ingestService.getScriptService());
|
||||
pipeline.getProcessors().stream()
|
||||
.filter(p -> p instanceof InferenceProcessor)
|
||||
.map(p -> (InferenceProcessor) p)
|
||||
.map(InferenceProcessor::getModelId)
|
||||
.forEach(allReferencedModelKeys::add);
|
||||
} catch (Exception ex) {
|
||||
LOGGER.warn(new ParameterizedMessage("failed to load pipeline [{}]", pipelineId), ex);
|
||||
}
|
||||
}
|
||||
return allReferencedModelKeys;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected ClusterBlockException checkBlock(DeleteTrainedModelAction.Request request, ClusterState state) {
|
||||
return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.action;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.ActionFilters;
|
||||
import org.elasticsearch.action.support.HandledTransportAction;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Response;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Set;
|
||||
|
||||
|
||||
public class TransportGetTrainedModelsAction extends HandledTransportAction<Request, Response> {
|
||||
|
||||
private final TrainedModelProvider provider;
|
||||
@Inject
|
||||
public TransportGetTrainedModelsAction(TransportService transportService,
|
||||
ActionFilters actionFilters,
|
||||
TrainedModelProvider trainedModelProvider) {
|
||||
super(GetTrainedModelsAction.NAME, transportService, actionFilters, GetTrainedModelsAction.Request::new);
|
||||
this.provider = trainedModelProvider;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
|
||||
|
||||
Response.Builder responseBuilder = Response.builder();
|
||||
|
||||
ActionListener<Tuple<Long, Set<String>>> idExpansionListener = ActionListener.wrap(
|
||||
totalAndIds -> {
|
||||
responseBuilder.setTotalCount(totalAndIds.v1());
|
||||
|
||||
if (totalAndIds.v2().isEmpty()) {
|
||||
listener.onResponse(responseBuilder.build());
|
||||
return;
|
||||
}
|
||||
|
||||
if (request.isIncludeModelDefinition() && totalAndIds.v2().size() > 1) {
|
||||
listener.onFailure(
|
||||
ExceptionsHelper.badRequestException(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED)
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (request.isIncludeModelDefinition()) {
|
||||
provider.getTrainedModel(totalAndIds.v2().iterator().next(), true, ActionListener.wrap(
|
||||
config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
|
||||
listener::onFailure
|
||||
));
|
||||
} else {
|
||||
provider.getTrainedModels(totalAndIds.v2(), request.isAllowNoResources(), ActionListener.wrap(
|
||||
configs -> listener.onResponse(responseBuilder.setModels(configs).build()),
|
||||
listener::onFailure
|
||||
));
|
||||
}
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idExpansionListener);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,249 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.action;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
|
||||
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
|
||||
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
|
||||
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
|
||||
import org.elasticsearch.action.support.ActionFilters;
|
||||
import org.elasticsearch.action.support.HandledTransportAction;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.common.metrics.CounterMetric;
|
||||
import org.elasticsearch.ingest.IngestMetadata;
|
||||
import org.elasticsearch.ingest.IngestService;
|
||||
import org.elasticsearch.ingest.IngestStats;
|
||||
import org.elasticsearch.ingest.Pipeline;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
|
||||
|
||||
|
||||
public class TransportGetTrainedModelsStatsAction extends HandledTransportAction<GetTrainedModelsStatsAction.Request,
|
||||
GetTrainedModelsStatsAction.Response> {
|
||||
|
||||
private final Client client;
|
||||
private final ClusterService clusterService;
|
||||
private final IngestService ingestService;
|
||||
private final TrainedModelProvider trainedModelProvider;
|
||||
|
||||
@Inject
|
||||
public TransportGetTrainedModelsStatsAction(TransportService transportService,
|
||||
ActionFilters actionFilters,
|
||||
ClusterService clusterService,
|
||||
IngestService ingestService,
|
||||
TrainedModelProvider trainedModelProvider,
|
||||
Client client) {
|
||||
super(GetTrainedModelsStatsAction.NAME, transportService, actionFilters, GetTrainedModelsStatsAction.Request::new);
|
||||
this.client = client;
|
||||
this.clusterService = clusterService;
|
||||
this.ingestService = ingestService;
|
||||
this.trainedModelProvider = trainedModelProvider;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doExecute(Task task,
|
||||
GetTrainedModelsStatsAction.Request request,
|
||||
ActionListener<GetTrainedModelsStatsAction.Response> listener) {
|
||||
|
||||
GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder();
|
||||
|
||||
ActionListener<NodesStatsResponse> nodesStatsListener = ActionListener.wrap(
|
||||
nodesStatsResponse -> {
|
||||
Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByPipelineId(nodesStatsResponse,
|
||||
pipelineIdsByModelIds(clusterService.state(),
|
||||
ingestService,
|
||||
responseBuilder.getExpandedIds()));
|
||||
listener.onResponse(responseBuilder.setIngestStatsByModelId(modelIdIngestStats).build());
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
ActionListener<Tuple<Long, Set<String>>> idsListener = ActionListener.wrap(
|
||||
tuple -> {
|
||||
responseBuilder.setExpandedIds(tuple.v2())
|
||||
.setTotalModelCount(tuple.v1());
|
||||
String[] ingestNodes = ingestNodes(clusterService.state());
|
||||
NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear().ingest(true);
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, NodesStatsAction.INSTANCE, nodesStatsRequest, nodesStatsListener);
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idsListener);
|
||||
}
|
||||
|
||||
static Map<String, IngestStats> inferenceIngestStatsByPipelineId(NodesStatsResponse response,
|
||||
Map<String, Set<String>> modelIdToPipelineId) {
|
||||
|
||||
Map<String, IngestStats> ingestStatsMap = new HashMap<>();
|
||||
|
||||
modelIdToPipelineId.forEach((modelId, pipelineIds) -> {
|
||||
List<IngestStats> collectedStats = response.getNodes()
|
||||
.stream()
|
||||
.map(nodeStats -> ingestStatsForPipelineIds(nodeStats, pipelineIds))
|
||||
.collect(Collectors.toList());
|
||||
ingestStatsMap.put(modelId, mergeStats(collectedStats));
|
||||
});
|
||||
|
||||
return ingestStatsMap;
|
||||
}
|
||||
|
||||
static String[] ingestNodes(final ClusterState clusterState) {
|
||||
String[] ingestNodes = new String[clusterState.nodes().getIngestNodes().size()];
|
||||
Iterator<String> nodeIterator = clusterState.nodes().getIngestNodes().keysIt();
|
||||
int i = 0;
|
||||
while(nodeIterator.hasNext()) {
|
||||
ingestNodes[i++] = nodeIterator.next();
|
||||
}
|
||||
return ingestNodes;
|
||||
}
|
||||
|
||||
static Map<String, Set<String>> pipelineIdsByModelIds(ClusterState state, IngestService ingestService, Set<String> modelIds) {
|
||||
IngestMetadata ingestMetadata = state.metaData().custom(IngestMetadata.TYPE);
|
||||
Map<String, Set<String>> pipelineIdsByModelIds = new HashMap<>();
|
||||
if (ingestMetadata == null) {
|
||||
return pipelineIdsByModelIds;
|
||||
}
|
||||
|
||||
ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> {
|
||||
try {
|
||||
Pipeline pipeline = Pipeline.create(pipelineId,
|
||||
pipelineConfiguration.getConfigAsMap(),
|
||||
ingestService.getProcessorFactories(),
|
||||
ingestService.getScriptService());
|
||||
pipeline.getProcessors().forEach(processor -> {
|
||||
if (processor instanceof InferenceProcessor) {
|
||||
InferenceProcessor inferenceProcessor = (InferenceProcessor) processor;
|
||||
if (modelIds.contains(inferenceProcessor.getModelId())) {
|
||||
pipelineIdsByModelIds.computeIfAbsent(inferenceProcessor.getModelId(),
|
||||
m -> new LinkedHashSet<>()).add(pipelineId);
|
||||
}
|
||||
}
|
||||
});
|
||||
} catch (Exception ex) {
|
||||
throw new ElasticsearchException("unexpected failure gathering pipeline information", ex);
|
||||
}
|
||||
});
|
||||
|
||||
return pipelineIdsByModelIds;
|
||||
}
|
||||
|
||||
static IngestStats ingestStatsForPipelineIds(NodeStats nodeStats, Set<String> pipelineIds) {
|
||||
IngestStats fullNodeStats = nodeStats.getIngestStats();
|
||||
Map<String, List<IngestStats.ProcessorStat>> filteredProcessorStats = new HashMap<>(fullNodeStats.getProcessorStats());
|
||||
filteredProcessorStats.keySet().retainAll(pipelineIds);
|
||||
List<IngestStats.PipelineStat> filteredPipelineStats = fullNodeStats.getPipelineStats()
|
||||
.stream()
|
||||
.filter(pipelineStat -> pipelineIds.contains(pipelineStat.getPipelineId()))
|
||||
.collect(Collectors.toList());
|
||||
CounterMetric ingestCount = new CounterMetric();
|
||||
CounterMetric ingestTimeInMillis = new CounterMetric();
|
||||
CounterMetric ingestCurrent = new CounterMetric();
|
||||
CounterMetric ingestFailedCount = new CounterMetric();
|
||||
|
||||
filteredPipelineStats.forEach(pipelineStat -> {
|
||||
IngestStats.Stats stats = pipelineStat.getStats();
|
||||
ingestCount.inc(stats.getIngestCount());
|
||||
ingestTimeInMillis.inc(stats.getIngestTimeInMillis());
|
||||
ingestCurrent.inc(stats.getIngestCurrent());
|
||||
ingestFailedCount.inc(stats.getIngestFailedCount());
|
||||
});
|
||||
|
||||
return new IngestStats(
|
||||
new IngestStats.Stats(ingestCount.count(), ingestTimeInMillis.count(), ingestCurrent.count(), ingestFailedCount.count()),
|
||||
filteredPipelineStats,
|
||||
filteredProcessorStats);
|
||||
}
|
||||
|
||||
private static IngestStats mergeStats(List<IngestStats> ingestStatsList) {
|
||||
|
||||
Map<String, IngestStatsAccumulator> pipelineStatsAcc = new LinkedHashMap<>(ingestStatsList.size());
|
||||
Map<String, Map<String, IngestStatsAccumulator>> processorStatsAcc = new LinkedHashMap<>(ingestStatsList.size());
|
||||
IngestStatsAccumulator totalStats = new IngestStatsAccumulator();
|
||||
ingestStatsList.forEach(ingestStats -> {
|
||||
|
||||
ingestStats.getPipelineStats()
|
||||
.forEach(pipelineStat ->
|
||||
pipelineStatsAcc.computeIfAbsent(pipelineStat.getPipelineId(),
|
||||
p -> new IngestStatsAccumulator()).inc(pipelineStat.getStats()));
|
||||
|
||||
ingestStats.getProcessorStats()
|
||||
.forEach((pipelineId, processorStat) -> {
|
||||
Map<String, IngestStatsAccumulator> processorAcc = processorStatsAcc.computeIfAbsent(pipelineId,
|
||||
k -> new LinkedHashMap<>());
|
||||
processorStat.forEach(p ->
|
||||
processorAcc.computeIfAbsent(p.getName(),
|
||||
k -> new IngestStatsAccumulator(p.getType())).inc(p.getStats()));
|
||||
});
|
||||
|
||||
totalStats.inc(ingestStats.getTotalStats());
|
||||
});
|
||||
|
||||
List<IngestStats.PipelineStat> pipelineStatList = new ArrayList<>(pipelineStatsAcc.size());
|
||||
pipelineStatsAcc.forEach((pipelineId, accumulator) ->
|
||||
pipelineStatList.add(new IngestStats.PipelineStat(pipelineId, accumulator.build())));
|
||||
|
||||
Map<String, List<IngestStats.ProcessorStat>> processorStatList = new LinkedHashMap<>(processorStatsAcc.size());
|
||||
processorStatsAcc.forEach((pipelineId, accumulatorMap) -> {
|
||||
List<IngestStats.ProcessorStat> processorStats = new ArrayList<>(accumulatorMap.size());
|
||||
accumulatorMap.forEach((processorName, acc) ->
|
||||
processorStats.add(new IngestStats.ProcessorStat(processorName, acc.type, acc.build())));
|
||||
processorStatList.put(pipelineId, processorStats);
|
||||
});
|
||||
|
||||
return new IngestStats(totalStats.build(), pipelineStatList, processorStatList);
|
||||
}
|
||||
|
||||
private static class IngestStatsAccumulator {
|
||||
CounterMetric ingestCount = new CounterMetric();
|
||||
CounterMetric ingestTimeInMillis = new CounterMetric();
|
||||
CounterMetric ingestCurrent = new CounterMetric();
|
||||
CounterMetric ingestFailedCount = new CounterMetric();
|
||||
|
||||
String type;
|
||||
|
||||
IngestStatsAccumulator() {}
|
||||
|
||||
IngestStatsAccumulator(String type) {
|
||||
this.type = type;
|
||||
}
|
||||
|
||||
IngestStatsAccumulator inc(IngestStats.Stats s) {
|
||||
ingestCount.inc(s.getIngestCount());
|
||||
ingestTimeInMillis.inc(s.getIngestTimeInMillis());
|
||||
ingestCurrent.inc(s.getIngestCurrent());
|
||||
ingestFailedCount.inc(s.getIngestFailedCount());
|
||||
return this;
|
||||
}
|
||||
|
||||
IngestStats.Stats build() {
|
||||
return new IngestStats.Stats(ingestCount.count(), ingestTimeInMillis.count(), ingestCurrent.count(), ingestFailedCount.count());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.action;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.ActionFilters;
|
||||
import org.elasticsearch.action.support.HandledTransportAction;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.license.LicenseUtils;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.Model;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
|
||||
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
|
||||
|
||||
|
||||
public class TransportInferModelAction extends HandledTransportAction<InferModelAction.Request, InferModelAction.Response> {
|
||||
|
||||
private final ModelLoadingService modelLoadingService;
|
||||
private final Client client;
|
||||
private final XPackLicenseState licenseState;
|
||||
|
||||
@Inject
|
||||
public TransportInferModelAction(TransportService transportService,
|
||||
ActionFilters actionFilters,
|
||||
ModelLoadingService modelLoadingService,
|
||||
Client client,
|
||||
XPackLicenseState licenseState) {
|
||||
super(InferModelAction.NAME, transportService, actionFilters, InferModelAction.Request::new);
|
||||
this.modelLoadingService = modelLoadingService;
|
||||
this.client = client;
|
||||
this.licenseState = licenseState;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doExecute(Task task, InferModelAction.Request request, ActionListener<InferModelAction.Response> listener) {
|
||||
|
||||
if (licenseState.isMachineLearningAllowed() == false) {
|
||||
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
|
||||
return;
|
||||
}
|
||||
|
||||
ActionListener<Model> getModelListener = ActionListener.wrap(
|
||||
model -> {
|
||||
TypedChainTaskExecutor<InferenceResults> typedChainTaskExecutor =
|
||||
new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME),
|
||||
// run through all tasks
|
||||
r -> true,
|
||||
// Always fail immediately and return an error
|
||||
ex -> true);
|
||||
request.getObjectsToInfer().forEach(stringObjectMap ->
|
||||
typedChainTaskExecutor.add(chainedTask ->
|
||||
model.infer(stringObjectMap, request.getConfig(), chainedTask)));
|
||||
|
||||
typedChainTaskExecutor.execute(ActionListener.wrap(
|
||||
inferenceResultsInterfaces ->
|
||||
listener.onResponse(new InferModelAction.Response(inferenceResultsInterfaces)),
|
||||
listener::onFailure
|
||||
));
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
this.modelLoadingService.getModel(request.getModelId(), getModelListener);
|
||||
}
|
||||
}
|
|
@ -150,6 +150,8 @@ public class AnalyticsResultProcessor {
|
|||
.setMetadata(Collections.singletonMap("analytics_config",
|
||||
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
|
||||
.setDefinition(definition)
|
||||
.setEstimatedHeapMemory(definition.ramBytesUsed())
|
||||
.setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations())
|
||||
.setInput(new TrainedModelInput(fieldNames))
|
||||
.build();
|
||||
}
|
||||
|
|
|
@ -7,14 +7,17 @@ package org.elasticsearch.xpack.ml.dataframe.process.results;
|
|||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
|
||||
|
||||
public class AnalyticsResult implements ToXContentObject {
|
||||
|
||||
|
@ -67,7 +70,9 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
builder.field(PROGRESS_PERCENT.getPreferredName(), progressPercent);
|
||||
}
|
||||
if (inferenceModel != null) {
|
||||
builder.field(INFERENCE_MODEL.getPreferredName(), inferenceModel);
|
||||
builder.field(INFERENCE_MODEL.getPreferredName(),
|
||||
inferenceModel,
|
||||
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")));
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
|
|
@ -0,0 +1,297 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.inference.ingest;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.metadata.MetaData;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.settings.Setting;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.ingest.AbstractProcessor;
|
||||
import org.elasticsearch.ingest.ConfigurationUtils;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.ingest.IngestMetadata;
|
||||
import org.elasticsearch.ingest.IngestService;
|
||||
import org.elasticsearch.ingest.Pipeline;
|
||||
import org.elasticsearch.ingest.PipelineConfiguration;
|
||||
import org.elasticsearch.ingest.Processor;
|
||||
import org.elasticsearch.license.LicenseStateListener;
|
||||
import org.elasticsearch.license.LicenseUtils;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
|
||||
|
||||
public class InferenceProcessor extends AbstractProcessor {
|
||||
|
||||
// How many total inference processors are allowed to be used in the cluster.
|
||||
public static final Setting<Integer> MAX_INFERENCE_PROCESSORS = Setting.intSetting("xpack.ml.max_inference_processors",
|
||||
50,
|
||||
1,
|
||||
Setting.Property.Dynamic,
|
||||
Setting.Property.NodeScope);
|
||||
|
||||
public static final String TYPE = "inference";
|
||||
public static final String MODEL_ID = "model_id";
|
||||
public static final String INFERENCE_CONFIG = "inference_config";
|
||||
public static final String TARGET_FIELD = "target_field";
|
||||
public static final String FIELD_MAPPINGS = "field_mappings";
|
||||
public static final String MODEL_INFO_FIELD = "model_info_field";
|
||||
public static final String INCLUDE_MODEL_METADATA = "include_model_metadata";
|
||||
|
||||
private final Client client;
|
||||
private final String modelId;
|
||||
|
||||
private final String targetField;
|
||||
private final String modelInfoField;
|
||||
private final InferenceConfig inferenceConfig;
|
||||
private final Map<String, String> fieldMapping;
|
||||
private final boolean includeModelMetadata;
|
||||
|
||||
public InferenceProcessor(Client client,
|
||||
String tag,
|
||||
String targetField,
|
||||
String modelId,
|
||||
InferenceConfig inferenceConfig,
|
||||
Map<String, String> fieldMapping,
|
||||
String modelInfoField,
|
||||
boolean includeModelMetadata) {
|
||||
super(tag);
|
||||
this.client = ExceptionsHelper.requireNonNull(client, "client");
|
||||
this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD);
|
||||
this.modelInfoField = ExceptionsHelper.requireNonNull(modelInfoField, MODEL_INFO_FIELD);
|
||||
this.includeModelMetadata = includeModelMetadata;
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
||||
this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
|
||||
this.fieldMapping = ExceptionsHelper.requireNonNull(fieldMapping, FIELD_MAPPINGS);
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
return modelId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
|
||||
executeAsyncWithOrigin(client,
|
||||
ML_ORIGIN,
|
||||
InferModelAction.INSTANCE,
|
||||
this.buildRequest(ingestDocument),
|
||||
ActionListener.wrap(
|
||||
r -> {
|
||||
try {
|
||||
mutateDocument(r, ingestDocument);
|
||||
handler.accept(ingestDocument, null);
|
||||
} catch(ElasticsearchException ex) {
|
||||
handler.accept(ingestDocument, ex);
|
||||
}
|
||||
},
|
||||
e -> handler.accept(ingestDocument, e)
|
||||
));
|
||||
}
|
||||
|
||||
InferModelAction.Request buildRequest(IngestDocument ingestDocument) {
|
||||
Map<String, Object> fields = new HashMap<>(ingestDocument.getSourceAndMetadata());
|
||||
if (fieldMapping != null) {
|
||||
fieldMapping.forEach((src, dest) -> {
|
||||
Object srcValue = fields.remove(src);
|
||||
if (srcValue != null) {
|
||||
fields.put(dest, srcValue);
|
||||
}
|
||||
});
|
||||
}
|
||||
return new InferModelAction.Request(modelId, fields, inferenceConfig);
|
||||
}
|
||||
|
||||
void mutateDocument(InferModelAction.Response response, IngestDocument ingestDocument) {
|
||||
if (response.getInferenceResults().isEmpty()) {
|
||||
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
|
||||
if (includeModelMetadata) {
|
||||
ingestDocument.setFieldValue(modelInfoField + "." + MODEL_ID, modelId);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public IngestDocument execute(IngestDocument ingestDocument) {
|
||||
throw new UnsupportedOperationException("should never be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getType() {
|
||||
return TYPE;
|
||||
}
|
||||
|
||||
public static final class Factory implements Processor.Factory, Consumer<ClusterState>, LicenseStateListener {
|
||||
|
||||
private static final Logger logger = LogManager.getLogger(Factory.class);
|
||||
|
||||
private final Client client;
|
||||
private final IngestService ingestService;
|
||||
private final XPackLicenseState licenseState;
|
||||
private volatile int currentInferenceProcessors;
|
||||
private volatile int maxIngestProcessors;
|
||||
private volatile Version minNodeVersion = Version.CURRENT;
|
||||
private volatile boolean inferenceAllowed;
|
||||
|
||||
public Factory(Client client,
|
||||
ClusterService clusterService,
|
||||
Settings settings,
|
||||
IngestService ingestService,
|
||||
XPackLicenseState licenseState) {
|
||||
this.client = client;
|
||||
this.maxIngestProcessors = MAX_INFERENCE_PROCESSORS.get(settings);
|
||||
this.ingestService = ingestService;
|
||||
this.licenseState = licenseState;
|
||||
this.inferenceAllowed = licenseState.isMachineLearningAllowed();
|
||||
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_INFERENCE_PROCESSORS, this::setMaxIngestProcessors);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void accept(ClusterState state) {
|
||||
minNodeVersion = state.nodes().getMinNodeVersion();
|
||||
MetaData metaData = state.getMetaData();
|
||||
if (metaData == null) {
|
||||
currentInferenceProcessors = 0;
|
||||
return;
|
||||
}
|
||||
IngestMetadata ingestMetadata = metaData.custom(IngestMetadata.TYPE);
|
||||
if (ingestMetadata == null) {
|
||||
currentInferenceProcessors = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
int count = 0;
|
||||
for (PipelineConfiguration configuration : ingestMetadata.getPipelines().values()) {
|
||||
try {
|
||||
Pipeline pipeline = Pipeline.create(configuration.getId(),
|
||||
configuration.getConfigAsMap(),
|
||||
ingestService.getProcessorFactories(),
|
||||
ingestService.getScriptService());
|
||||
count += pipeline.getProcessors().stream().filter(processor -> processor instanceof InferenceProcessor).count();
|
||||
} catch (Exception ex) {
|
||||
logger.warn(new ParameterizedMessage("failure parsing pipeline config [{}]", configuration.getId()), ex);
|
||||
}
|
||||
}
|
||||
currentInferenceProcessors = count;
|
||||
}
|
||||
|
||||
// Used for testing
|
||||
int numInferenceProcessors() {
|
||||
return currentInferenceProcessors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceProcessor create(Map<String, Processor.Factory> processorFactories, String tag, Map<String, Object> config)
|
||||
throws Exception {
|
||||
|
||||
if (inferenceAllowed == false) {
|
||||
throw LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING);
|
||||
}
|
||||
|
||||
if (this.maxIngestProcessors <= currentInferenceProcessors) {
|
||||
throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. " +
|
||||
"Adjust the setting [{}]: [{}] if a greater number is desired.",
|
||||
RestStatus.CONFLICT,
|
||||
currentInferenceProcessors,
|
||||
MAX_INFERENCE_PROCESSORS.getKey(),
|
||||
maxIngestProcessors);
|
||||
}
|
||||
|
||||
boolean includeModelMetadata = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, INCLUDE_MODEL_METADATA, true);
|
||||
String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID);
|
||||
String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD);
|
||||
Map<String, String> fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS);
|
||||
InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG));
|
||||
String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "ml");
|
||||
// If multiple inference processors are in the same pipeline, it is wise to tag them
|
||||
// The tag will keep metadata entries from stepping on each other
|
||||
if (tag != null) {
|
||||
modelInfoField += "." + tag;
|
||||
}
|
||||
return new InferenceProcessor(client,
|
||||
tag,
|
||||
targetField,
|
||||
modelId,
|
||||
inferenceConfig,
|
||||
fieldMapping,
|
||||
modelInfoField,
|
||||
includeModelMetadata);
|
||||
}
|
||||
|
||||
// Package private for testing
|
||||
void setMaxIngestProcessors(int maxIngestProcessors) {
|
||||
logger.trace("updating setting maxIngestProcessors from [{}] to [{}]", this.maxIngestProcessors, maxIngestProcessors);
|
||||
this.maxIngestProcessors = maxIngestProcessors;
|
||||
}
|
||||
|
||||
InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) throws IOException {
|
||||
ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
|
||||
|
||||
if (inferenceConfig.size() != 1) {
|
||||
throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.",
|
||||
INFERENCE_CONFIG);
|
||||
}
|
||||
Object value = inferenceConfig.values().iterator().next();
|
||||
|
||||
if ((value instanceof Map<?, ?>) == false) {
|
||||
throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.",
|
||||
INFERENCE_CONFIG);
|
||||
}
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> valueMap = (Map<String, Object>)value;
|
||||
|
||||
if (inferenceConfig.containsKey(ClassificationConfig.NAME)) {
|
||||
checkSupportedVersion(new ClassificationConfig(0));
|
||||
return ClassificationConfig.fromMap(valueMap);
|
||||
} else if (inferenceConfig.containsKey(RegressionConfig.NAME)) {
|
||||
checkSupportedVersion(new RegressionConfig());
|
||||
return RegressionConfig.fromMap(valueMap);
|
||||
} else {
|
||||
throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
|
||||
inferenceConfig.keySet(),
|
||||
Arrays.asList(ClassificationConfig.NAME, RegressionConfig.NAME));
|
||||
}
|
||||
}
|
||||
|
||||
void checkSupportedVersion(InferenceConfig config) {
|
||||
if (config.getMinimalSupportedVersion().after(minNodeVersion)) {
|
||||
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION,
|
||||
config.getName(),
|
||||
config.getMinimalSupportedVersion(),
|
||||
minNodeVersion));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void licenseStateChanged() {
|
||||
this.inferenceAllowed = licenseState.isMachineLearningAllowed();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.inference.loadingservice;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class LocalModel implements Model {
|
||||
|
||||
private final TrainedModelDefinition trainedModelDefinition;
|
||||
private final String modelId;
|
||||
|
||||
public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) {
|
||||
this.trainedModelDefinition = trainedModelDefinition;
|
||||
this.modelId = modelId;
|
||||
}
|
||||
|
||||
long ramBytesUsed() {
|
||||
return trainedModelDefinition.ramBytesUsed();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModelId() {
|
||||
return modelId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getResultsType() {
|
||||
switch (trainedModelDefinition.getTrainedModel().targetType()) {
|
||||
case CLASSIFICATION:
|
||||
return ClassificationInferenceResults.NAME;
|
||||
case REGRESSION:
|
||||
return RegressionInferenceResults.NAME;
|
||||
default:
|
||||
throw ExceptionsHelper.badRequestException("Model [{}] has unsupported target type [{}]",
|
||||
modelId,
|
||||
trainedModelDefinition.getTrainedModel().targetType());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
|
||||
try {
|
||||
listener.onResponse(trainedModelDefinition.infer(fields, config));
|
||||
} catch (Exception e) {
|
||||
listener.onFailure(e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.inference.loadingservice;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public interface Model {
|
||||
|
||||
String getResultsType();
|
||||
|
||||
void infer(Map<String, Object> fields, InferenceConfig inferenceConfig, ActionListener<InferenceResults> listener);
|
||||
|
||||
String getModelId();
|
||||
}
|
|
@ -0,0 +1,370 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.inference.loadingservice;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.cluster.ClusterChangedEvent;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.ClusterStateListener;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.cache.Cache;
|
||||
import org.elasticsearch.common.cache.CacheBuilder;
|
||||
import org.elasticsearch.common.cache.RemovalNotification;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.settings.Setting;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.ingest.IngestMetadata;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
||||
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Queue;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* This is a thread safe model loading service.
|
||||
*
|
||||
* It will cache local models that are referenced by processors in memory (as long as it is instantiated on an ingest node).
|
||||
*
|
||||
* If more than one processor references the same model, that model will only be cached once.
|
||||
*/
|
||||
public class ModelLoadingService implements ClusterStateListener {
|
||||
|
||||
/**
|
||||
* The maximum size of the local model cache here in the loading service
|
||||
*
|
||||
* Once the limit is reached, LRU models are evicted in favor of new models
|
||||
*/
|
||||
public static final Setting<ByteSizeValue> INFERENCE_MODEL_CACHE_SIZE =
|
||||
Setting.byteSizeSetting("xpack.ml.inference_model.cache_size",
|
||||
new ByteSizeValue(1, ByteSizeUnit.GB),
|
||||
Setting.Property.NodeScope);
|
||||
|
||||
/**
|
||||
* How long should a model stay in the cache since its last access
|
||||
*
|
||||
* If nothing references a model via getModel for this configured timeValue, it will be evicted.
|
||||
*
|
||||
* Specifically, in the ingest scenario, a processor will call getModel whenever it needs to run inference. So, if a processor is not
|
||||
* executed for an extended period of time, the model will be evicted and will have to be loaded again when getModel is called.
|
||||
*
|
||||
*/
|
||||
public static final Setting<TimeValue> INFERENCE_MODEL_CACHE_TTL =
|
||||
Setting.timeSetting("xpack.ml.inference_model.time_to_live",
|
||||
new TimeValue(5, TimeUnit.MINUTES),
|
||||
new TimeValue(1, TimeUnit.MILLISECONDS),
|
||||
Setting.Property.NodeScope);
|
||||
|
||||
private static final Logger logger = LogManager.getLogger(ModelLoadingService.class);
|
||||
private final Cache<String, LocalModel> localModelCache;
|
||||
private final Set<String> referencedModels = new HashSet<>();
|
||||
private final Map<String, Queue<ActionListener<Model>>> loadingListeners = new HashMap<>();
|
||||
private final TrainedModelProvider provider;
|
||||
private final Set<String> shouldNotAudit;
|
||||
private final ThreadPool threadPool;
|
||||
private final InferenceAuditor auditor;
|
||||
private final ByteSizeValue maxCacheSize;
|
||||
|
||||
public ModelLoadingService(TrainedModelProvider trainedModelProvider,
|
||||
InferenceAuditor auditor,
|
||||
ThreadPool threadPool,
|
||||
ClusterService clusterService,
|
||||
Settings settings) {
|
||||
this.provider = trainedModelProvider;
|
||||
this.threadPool = threadPool;
|
||||
this.maxCacheSize = INFERENCE_MODEL_CACHE_SIZE.get(settings);
|
||||
this.auditor = auditor;
|
||||
this.shouldNotAudit = new HashSet<>();
|
||||
this.localModelCache = CacheBuilder.<String, LocalModel>builder()
|
||||
.setMaximumWeight(this.maxCacheSize.getBytes())
|
||||
.weigher((id, localModel) -> localModel.ramBytesUsed())
|
||||
.removalListener(this::cacheEvictionListener)
|
||||
.setExpireAfterAccess(INFERENCE_MODEL_CACHE_TTL.get(settings))
|
||||
.build();
|
||||
clusterService.addListener(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the model referenced by `modelId` and responds to the listener.
|
||||
*
|
||||
* This method first checks the local LRU cache for the model. If it is present, it is returned from cache.
|
||||
*
|
||||
* If it is not present, one of the following occurs:
|
||||
*
|
||||
* - If the model is referenced by a pipeline and is currently being loaded, the `modelActionListener`
|
||||
* is added to the list of listeners to be alerted when the model is fully loaded.
|
||||
* - If the model is referenced by a pipeline and is currently NOT being loaded, a new load attempt is made and the resulting
|
||||
* model will attempt to be cached for future reference
|
||||
* - If the models is NOT referenced by a pipeline, the model is simply loaded from the index and given to the listener.
|
||||
* It is not cached.
|
||||
*
|
||||
* @param modelId the model to get
|
||||
* @param modelActionListener the listener to alert when the model has been retrieved.
|
||||
*/
|
||||
public void getModel(String modelId, ActionListener<Model> modelActionListener) {
|
||||
LocalModel cachedModel = localModelCache.get(modelId);
|
||||
if (cachedModel != null) {
|
||||
modelActionListener.onResponse(cachedModel);
|
||||
logger.trace("[{}] loaded from cache", modelId);
|
||||
return;
|
||||
}
|
||||
if (loadModelIfNecessary(modelId, modelActionListener) == false) {
|
||||
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
|
||||
// by a simulated pipeline
|
||||
logger.trace("[{}] not actively loading, eager loading without cache", modelId);
|
||||
provider.getTrainedModel(modelId, true, ActionListener.wrap(
|
||||
trainedModelConfig ->
|
||||
modelActionListener.onResponse(new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition())),
|
||||
modelActionListener::onFailure
|
||||
));
|
||||
} else {
|
||||
logger.trace("[{}] is loading or loaded, added new listener to queue", modelId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the model is loaded and the listener has been given the cached model
|
||||
* Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded
|
||||
* Returns false if the model is not loaded or actively being loaded
|
||||
*/
|
||||
private boolean loadModelIfNecessary(String modelId, ActionListener<Model> modelActionListener) {
|
||||
synchronized (loadingListeners) {
|
||||
Model cachedModel = localModelCache.get(modelId);
|
||||
if (cachedModel != null) {
|
||||
modelActionListener.onResponse(cachedModel);
|
||||
return true;
|
||||
}
|
||||
// It is referenced by a pipeline, but the cache does not contain it
|
||||
if (referencedModels.contains(modelId)) {
|
||||
// If the loaded model is referenced there but is not present,
|
||||
// that means the previous load attempt failed or the model has been evicted
|
||||
// Attempt to load and cache the model if necessary
|
||||
if (loadingListeners.computeIfPresent(
|
||||
modelId,
|
||||
(storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) {
|
||||
logger.trace("[{}] attempting to load and cache", modelId);
|
||||
loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener));
|
||||
loadModel(modelId);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
// if the cachedModel entry is null, but there are listeners present, that means it is being loaded
|
||||
return loadingListeners.computeIfPresent(modelId,
|
||||
(storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) != null;
|
||||
} // synchronized (loadingListeners)
|
||||
}
|
||||
|
||||
private void loadModel(String modelId) {
|
||||
provider.getTrainedModel(modelId, true, ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
logger.debug("[{}] successfully loaded model", modelId);
|
||||
handleLoadSuccess(modelId, trainedModelConfig);
|
||||
},
|
||||
failure -> {
|
||||
logger.warn(new ParameterizedMessage("[{}] failed to load model", modelId), failure);
|
||||
handleLoadFailure(modelId, failure);
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelConfig) {
|
||||
Queue<ActionListener<Model>> listeners;
|
||||
LocalModel loadedModel = new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition());
|
||||
synchronized (loadingListeners) {
|
||||
listeners = loadingListeners.remove(modelId);
|
||||
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such
|
||||
// Consequently, we should not store the retrieved model
|
||||
if (listeners == null) {
|
||||
return;
|
||||
}
|
||||
localModelCache.put(modelId, loadedModel);
|
||||
shouldNotAudit.remove(modelId);
|
||||
} // synchronized (loadingListeners)
|
||||
for (ActionListener<Model> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
|
||||
listener.onResponse(loadedModel);
|
||||
}
|
||||
}
|
||||
|
||||
private void handleLoadFailure(String modelId, Exception failure) {
|
||||
Queue<ActionListener<Model>> listeners;
|
||||
synchronized (loadingListeners) {
|
||||
listeners = loadingListeners.remove(modelId);
|
||||
if (listeners == null) {
|
||||
return;
|
||||
}
|
||||
} // synchronized (loadingListeners)
|
||||
// If we failed to load and there were listeners present, that means that this model is referenced by a processor
|
||||
// Alert the listeners to the failure
|
||||
for (ActionListener<Model> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
|
||||
listener.onFailure(failure);
|
||||
}
|
||||
}
|
||||
|
||||
private void cacheEvictionListener(RemovalNotification<String, LocalModel> notification) {
|
||||
if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
|
||||
String msg = new ParameterizedMessage(
|
||||
"model cache entry evicted." +
|
||||
"current cache [{}] current max [{}] model size [{}]. " +
|
||||
"If this is undesired, consider updating setting [{}] or [{}].",
|
||||
new ByteSizeValue(localModelCache.weight()).getStringRep(),
|
||||
maxCacheSize.getStringRep(),
|
||||
new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(),
|
||||
INFERENCE_MODEL_CACHE_SIZE.getKey(),
|
||||
INFERENCE_MODEL_CACHE_TTL.getKey()).getFormattedMessage();
|
||||
auditIfNecessary(notification.getKey(), msg);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clusterChanged(ClusterChangedEvent event) {
|
||||
// If ingest data has not changed or if the current node is not an ingest node, don't bother caching models
|
||||
if (event.changedCustomMetaDataSet().contains(IngestMetadata.TYPE) == false ||
|
||||
event.state().nodes().getLocalNode().isIngestNode() == false) {
|
||||
return;
|
||||
}
|
||||
|
||||
ClusterState state = event.state();
|
||||
IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE);
|
||||
Set<String> allReferencedModelKeys = getReferencedModelKeys(currentIngestMetadata);
|
||||
if (allReferencedModelKeys.equals(referencedModels)) {
|
||||
return;
|
||||
}
|
||||
// The listeners still waiting for a model and we are canceling the load?
|
||||
List<Tuple<String, List<ActionListener<Model>>>> drainWithFailure = new ArrayList<>();
|
||||
Set<String> referencedModelsBeforeClusterState = null;
|
||||
Set<String> loadingModelBeforeClusterState = null;
|
||||
Set<String> removedModels = null;
|
||||
synchronized (loadingListeners) {
|
||||
referencedModelsBeforeClusterState = new HashSet<>(referencedModels);
|
||||
if (logger.isTraceEnabled()) {
|
||||
loadingModelBeforeClusterState = new HashSet<>(loadingListeners.keySet());
|
||||
}
|
||||
// If we had models still loading here but are no longer referenced
|
||||
// we should remove them from loadingListeners and alert the listeners
|
||||
for (String modelId : loadingListeners.keySet()) {
|
||||
if (allReferencedModelKeys.contains(modelId) == false) {
|
||||
drainWithFailure.add(Tuple.tuple(modelId, new ArrayList<>(loadingListeners.remove(modelId))));
|
||||
}
|
||||
}
|
||||
removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys);
|
||||
|
||||
// Remove all cached models that are not referenced by any processors
|
||||
removedModels.forEach(localModelCache::invalidate);
|
||||
// Remove the models that are no longer referenced
|
||||
referencedModels.removeAll(removedModels);
|
||||
shouldNotAudit.removeAll(removedModels);
|
||||
|
||||
// Remove all that are still referenced, i.e. the intersection of allReferencedModelKeys and referencedModels
|
||||
allReferencedModelKeys.removeAll(referencedModels);
|
||||
referencedModels.addAll(allReferencedModelKeys);
|
||||
|
||||
// Populate loadingListeners key so we know that we are currently loading the model
|
||||
for (String modelId : allReferencedModelKeys) {
|
||||
loadingListeners.put(modelId, new ArrayDeque<>());
|
||||
}
|
||||
} // synchronized (loadingListeners)
|
||||
if (logger.isTraceEnabled()) {
|
||||
if (loadingListeners.keySet().equals(loadingModelBeforeClusterState) == false) {
|
||||
logger.trace("cluster state event changed loading models: before {} after {}", loadingModelBeforeClusterState,
|
||||
loadingListeners.keySet());
|
||||
}
|
||||
if (referencedModels.equals(referencedModelsBeforeClusterState) == false) {
|
||||
logger.trace("cluster state event changed referenced models: before {} after {}", referencedModelsBeforeClusterState,
|
||||
referencedModels);
|
||||
}
|
||||
}
|
||||
for (Tuple<String, List<ActionListener<Model>>> modelAndListeners : drainWithFailure) {
|
||||
final String msg = new ParameterizedMessage(
|
||||
"Cancelling load of model [{}] as it is no longer referenced by a pipeline",
|
||||
modelAndListeners.v1()).getFormat();
|
||||
for (ActionListener<Model> listener : modelAndListeners.v2()) {
|
||||
listener.onFailure(new ElasticsearchException(msg));
|
||||
}
|
||||
}
|
||||
removedModels.forEach(this::auditUnreferencedModel);
|
||||
loadModels(allReferencedModelKeys);
|
||||
}
|
||||
|
||||
private void auditIfNecessary(String modelId, String msg) {
|
||||
if (shouldNotAudit.contains(modelId)) {
|
||||
logger.trace("[{}] {}", modelId, msg);
|
||||
return;
|
||||
}
|
||||
auditor.warning(modelId, msg);
|
||||
shouldNotAudit.add(modelId);
|
||||
logger.warn("[{}] {}", modelId, msg);
|
||||
}
|
||||
|
||||
private void loadModels(Set<String> modelIds) {
|
||||
if (modelIds.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
// Execute this on a utility thread as when the callbacks occur we don't want them tying up the cluster listener thread pool
|
||||
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
|
||||
for (String modelId : modelIds) {
|
||||
auditNewReferencedModel(modelId);
|
||||
this.loadModel(modelId);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void auditNewReferencedModel(String modelId) {
|
||||
auditor.info(modelId, "referenced by ingest processors. Attempting to load model into cache");
|
||||
}
|
||||
|
||||
private void auditUnreferencedModel(String modelId) {
|
||||
auditor.info(modelId, "no longer referenced by any processors");
|
||||
}
|
||||
|
||||
private static <T> Queue<T> addFluently(Queue<T> queue, T object) {
|
||||
queue.add(object);
|
||||
return queue;
|
||||
}
|
||||
|
||||
private static Set<String> getReferencedModelKeys(IngestMetadata ingestMetadata) {
|
||||
Set<String> allReferencedModelKeys = new HashSet<>();
|
||||
if (ingestMetadata == null) {
|
||||
return allReferencedModelKeys;
|
||||
}
|
||||
ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> {
|
||||
Object processors = pipelineConfiguration.getConfigAsMap().get("processors");
|
||||
if (processors instanceof List<?>) {
|
||||
for(Object processor : (List<?>)processors) {
|
||||
if (processor instanceof Map<?, ?>) {
|
||||
Object processorConfig = ((Map<?, ?>)processor).get(InferenceProcessor.TYPE);
|
||||
if (processorConfig instanceof Map<?, ?>) {
|
||||
Object modelId = ((Map<?, ?>)processorConfig).get(InferenceProcessor.MODEL_ID);
|
||||
if (modelId != null) {
|
||||
assert modelId instanceof String;
|
||||
allReferencedModelKeys.add(modelId.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
return allReferencedModelKeys;
|
||||
}
|
||||
|
||||
}
|
|
@ -24,6 +24,7 @@ import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappi
|
|||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.DYNAMIC;
|
||||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.ENABLED;
|
||||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.KEYWORD;
|
||||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.LONG;
|
||||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.PROPERTIES;
|
||||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TEXT;
|
||||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TYPE;
|
||||
|
@ -103,6 +104,12 @@ public final class InferenceInternalIndex {
|
|||
.endObject()
|
||||
.startObject(TrainedModelConfig.METADATA.getPreferredName())
|
||||
.field(ENABLED, false)
|
||||
.endObject()
|
||||
.startObject(TrainedModelConfig.ESTIMATED_OPERATIONS.getPreferredName())
|
||||
.field(TYPE, LONG)
|
||||
.endObject()
|
||||
.startObject(TrainedModelConfig.ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName())
|
||||
.field(TYPE, LONG)
|
||||
.endObject();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,10 +20,19 @@ import org.elasticsearch.action.index.IndexRequest;
|
|||
import org.elasticsearch.action.search.MultiSearchAction;
|
||||
import org.elasticsearch.action.search.MultiSearchRequestBuilder;
|
||||
import org.elasticsearch.action.search.MultiSearchResponse;
|
||||
import org.elasticsearch.action.search.SearchAction;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.IndicesOptions;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.CheckedBiFunction;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.regex.Regex;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
|
@ -32,12 +41,21 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
|
|||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.index.IndexNotFoundException;
|
||||
import org.elasticsearch.index.engine.VersionConflictEngineException;
|
||||
import org.elasticsearch.index.mapper.MapperService;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.index.reindex.DeleteByQueryAction;
|
||||
import org.elasticsearch.index.reindex.DeleteByQueryRequest;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.search.sort.SortBuilders;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
|
||||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
|
@ -47,10 +65,17 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
|
||||
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_FAILED_TO_DESERIALIZE;
|
||||
|
||||
public class TrainedModelProvider {
|
||||
|
||||
|
@ -189,6 +214,166 @@ public class TrainedModelProvider {
|
|||
multiSearchResponseActionListener);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all the provided trained config model objects
|
||||
*
|
||||
* NOTE:
|
||||
* This does no expansion on the ids.
|
||||
* It assumes that there are fewer than 10k.
|
||||
*/
|
||||
public void getTrainedModels(Set<String> modelIds, boolean allowNoResources, final ActionListener<List<TrainedModelConfig>> listener) {
|
||||
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0])));
|
||||
|
||||
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
||||
.addSort(TrainedModelConfig.MODEL_ID.getPreferredName(), SortOrder.ASC)
|
||||
.addSort("_index", SortOrder.DESC)
|
||||
.setQuery(queryBuilder)
|
||||
.request();
|
||||
|
||||
ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
|
||||
searchResponse -> {
|
||||
Set<String> observedIds = new HashSet<>(searchResponse.getHits().getHits().length, 1.0f);
|
||||
List<TrainedModelConfig> configs = new ArrayList<>(searchResponse.getHits().getHits().length);
|
||||
for(SearchHit searchHit : searchResponse.getHits().getHits()) {
|
||||
try {
|
||||
if (observedIds.contains(searchHit.getId()) == false) {
|
||||
configs.add(
|
||||
parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()).build()
|
||||
);
|
||||
observedIds.add(searchHit.getId());
|
||||
}
|
||||
} catch (IOException ex) {
|
||||
listener.onFailure(
|
||||
ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId()));
|
||||
return;
|
||||
}
|
||||
}
|
||||
// We previously expanded the IDs.
|
||||
// If the config has gone missing between then and now we should throw if allowNoResources is false
|
||||
// Otherwise, treat it as if it was never expanded to begin with.
|
||||
Set<String> missingConfigs = Sets.difference(modelIds, observedIds);
|
||||
if (missingConfigs.isEmpty() == false && allowNoResources == false) {
|
||||
listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
|
||||
return;
|
||||
}
|
||||
listener.onResponse(configs);
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler);
|
||||
}
|
||||
|
||||
public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener) {
|
||||
DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false);
|
||||
|
||||
request.indices(InferenceIndexConstants.INDEX_PATTERN);
|
||||
QueryBuilder query = QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
|
||||
request.setQuery(query);
|
||||
request.setRefresh(true);
|
||||
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, DeleteByQueryAction.INSTANCE, request, ActionListener.wrap(deleteResponse -> {
|
||||
if (deleteResponse.getDeleted() == 0) {
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||
return;
|
||||
}
|
||||
listener.onResponse(true);
|
||||
}, e -> {
|
||||
if (e.getClass() == IndexNotFoundException.class) {
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||
} else {
|
||||
listener.onFailure(e);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
public void expandIds(String idExpression,
|
||||
boolean allowNoResources,
|
||||
@Nullable PageParams pageParams,
|
||||
ActionListener<Tuple<Long, Set<String>>> idsListener) {
|
||||
String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
|
||||
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
|
||||
.sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName())
|
||||
// If there are no resources, there might be no mapping for the id field.
|
||||
// This makes sure we don't get an error if that happens.
|
||||
.unmappedType("long"))
|
||||
.query(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
|
||||
if (pageParams != null) {
|
||||
sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize());
|
||||
}
|
||||
sourceBuilder.trackTotalHits(true)
|
||||
// we only care about the item id's
|
||||
.fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null);
|
||||
|
||||
IndicesOptions indicesOptions = SearchRequest.DEFAULT_INDICES_OPTIONS;
|
||||
SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN)
|
||||
.indicesOptions(IndicesOptions.fromOptions(true,
|
||||
indicesOptions.allowNoIndices(),
|
||||
indicesOptions.expandWildcardsOpen(),
|
||||
indicesOptions.expandWildcardsClosed(),
|
||||
indicesOptions))
|
||||
.source(sourceBuilder);
|
||||
|
||||
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
|
||||
ML_ORIGIN,
|
||||
searchRequest,
|
||||
ActionListener.<SearchResponse>wrap(
|
||||
response -> {
|
||||
Set<String> foundResourceIds = new LinkedHashSet<>();
|
||||
long totalHitCount = response.getHits().getTotalHits().value;
|
||||
for (SearchHit hit : response.getHits().getHits()) {
|
||||
Map<String, Object> docSource = hit.getSourceAsMap();
|
||||
if (docSource == null) {
|
||||
continue;
|
||||
}
|
||||
Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName());
|
||||
if (idValue instanceof String) {
|
||||
foundResourceIds.add(idValue.toString());
|
||||
}
|
||||
}
|
||||
ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources);
|
||||
requiredMatches.filterMatchedIds(foundResourceIds);
|
||||
if (requiredMatches.hasUnmatchedIds()) {
|
||||
idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString()));
|
||||
} else {
|
||||
idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds));
|
||||
}
|
||||
},
|
||||
idsListener::onFailure
|
||||
),
|
||||
client::search);
|
||||
|
||||
}
|
||||
|
||||
private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
|
||||
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME));
|
||||
|
||||
if (Strings.isAllOrWildcard(tokens)) {
|
||||
return boolQuery;
|
||||
}
|
||||
// If the resourceId is not _all or *, we should see if it is a comma delimited string with wild-cards
|
||||
// e.g. id1,id2*,id3
|
||||
BoolQueryBuilder shouldQueries = new BoolQueryBuilder();
|
||||
List<String> terms = new ArrayList<>();
|
||||
for (String token : tokens) {
|
||||
if (Regex.isSimpleMatchPattern(token)) {
|
||||
shouldQueries.should(QueryBuilders.wildcardQuery(resourceIdField, token));
|
||||
} else {
|
||||
terms.add(token);
|
||||
}
|
||||
}
|
||||
if (terms.isEmpty() == false) {
|
||||
shouldQueries.should(QueryBuilders.termsQuery(resourceIdField, terms));
|
||||
}
|
||||
|
||||
if (shouldQueries.should().isEmpty() == false) {
|
||||
boolQuery.filter(shouldQueries);
|
||||
}
|
||||
return boolQuery;
|
||||
}
|
||||
|
||||
private static <T> T handleSearchItem(MultiSearchResponse.Item item,
|
||||
String resourceId,
|
||||
|
@ -202,23 +387,23 @@ public class TrainedModelProvider {
|
|||
return parseLeniently.apply(item.getResponse().getHits().getHits()[0].getSourceRef(), resourceId);
|
||||
}
|
||||
|
||||
private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws Exception {
|
||||
private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws IOException {
|
||||
try (InputStream stream = source.streamInput();
|
||||
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
||||
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
|
||||
return TrainedModelConfig.fromXContent(parser, true);
|
||||
} catch (Exception e) {
|
||||
} catch (IOException e) {
|
||||
logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws Exception {
|
||||
private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws IOException {
|
||||
try (InputStream stream = source.streamInput();
|
||||
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
||||
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
|
||||
return TrainedModelDefinition.fromXContent(parser, true).build();
|
||||
} catch (Exception e) {
|
||||
} catch (IOException e) {
|
||||
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), e);
|
||||
throw e;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.notifications;
|
||||
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor;
|
||||
import org.elasticsearch.xpack.core.ml.notifications.AuditorField;
|
||||
import org.elasticsearch.xpack.core.ml.notifications.InferenceAuditMessage;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
||||
|
||||
public class InferenceAuditor extends AbstractAuditor<InferenceAuditMessage> {
|
||||
|
||||
public InferenceAuditor(Client client, String nodeName) {
|
||||
super(client, nodeName, AuditorField.NOTIFICATIONS_INDEX, ML_ORIGIN, InferenceAuditMessage::new);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.rest.inference;
|
||||
|
||||
import org.elasticsearch.client.node.NodeClient;
|
||||
import org.elasticsearch.rest.BaseRestHandler;
|
||||
import org.elasticsearch.rest.RestController;
|
||||
import org.elasticsearch.rest.RestRequest;
|
||||
import org.elasticsearch.rest.action.RestToXContentListener;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.elasticsearch.rest.RestRequest.Method.DELETE;
|
||||
|
||||
public class RestDeleteTrainedModelAction extends BaseRestHandler {
|
||||
|
||||
public RestDeleteTrainedModelAction(RestController controller) {
|
||||
controller.registerHandler(
|
||||
DELETE, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "ml_delete_trained_models_action";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
|
||||
String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
|
||||
DeleteTrainedModelAction.Request request = new DeleteTrainedModelAction.Request(modelId);
|
||||
return channel -> client.execute(DeleteTrainedModelAction.INSTANCE, request, new RestToXContentListener<>(channel));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.rest.inference;
|
||||
|
||||
import org.elasticsearch.client.node.NodeClient;
|
||||
import org.elasticsearch.cluster.metadata.MetaData;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.rest.BaseRestHandler;
|
||||
import org.elasticsearch.rest.RestController;
|
||||
import org.elasticsearch.rest.RestRequest;
|
||||
import org.elasticsearch.rest.action.RestToXContentListener;
|
||||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.elasticsearch.rest.RestRequest.Method.GET;
|
||||
import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH;
|
||||
|
||||
public class RestGetTrainedModelsAction extends BaseRestHandler {
|
||||
|
||||
public RestGetTrainedModelsAction(RestController controller) {
|
||||
controller.registerHandler(
|
||||
GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this);
|
||||
controller.registerHandler(GET, MachineLearning.BASE_PATH + "inference", this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "ml_get_trained_models_action";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
|
||||
String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
|
||||
if (Strings.isNullOrEmpty(modelId)) {
|
||||
modelId = MetaData.ALL;
|
||||
}
|
||||
boolean includeModelDefinition = restRequest.paramAsBoolean(
|
||||
GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(),
|
||||
false
|
||||
);
|
||||
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition);
|
||||
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
|
||||
request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
|
||||
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
|
||||
}
|
||||
request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources()));
|
||||
return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.rest.inference;
|
||||
|
||||
import org.elasticsearch.client.node.NodeClient;
|
||||
import org.elasticsearch.cluster.metadata.MetaData;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.rest.BaseRestHandler;
|
||||
import org.elasticsearch.rest.RestController;
|
||||
import org.elasticsearch.rest.RestRequest;
|
||||
import org.elasticsearch.rest.action.RestToXContentListener;
|
||||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.elasticsearch.rest.RestRequest.Method.GET;
|
||||
import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH;
|
||||
|
||||
public class RestGetTrainedModelsStatsAction extends BaseRestHandler {
|
||||
|
||||
public RestGetTrainedModelsStatsAction(RestController controller) {
|
||||
controller.registerHandler(
|
||||
GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_stats", this);
|
||||
controller.registerHandler(GET, MachineLearning.BASE_PATH + "inference/_stats", this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "ml_get_trained_models_stats_action";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
|
||||
String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
|
||||
if (Strings.isNullOrEmpty(modelId)) {
|
||||
modelId = MetaData.ALL;
|
||||
}
|
||||
GetTrainedModelsStatsAction.Request request = new GetTrainedModelsStatsAction.Request(modelId);
|
||||
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
|
||||
request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
|
||||
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
|
||||
}
|
||||
request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources()));
|
||||
return channel -> client.execute(GetTrainedModelsStatsAction.INSTANCE, request, new RestToXContentListener<>(channel));
|
||||
}
|
||||
}
|
|
@ -6,13 +6,21 @@
|
|||
package org.elasticsearch.license;
|
||||
|
||||
import org.elasticsearch.ElasticsearchSecurityException;
|
||||
import org.elasticsearch.action.ingest.PutPipelineAction;
|
||||
import org.elasticsearch.action.ingest.PutPipelineRequest;
|
||||
import org.elasticsearch.action.ingest.SimulatePipelineAction;
|
||||
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
|
||||
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
||||
import org.elasticsearch.client.transport.TransportClient;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.index.mapper.MapperService;
|
||||
import org.elasticsearch.license.License.OperationMode;
|
||||
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
|
@ -24,6 +32,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteDatafeedAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
|
||||
|
@ -31,17 +40,24 @@ import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction;
|
||||
import org.elasticsearch.xpack.core.ml.client.MachineLearningClient;
|
||||
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.JobState;
|
||||
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
||||
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
|
||||
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasItem;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
||||
public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
|
||||
|
||||
|
@ -529,6 +545,216 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testMachineLearningCreateInferenceProcessorRestricted() {
|
||||
String modelId = "modelprocessorlicensetest";
|
||||
assertMLAllowed(true);
|
||||
putInferenceModel(modelId);
|
||||
|
||||
String pipeline = "{" +
|
||||
" \"processors\": [\n" +
|
||||
" {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"target_field\": \"regression_value\",\n" +
|
||||
" \"model_id\": \"modelprocessorlicensetest\",\n" +
|
||||
" \"inference_config\": {\"regression\": {}},\n" +
|
||||
" \"field_mappings\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
" \"col2\": \"col2\",\n" +
|
||||
" \"col3\": \"col3\",\n" +
|
||||
" \"col4\": \"col4\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }]}\n";
|
||||
// test that license restricted apis do now work
|
||||
PlainActionFuture<AcknowledgedResponse> putPipelineListener = PlainActionFuture.newFuture();
|
||||
client().execute(PutPipelineAction.INSTANCE,
|
||||
new PutPipelineRequest("test_infer_license_pipeline",
|
||||
new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON),
|
||||
putPipelineListener);
|
||||
AcknowledgedResponse putPipelineResponse = putPipelineListener.actionGet();
|
||||
assertTrue(putPipelineResponse.isAcknowledged());
|
||||
|
||||
String simulateSource = "{\n" +
|
||||
" \"pipeline\": \n" +
|
||||
pipeline +
|
||||
" ,\n" +
|
||||
" \"docs\": [\n" +
|
||||
" {\"_source\": {\n" +
|
||||
" \"col1\": \"female\",\n" +
|
||||
" \"col2\": \"M\",\n" +
|
||||
" \"col3\": \"none\",\n" +
|
||||
" \"col4\": 10\n" +
|
||||
" }}]\n" +
|
||||
"}";
|
||||
PlainActionFuture<SimulatePipelineResponse> simulatePipelineListener = PlainActionFuture.newFuture();
|
||||
client().execute(SimulatePipelineAction.INSTANCE,
|
||||
new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON),
|
||||
simulatePipelineListener);
|
||||
|
||||
assertThat(simulatePipelineListener.actionGet().getResults(), is(not(empty())));
|
||||
|
||||
|
||||
// Pick a license that does not allow machine learning
|
||||
License.OperationMode mode = randomInvalidLicenseType();
|
||||
enableLicensing(mode);
|
||||
assertMLAllowed(false);
|
||||
|
||||
// creating a new pipeline should fail
|
||||
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> {
|
||||
PlainActionFuture<AcknowledgedResponse> listener = PlainActionFuture.newFuture();
|
||||
client().execute(PutPipelineAction.INSTANCE,
|
||||
new PutPipelineRequest("test_infer_license_pipeline_failure",
|
||||
new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON),
|
||||
listener);
|
||||
listener.actionGet();
|
||||
});
|
||||
assertThat(e.status(), is(RestStatus.FORBIDDEN));
|
||||
assertThat(e.getMessage(), containsString("non-compliant"));
|
||||
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
|
||||
|
||||
// Simulating the pipeline should fail
|
||||
e = expectThrows(ElasticsearchSecurityException.class, () -> {
|
||||
PlainActionFuture<SimulatePipelineResponse> listener = PlainActionFuture.newFuture();
|
||||
client().execute(SimulatePipelineAction.INSTANCE,
|
||||
new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON),
|
||||
listener);
|
||||
listener.actionGet();
|
||||
});
|
||||
assertThat(e.status(), is(RestStatus.FORBIDDEN));
|
||||
assertThat(e.getMessage(), containsString("non-compliant"));
|
||||
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
|
||||
|
||||
// Pick a license that does allow machine learning
|
||||
mode = randomValidLicenseType();
|
||||
enableLicensing(mode);
|
||||
assertMLAllowed(true);
|
||||
// test that license restricted apis do now work
|
||||
PlainActionFuture<AcknowledgedResponse> putPipelineListenerNewLicense = PlainActionFuture.newFuture();
|
||||
client().execute(PutPipelineAction.INSTANCE,
|
||||
new PutPipelineRequest("test_infer_license_pipeline",
|
||||
new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)),
|
||||
XContentType.JSON),
|
||||
putPipelineListenerNewLicense);
|
||||
AcknowledgedResponse putPipelineResponseNewLicense = putPipelineListenerNewLicense.actionGet();
|
||||
assertTrue(putPipelineResponseNewLicense.isAcknowledged());
|
||||
|
||||
PlainActionFuture<SimulatePipelineResponse> simulatePipelineListenerNewLicense = PlainActionFuture.newFuture();
|
||||
client().execute(SimulatePipelineAction.INSTANCE,
|
||||
new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON),
|
||||
simulatePipelineListenerNewLicense);
|
||||
|
||||
assertThat(simulatePipelineListenerNewLicense.actionGet().getResults(), is(not(empty())));
|
||||
}
|
||||
|
||||
public void testMachineLearningInferModelRestricted() throws Exception {
|
||||
String modelId = "modelinfermodellicensetest";
|
||||
assertMLAllowed(true);
|
||||
putInferenceModel(modelId);
|
||||
|
||||
|
||||
PlainActionFuture<InferModelAction.Response> inferModelSuccess = PlainActionFuture.newFuture();
|
||||
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
|
||||
modelId,
|
||||
Collections.singletonList(Collections.emptyMap()),
|
||||
new RegressionConfig()
|
||||
), inferModelSuccess);
|
||||
assertThat(inferModelSuccess.actionGet().getInferenceResults(), is(not(empty())));
|
||||
|
||||
// Pick a license that does not allow machine learning
|
||||
License.OperationMode mode = randomInvalidLicenseType();
|
||||
enableLicensing(mode);
|
||||
assertMLAllowed(false);
|
||||
|
||||
// inferring against a model should now fail
|
||||
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> {
|
||||
PlainActionFuture<InferModelAction.Response> listener = PlainActionFuture.newFuture();
|
||||
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
|
||||
modelId,
|
||||
Collections.singletonList(Collections.emptyMap()),
|
||||
new RegressionConfig()
|
||||
), listener);
|
||||
listener.actionGet();
|
||||
});
|
||||
assertThat(e.status(), is(RestStatus.FORBIDDEN));
|
||||
assertThat(e.getMessage(), containsString("non-compliant"));
|
||||
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
|
||||
|
||||
// Pick a license that does allow machine learning
|
||||
mode = randomValidLicenseType();
|
||||
enableLicensing(mode);
|
||||
assertMLAllowed(true);
|
||||
|
||||
PlainActionFuture<InferModelAction.Response> listener = PlainActionFuture.newFuture();
|
||||
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
|
||||
modelId,
|
||||
Collections.singletonList(Collections.emptyMap()),
|
||||
new RegressionConfig()
|
||||
), listener);
|
||||
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
|
||||
}
|
||||
|
||||
private void putInferenceModel(String modelId) {
|
||||
String config = "" +
|
||||
"{\n" +
|
||||
" \"model_id\": \"" + modelId + "\",\n" +
|
||||
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
|
||||
" \"description\": \"test model for classification\",\n" +
|
||||
" \"version\": \"7.6.0\",\n" +
|
||||
" \"created_by\": \"benwtrent\",\n" +
|
||||
" \"estimated_heap_memory_usage_bytes\": 0,\n" +
|
||||
" \"estimated_operations\": 0,\n" +
|
||||
" \"created_time\": 0\n" +
|
||||
"}";
|
||||
String definition = "" +
|
||||
"{" +
|
||||
" \"trained_model\": {\n" +
|
||||
" \"tree\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col1_male\",\n" +
|
||||
" \"col1_female\",\n" +
|
||||
" \"col2_encoded\",\n" +
|
||||
" \"col3_encoded\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"tree_structure\": [\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"leaf_value\": 2\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
" }\n" +
|
||||
" }," +
|
||||
" \"model_id\": \"" + modelId + "\"\n" +
|
||||
"}";
|
||||
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||
.setId(modelId)
|
||||
.setSource(config, XContentType.JSON)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get().status(), equalTo(RestStatus.CREATED));
|
||||
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||
.setId(TrainedModelDefinition.docId(modelId))
|
||||
.setSource(definition, XContentType.JSON)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.get().status(), equalTo(RestStatus.CREATED));
|
||||
}
|
||||
|
||||
private static OperationMode randomInvalidLicenseType() {
|
||||
return randomFrom(License.OperationMode.GOLD, License.OperationMode.STANDARD, License.OperationMode.BASIC);
|
||||
}
|
||||
|
|
|
@ -8,8 +8,16 @@ package org.elasticsearch.xpack.ml;
|
|||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
|
||||
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
|
||||
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
|
||||
import org.elasticsearch.action.search.SearchAction;
|
||||
import org.elasticsearch.action.search.SearchRequestBuilder;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.cluster.ClusterName;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.metadata.MetaData;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNode;
|
||||
|
@ -20,13 +28,18 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.transport.TransportAddress;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.env.Environment;
|
||||
import org.elasticsearch.env.TestEnvironment;
|
||||
import org.elasticsearch.ingest.IngestStats;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchHits;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.XPackFeatureSet;
|
||||
import org.elasticsearch.xpack.core.XPackFeatureSet.Usage;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
|
@ -40,6 +53,7 @@ import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
|
|||
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
|
||||
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Detector;
|
||||
|
@ -49,10 +63,12 @@ import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeSta
|
|||
import org.elasticsearch.xpack.core.ml.stats.ForecastStats;
|
||||
import org.elasticsearch.xpack.core.ml.stats.ForecastStatsTests;
|
||||
import org.elasticsearch.xpack.core.watcher.support.xcontent.XContentSource;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.job.JobManager;
|
||||
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Date;
|
||||
|
@ -61,6 +77,8 @@ import java.util.HashSet;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.notNullValue;
|
||||
|
@ -98,6 +116,8 @@ public class MachineLearningFeatureSetTests extends ESTestCase {
|
|||
givenJobs(Collections.emptyList(), Collections.emptyList());
|
||||
givenDatafeeds(Collections.emptyList());
|
||||
givenDataFrameAnalytics(Collections.emptyList());
|
||||
givenProcessorStats(Collections.emptyList());
|
||||
givenTrainedModelConfigCount(0);
|
||||
}
|
||||
|
||||
public void testIsRunningOnMlPlatform() {
|
||||
|
@ -175,14 +195,54 @@ public class MachineLearningFeatureSetTests extends ESTestCase {
|
|||
buildDatafeedStats(DatafeedState.STARTED),
|
||||
buildDatafeedStats(DatafeedState.STOPPED)
|
||||
));
|
||||
|
||||
givenDataFrameAnalytics(Arrays.asList(
|
||||
buildDataFrameAnalyticsStats(DataFrameAnalyticsState.STOPPED),
|
||||
buildDataFrameAnalyticsStats(DataFrameAnalyticsState.STOPPED),
|
||||
buildDataFrameAnalyticsStats(DataFrameAnalyticsState.STARTED)
|
||||
));
|
||||
|
||||
givenProcessorStats(Arrays.asList(
|
||||
buildNodeStats(
|
||||
Arrays.asList("pipeline1", "pipeline2", "pipeline3"),
|
||||
Arrays.asList(
|
||||
Arrays.asList(
|
||||
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0)),
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)),
|
||||
new IngestStats.ProcessorStat(
|
||||
InferenceProcessor.TYPE,
|
||||
InferenceProcessor.TYPE,
|
||||
new IngestStats.Stats(100, 10, 0, 1))
|
||||
),
|
||||
Arrays.asList(
|
||||
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)),
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
|
||||
),
|
||||
Arrays.asList(
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
|
||||
)
|
||||
)),
|
||||
buildNodeStats(
|
||||
Arrays.asList("pipeline1", "pipeline2", "pipeline3"),
|
||||
Arrays.asList(
|
||||
Arrays.asList(
|
||||
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(0, 0, 0, 0)),
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(0, 0, 0, 0)),
|
||||
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0))
|
||||
),
|
||||
Arrays.asList(
|
||||
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)),
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
|
||||
),
|
||||
Arrays.asList(
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
|
||||
)
|
||||
))
|
||||
));
|
||||
givenTrainedModelConfigCount(100);
|
||||
|
||||
MachineLearningFeatureSet featureSet = new MachineLearningFeatureSet(TestEnvironment.newEnvironment(settings.build()),
|
||||
clusterService, client, licenseState, jobManagerHolder);
|
||||
clusterService, client, licenseState, jobManagerHolder);
|
||||
PlainActionFuture<Usage> future = new PlainActionFuture<>();
|
||||
featureSet.usage(future);
|
||||
XPackFeatureSet.Usage mlUsage = future.get();
|
||||
|
@ -258,6 +318,18 @@ public class MachineLearningFeatureSetTests extends ESTestCase {
|
|||
|
||||
assertThat(source.getValue("jobs.opened.forecasts.total"), equalTo(11));
|
||||
assertThat(source.getValue("jobs.opened.forecasts.forecasted_jobs"), equalTo(2));
|
||||
|
||||
assertThat(source.getValue("inference.trained_models._all.count"), equalTo(100));
|
||||
assertThat(source.getValue("inference.ingest_processors._all.pipelines.count"), equalTo(2));
|
||||
assertThat(source.getValue("inference.ingest_processors._all.num_docs_processed.sum"), equalTo(130));
|
||||
assertThat(source.getValue("inference.ingest_processors._all.num_docs_processed.min"), equalTo(0));
|
||||
assertThat(source.getValue("inference.ingest_processors._all.num_docs_processed.max"), equalTo(100));
|
||||
assertThat(source.getValue("inference.ingest_processors._all.time_ms.sum"), equalTo(14));
|
||||
assertThat(source.getValue("inference.ingest_processors._all.time_ms.min"), equalTo(0));
|
||||
assertThat(source.getValue("inference.ingest_processors._all.time_ms.max"), equalTo(10));
|
||||
assertThat(source.getValue("inference.ingest_processors._all.num_failures.sum"), equalTo(1));
|
||||
assertThat(source.getValue("inference.ingest_processors._all.num_failures.min"), equalTo(0));
|
||||
assertThat(source.getValue("inference.ingest_processors._all.num_failures.max"), equalTo(1));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -444,6 +516,34 @@ public class MachineLearningFeatureSetTests extends ESTestCase {
|
|||
}).when(client).execute(same(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
|
||||
}
|
||||
|
||||
private void givenProcessorStats(List<NodeStats> stats) {
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("unchecked")
|
||||
ActionListener<NodesStatsResponse> listener =
|
||||
(ActionListener<NodesStatsResponse>) invocationOnMock.getArguments()[2];
|
||||
listener.onResponse(new NodesStatsResponse(new ClusterName("_name"), stats, Collections.emptyList()));
|
||||
return Void.TYPE;
|
||||
}).when(client).execute(same(NodesStatsAction.INSTANCE), any(), any());
|
||||
}
|
||||
|
||||
private void givenTrainedModelConfigCount(long count) {
|
||||
when(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN))
|
||||
.thenReturn(new SearchRequestBuilder(client, SearchAction.INSTANCE));
|
||||
ThreadPool pool = mock(ThreadPool.class);
|
||||
when(pool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
|
||||
when(client.threadPool()).thenReturn(pool);
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("unchecked")
|
||||
ActionListener<SearchResponse> listener =
|
||||
(ActionListener<SearchResponse>) invocationOnMock.getArguments()[1];
|
||||
SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(count, TotalHits.Relation.EQUAL_TO), (float)0.0);
|
||||
SearchResponse searchResponse = mock(SearchResponse.class);
|
||||
when(searchResponse.getHits()).thenReturn(searchHits);
|
||||
listener.onResponse(searchResponse);
|
||||
return Void.TYPE;
|
||||
}).when(client).search(any(), any());
|
||||
}
|
||||
|
||||
private static Detector buildMinDetector(String fieldName) {
|
||||
Detector.Builder detectorBuilder = new Detector.Builder();
|
||||
detectorBuilder.setFunction("min");
|
||||
|
@ -490,6 +590,17 @@ public class MachineLearningFeatureSetTests extends ESTestCase {
|
|||
return stats;
|
||||
}
|
||||
|
||||
private static NodeStats buildNodeStats(List<String> pipelineNames, List<List<IngestStats.ProcessorStat>> processorStats) {
|
||||
IngestStats ingestStats = new IngestStats(
|
||||
new IngestStats.Stats(0,0,0,0),
|
||||
Collections.emptyList(),
|
||||
IntStream.range(0, pipelineNames.size()).boxed().collect(Collectors.toMap(pipelineNames::get, processorStats::get)));
|
||||
return new NodeStats(mock(DiscoveryNode.class),
|
||||
Instant.now().toEpochMilli(), null, null, null, null, null, null, null, null,
|
||||
null, null, null, ingestStats, null);
|
||||
|
||||
}
|
||||
|
||||
private static ForecastStats buildForecastStats(long numberOfForecasts) {
|
||||
return new ForecastStatsTests().createForecastStats(numberOfForecasts, numberOfForecasts);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,289 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.action;
|
||||
|
||||
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
|
||||
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.cluster.ClusterName;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.metadata.MetaData;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNode;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.ClusterSettings;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.ingest.IngestMetadata;
|
||||
import org.elasticsearch.ingest.IngestService;
|
||||
import org.elasticsearch.ingest.IngestStats;
|
||||
import org.elasticsearch.ingest.PipelineConfiguration;
|
||||
import org.elasticsearch.ingest.Processor;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.plugins.IngestPlugin;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasEntry;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
|
||||
|
||||
private static class NotInferenceProcessor implements Processor {
|
||||
|
||||
@Override
|
||||
public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
|
||||
return ingestDocument;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getType() {
|
||||
return "not_inference";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getTag() {
|
||||
return null;
|
||||
}
|
||||
|
||||
static class Factory implements Processor.Factory {
|
||||
|
||||
@Override
|
||||
public Processor create(Map<String, Processor.Factory> processorFactories, String tag, Map<String, Object> config) {
|
||||
return new NotInferenceProcessor();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static final IngestPlugin SKINNY_INGEST_PLUGIN = new IngestPlugin() {
|
||||
@Override
|
||||
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
|
||||
Map<String, Processor.Factory> factoryMap = new HashMap<>();
|
||||
XPackLicenseState licenseState = mock(XPackLicenseState.class);
|
||||
when(licenseState.isMachineLearningAllowed()).thenReturn(true);
|
||||
factoryMap.put(InferenceProcessor.TYPE,
|
||||
new InferenceProcessor.Factory(parameters.client,
|
||||
parameters.ingestService.getClusterService(),
|
||||
Settings.EMPTY,
|
||||
parameters.ingestService,
|
||||
licenseState));
|
||||
|
||||
factoryMap.put("not_inference", new NotInferenceProcessor.Factory());
|
||||
|
||||
return factoryMap;
|
||||
}
|
||||
};
|
||||
|
||||
private ClusterService clusterService;
|
||||
private IngestService ingestService;
|
||||
private Client client;
|
||||
|
||||
@Before
|
||||
public void setUpVariables() {
|
||||
ThreadPool tp = mock(ThreadPool.class);
|
||||
client = mock(Client.class);
|
||||
clusterService = mock(ClusterService.class);
|
||||
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY,
|
||||
Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS));
|
||||
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
|
||||
ingestService = new IngestService(clusterService, tp, null, null,
|
||||
null, Collections.singletonList(SKINNY_INGEST_PLUGIN), client);
|
||||
}
|
||||
|
||||
|
||||
public void testInferenceIngestStatsByPipelineId() throws IOException {
|
||||
List<NodeStats> nodeStatsList = Arrays.asList(
|
||||
buildNodeStats(
|
||||
new IngestStats.Stats(2, 2, 3, 4),
|
||||
Arrays.asList(
|
||||
new IngestStats.PipelineStat(
|
||||
"pipeline1",
|
||||
new IngestStats.Stats(0, 0, 3, 1)),
|
||||
new IngestStats.PipelineStat(
|
||||
"pipeline2",
|
||||
new IngestStats.Stats(1, 1, 0, 1)),
|
||||
new IngestStats.PipelineStat(
|
||||
"pipeline3",
|
||||
new IngestStats.Stats(2, 1, 1, 1))
|
||||
),
|
||||
Arrays.asList(
|
||||
Arrays.asList(
|
||||
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0)),
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)),
|
||||
new IngestStats.ProcessorStat(
|
||||
InferenceProcessor.TYPE,
|
||||
InferenceProcessor.TYPE,
|
||||
new IngestStats.Stats(100, 10, 0, 1))
|
||||
),
|
||||
Arrays.asList(
|
||||
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)),
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
|
||||
),
|
||||
Arrays.asList(
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
|
||||
)
|
||||
)),
|
||||
buildNodeStats(
|
||||
new IngestStats.Stats(15, 5, 3, 4),
|
||||
Arrays.asList(
|
||||
new IngestStats.PipelineStat(
|
||||
"pipeline1",
|
||||
new IngestStats.Stats(10, 1, 3, 1)),
|
||||
new IngestStats.PipelineStat(
|
||||
"pipeline2",
|
||||
new IngestStats.Stats(1, 1, 0, 1)),
|
||||
new IngestStats.PipelineStat(
|
||||
"pipeline3",
|
||||
new IngestStats.Stats(2, 1, 1, 1))
|
||||
),
|
||||
Arrays.asList(
|
||||
Arrays.asList(
|
||||
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(0, 0, 0, 0)),
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(0, 0, 0, 0)),
|
||||
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0))
|
||||
),
|
||||
Arrays.asList(
|
||||
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)),
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
|
||||
),
|
||||
Arrays.asList(
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
|
||||
)
|
||||
))
|
||||
);
|
||||
|
||||
NodesStatsResponse response = new NodesStatsResponse(new ClusterName("_name"), nodeStatsList, Collections.emptyList());
|
||||
|
||||
Map<String, Set<String>> pipelineIdsByModelIds = new HashMap<String, Set<String>>(){{
|
||||
put("trained_model_1", Collections.singleton("pipeline1"));
|
||||
put("trained_model_2", new HashSet<>(Arrays.asList("pipeline1", "pipeline2")));
|
||||
}};
|
||||
Map<String, IngestStats> ingestStatsMap = TransportGetTrainedModelsStatsAction.inferenceIngestStatsByPipelineId(response,
|
||||
pipelineIdsByModelIds);
|
||||
|
||||
assertThat(ingestStatsMap.keySet(), equalTo(new HashSet<>(Arrays.asList("trained_model_1", "trained_model_2"))));
|
||||
|
||||
IngestStats expectedStatsModel1 = new IngestStats(
|
||||
new IngestStats.Stats(10, 1, 6, 2),
|
||||
Collections.singletonList(new IngestStats.PipelineStat("pipeline1", new IngestStats.Stats(10, 1, 6, 2))),
|
||||
Collections.singletonMap("pipeline1", Arrays.asList(
|
||||
new IngestStats.ProcessorStat("inference", "inference", new IngestStats.Stats(120, 12, 0, 1)),
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))))
|
||||
);
|
||||
|
||||
IngestStats expectedStatsModel2 = new IngestStats(
|
||||
new IngestStats.Stats(12, 3, 6, 4),
|
||||
Arrays.asList(
|
||||
new IngestStats.PipelineStat("pipeline1", new IngestStats.Stats(10, 1, 6, 2)),
|
||||
new IngestStats.PipelineStat("pipeline2", new IngestStats.Stats(2, 2, 0, 2))),
|
||||
new HashMap<String, List<IngestStats.ProcessorStat>>() {{
|
||||
put("pipeline2", Arrays.asList(
|
||||
new IngestStats.ProcessorStat("inference", "inference", new IngestStats.Stats(10, 2, 0, 0)),
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(20, 2, 0, 0))));
|
||||
put("pipeline1", Arrays.asList(
|
||||
new IngestStats.ProcessorStat("inference", "inference", new IngestStats.Stats(120, 12, 0, 1)),
|
||||
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))));
|
||||
}}
|
||||
);
|
||||
|
||||
assertThat(ingestStatsMap, hasEntry("trained_model_1", expectedStatsModel1));
|
||||
assertThat(ingestStatsMap, hasEntry("trained_model_2", expectedStatsModel2));
|
||||
}
|
||||
|
||||
public void testPipelineIdsByModelIds() throws IOException {
|
||||
String modelId1 = "trained_model_1";
|
||||
String modelId2 = "trained_model_2";
|
||||
String modelId3 = "trained_model_3";
|
||||
Set<String> modelIds = new HashSet<>(Arrays.asList(modelId1, modelId2, modelId3));
|
||||
|
||||
ClusterState clusterState = buildClusterStateWithModelReferences(modelId1, modelId2, modelId3);
|
||||
|
||||
Map<String, Set<String>> pipelineIdsByModelIds =
|
||||
TransportGetTrainedModelsStatsAction.pipelineIdsByModelIds(clusterState, ingestService, modelIds);
|
||||
|
||||
assertThat(pipelineIdsByModelIds.keySet(), equalTo(modelIds));
|
||||
assertThat(pipelineIdsByModelIds,
|
||||
hasEntry(modelId1, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId1 + 0, "pipeline_with_model_" + modelId1 + 1))));
|
||||
assertThat(pipelineIdsByModelIds,
|
||||
hasEntry(modelId2, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId2 + 0, "pipeline_with_model_" + modelId2 + 1))));
|
||||
assertThat(pipelineIdsByModelIds,
|
||||
hasEntry(modelId3, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId3 + 0, "pipeline_with_model_" + modelId3 + 1))));
|
||||
|
||||
}
|
||||
|
||||
private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException {
|
||||
Map<String, PipelineConfiguration> configurations = new HashMap<>(modelId.length);
|
||||
for (String id : modelId) {
|
||||
configurations.put("pipeline_with_model_" + id + 0, newConfigurationWithInferenceProcessor(id, 0));
|
||||
configurations.put("pipeline_with_model_" + id + 1, newConfigurationWithInferenceProcessor(id, 1));
|
||||
}
|
||||
for (int i = 0; i < 3; i++) {
|
||||
configurations.put("pipeline_without_model_" + i, newConfigurationWithOutInferenceProcessor(i));
|
||||
}
|
||||
IngestMetadata ingestMetadata = new IngestMetadata(configurations);
|
||||
|
||||
return ClusterState.builder(new ClusterName("_name"))
|
||||
.metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId, int num) throws IOException {
|
||||
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors",
|
||||
Collections.singletonList(
|
||||
Collections.singletonMap(InferenceProcessor.TYPE,
|
||||
new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.MODEL_ID, modelId);
|
||||
put("inference_config", Collections.singletonMap("regression", Collections.emptyMap()));
|
||||
put("field_mappings", Collections.emptyMap());
|
||||
put("target_field", randomAlphaOfLength(10));
|
||||
}}))))) {
|
||||
return new PipelineConfiguration("pipeline_with_model_" + modelId + num,
|
||||
BytesReference.bytes(xContentBuilder),
|
||||
XContentType.JSON);
|
||||
}
|
||||
}
|
||||
|
||||
private static PipelineConfiguration newConfigurationWithOutInferenceProcessor(int i) throws IOException {
|
||||
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors",
|
||||
Collections.singletonList(Collections.singletonMap("not_inference", Collections.emptyMap()))))) {
|
||||
return new PipelineConfiguration("pipeline_without_model_" + i, BytesReference.bytes(xContentBuilder), XContentType.JSON);
|
||||
}
|
||||
}
|
||||
|
||||
private static NodeStats buildNodeStats(IngestStats.Stats overallStats,
|
||||
List<IngestStats.PipelineStat> pipelineNames,
|
||||
List<List<IngestStats.ProcessorStat>> processorStats) {
|
||||
List<String> pipelineids = pipelineNames.stream().map(IngestStats.PipelineStat::getPipelineId).collect(Collectors.toList());
|
||||
IngestStats ingestStats = new IngestStats(
|
||||
overallStats,
|
||||
pipelineNames,
|
||||
IntStream.range(0, pipelineids.size()).boxed().collect(Collectors.toMap(pipelineids::get, processorStats::get)));
|
||||
return new NodeStats(mock(DiscoveryNode.class),
|
||||
Instant.now().toEpochMilli(), null, null, null, null, null, null, null, null,
|
||||
null, null, null, ingestStats, null);
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -145,6 +145,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
|
||||
assertThat(storedModel.getDefinition(), equalTo(inferenceModel.build()));
|
||||
assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames));
|
||||
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed()));
|
||||
assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations()));
|
||||
Map<String, Object> metadata = storedModel.getMetadata();
|
||||
assertThat(metadata.size(), equalTo(1));
|
||||
assertThat(metadata, hasKey("analytics_config"));
|
||||
|
|
|
@ -40,7 +40,7 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
|
|||
progressPercent = randomIntBetween(0, 100);
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(null);
|
||||
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder("model");
|
||||
}
|
||||
return new AnalyticsResult(rowResults, progressPercent, inferenceModel);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,281 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.inference.ingest;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.cluster.ClusterName;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.metadata.MetaData;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNode;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNodes;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.ClusterSettings;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.transport.TransportAddress;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.ingest.IngestMetadata;
|
||||
import org.elasticsearch.ingest.IngestService;
|
||||
import org.elasticsearch.ingest.PipelineConfiguration;
|
||||
import org.elasticsearch.ingest.Processor;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.plugins.IngestPlugin;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.InetAddress;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class InferenceProcessorFactoryTests extends ESTestCase {
|
||||
|
||||
private static final IngestPlugin SKINNY_PLUGIN = new IngestPlugin() {
|
||||
@Override
|
||||
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
|
||||
XPackLicenseState licenseState = mock(XPackLicenseState.class);
|
||||
when(licenseState.isMachineLearningAllowed()).thenReturn(true);
|
||||
return Collections.singletonMap(InferenceProcessor.TYPE,
|
||||
new InferenceProcessor.Factory(parameters.client,
|
||||
parameters.ingestService.getClusterService(),
|
||||
Settings.EMPTY,
|
||||
parameters.ingestService,
|
||||
licenseState));
|
||||
}
|
||||
};
|
||||
private Client client;
|
||||
private XPackLicenseState licenseState;
|
||||
private ClusterService clusterService;
|
||||
private IngestService ingestService;
|
||||
|
||||
@Before
|
||||
public void setUpVariables() {
|
||||
ThreadPool tp = mock(ThreadPool.class);
|
||||
client = mock(Client.class);
|
||||
clusterService = mock(ClusterService.class);
|
||||
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY,
|
||||
Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS));
|
||||
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
|
||||
ingestService = new IngestService(clusterService, tp, null, null,
|
||||
null, Collections.singletonList(SKINNY_PLUGIN), client);
|
||||
licenseState = mock(XPackLicenseState.class);
|
||||
when(licenseState.isMachineLearningAllowed()).thenReturn(true);
|
||||
}
|
||||
|
||||
public void testNumInferenceProcessors() throws Exception {
|
||||
MetaData metaData = null;
|
||||
|
||||
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
|
||||
clusterService,
|
||||
Settings.EMPTY,
|
||||
ingestService,
|
||||
licenseState);
|
||||
processorFactory.accept(buildClusterState(metaData));
|
||||
|
||||
assertThat(processorFactory.numInferenceProcessors(), equalTo(0));
|
||||
metaData = MetaData.builder().build();
|
||||
|
||||
processorFactory.accept(buildClusterState(metaData));
|
||||
assertThat(processorFactory.numInferenceProcessors(), equalTo(0));
|
||||
|
||||
processorFactory.accept(buildClusterStateWithModelReferences("model1", "model2", "model3"));
|
||||
assertThat(processorFactory.numInferenceProcessors(), equalTo(3));
|
||||
}
|
||||
|
||||
public void testCreateProcessorWithTooManyExisting() throws Exception {
|
||||
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
|
||||
clusterService,
|
||||
Settings.builder().put(InferenceProcessor.MAX_INFERENCE_PROCESSORS.getKey(), 1).build(),
|
||||
ingestService,
|
||||
licenseState);
|
||||
|
||||
processorFactory.accept(buildClusterStateWithModelReferences("model1"));
|
||||
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", Collections.emptyMap()));
|
||||
|
||||
assertThat(ex.getMessage(), equalTo("Max number of inference processors reached, total inference processors [1]. " +
|
||||
"Adjust the setting [xpack.ml.max_inference_processors]: [1] if a greater number is desired."));
|
||||
}
|
||||
|
||||
public void testCreateProcessorWithInvalidInferenceConfig() {
|
||||
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
|
||||
clusterService,
|
||||
Settings.EMPTY,
|
||||
ingestService,
|
||||
licenseState);
|
||||
|
||||
Map<String, Object> config = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("unknown_type", Collections.emptyMap()));
|
||||
}};
|
||||
|
||||
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config));
|
||||
assertThat(ex.getMessage(),
|
||||
equalTo("unrecognized inference configuration type [unknown_type]. Supported types [classification, regression]"));
|
||||
|
||||
Map<String, Object> config2 = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("regression", "boom"));
|
||||
}};
|
||||
ex = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config2));
|
||||
assertThat(ex.getMessage(),
|
||||
equalTo("inference_config must be an object with one inference type mapped to an object."));
|
||||
|
||||
Map<String, Object> config3 = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.emptyMap());
|
||||
}};
|
||||
ex = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config3));
|
||||
assertThat(ex.getMessage(),
|
||||
equalTo("inference_config must be an object with one inference type mapped to an object."));
|
||||
}
|
||||
|
||||
public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException {
|
||||
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
|
||||
clusterService,
|
||||
Settings.EMPTY,
|
||||
ingestService,
|
||||
licenseState);
|
||||
processorFactory.accept(builderClusterStateWithModelReferences(Version.V_7_5_0, "model1"));
|
||||
|
||||
Map<String, Object> regression = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
|
||||
}};
|
||||
|
||||
try {
|
||||
processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression);
|
||||
fail("Should not have successfully created");
|
||||
} catch (ElasticsearchException ex) {
|
||||
assertThat(ex.getMessage(),
|
||||
equalTo("Configuration [regression] requires minimum node version [7.6.0] (current minimum node version [7.5.0]"));
|
||||
} catch (Exception ex) {
|
||||
fail(ex.getMessage());
|
||||
}
|
||||
|
||||
Map<String, Object> classification = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME,
|
||||
Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1)));
|
||||
}};
|
||||
|
||||
try {
|
||||
processorFactory.create(Collections.emptyMap(), "my_inference_processor", classification);
|
||||
fail("Should not have successfully created");
|
||||
} catch (ElasticsearchException ex) {
|
||||
assertThat(ex.getMessage(),
|
||||
equalTo("Configuration [classification] requires minimum node version [7.6.0] (current minimum node version [7.5.0]"));
|
||||
} catch (Exception ex) {
|
||||
fail(ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public void testCreateProcessor() {
|
||||
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
|
||||
clusterService,
|
||||
Settings.EMPTY,
|
||||
ingestService,
|
||||
licenseState);
|
||||
|
||||
Map<String, Object> regression = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
|
||||
}};
|
||||
|
||||
try {
|
||||
processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression);
|
||||
} catch (Exception ex) {
|
||||
fail(ex.getMessage());
|
||||
}
|
||||
|
||||
Map<String, Object> classification = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME,
|
||||
Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1)));
|
||||
}};
|
||||
|
||||
try {
|
||||
processorFactory.create(Collections.emptyMap(), "my_inference_processor", classification);
|
||||
} catch (Exception ex) {
|
||||
fail(ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static ClusterState buildClusterState(MetaData metaData) {
|
||||
return ClusterState.builder(new ClusterName("_name")).metaData(metaData).build();
|
||||
}
|
||||
|
||||
private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException {
|
||||
return builderClusterStateWithModelReferences(Version.CURRENT, modelId);
|
||||
}
|
||||
|
||||
private static ClusterState builderClusterStateWithModelReferences(Version minNodeVersion, String... modelId) throws IOException {
|
||||
Map<String, PipelineConfiguration> configurations = new HashMap<>(modelId.length);
|
||||
for (String id : modelId) {
|
||||
configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id));
|
||||
}
|
||||
IngestMetadata ingestMetadata = new IngestMetadata(configurations);
|
||||
|
||||
return ClusterState.builder(new ClusterName("_name"))
|
||||
.metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata))
|
||||
.nodes(DiscoveryNodes.builder()
|
||||
.add(new DiscoveryNode("min_node",
|
||||
new TransportAddress(InetAddress.getLoopbackAddress(), 9300),
|
||||
minNodeVersion))
|
||||
.add(new DiscoveryNode("current_node",
|
||||
new TransportAddress(InetAddress.getLoopbackAddress(), 9302),
|
||||
Version.CURRENT))
|
||||
.localNodeId("_node_id")
|
||||
.masterNodeId("_node_id"))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException {
|
||||
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors",
|
||||
Collections.singletonList(
|
||||
Collections.singletonMap(InferenceProcessor.TYPE,
|
||||
new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.MODEL_ID, modelId);
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
|
||||
put(InferenceProcessor.TARGET_FIELD, "new_field");
|
||||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.singletonMap("source", "dest"));
|
||||
}}))))) {
|
||||
return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,230 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.inference.ingest;
|
||||
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.core.Is.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class InferenceProcessorTests extends ESTestCase {
|
||||
|
||||
private Client client;
|
||||
|
||||
@Before
|
||||
public void setUpVariables() {
|
||||
client = mock(Client.class);
|
||||
}
|
||||
|
||||
public void testMutateDocumentWithClassification() {
|
||||
String targetField = "classification_value";
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
"my_processor",
|
||||
targetField,
|
||||
"classification_model",
|
||||
new ClassificationConfig(0),
|
||||
Collections.emptyMap(),
|
||||
"ml.my_processor",
|
||||
true);
|
||||
|
||||
Map<String, Object> source = new HashMap<>();
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
InferModelAction.Response response = new InferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", null)));
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat(document.getFieldValue(targetField, String.class), equalTo("foo"));
|
||||
assertThat(document.getFieldValue("ml", Map.class),
|
||||
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model"))));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testMutateDocumentClassificationTopNClasses() {
|
||||
String targetField = "classification_value_probabilities";
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
"my_processor",
|
||||
targetField,
|
||||
"classification_model",
|
||||
new ClassificationConfig(2),
|
||||
Collections.emptyMap(),
|
||||
"ml.my_processor",
|
||||
true);
|
||||
|
||||
Map<String, Object> source = new HashMap<>();
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
|
||||
|
||||
InferModelAction.Response response = new InferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes)));
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat((List<Map<?,?>>)document.getFieldValue(targetField, List.class),
|
||||
contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
|
||||
assertThat(document.getFieldValue("ml", Map.class),
|
||||
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model"))));
|
||||
}
|
||||
|
||||
public void testMutateDocumentRegression() {
|
||||
String targetField = "regression_value";
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
"my_processor",
|
||||
targetField,
|
||||
"regression_model",
|
||||
new RegressionConfig(),
|
||||
Collections.emptyMap(),
|
||||
"ml.my_processor",
|
||||
true);
|
||||
|
||||
Map<String, Object> source = new HashMap<>();
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
InferModelAction.Response response = new InferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7)));
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
|
||||
assertThat(document.getFieldValue("ml", Map.class),
|
||||
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "regression_model"))));
|
||||
}
|
||||
|
||||
public void testMutateDocumentNoModelMetaData() {
|
||||
String targetField = "regression_value";
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
"my_processor",
|
||||
targetField,
|
||||
"regression_model",
|
||||
new RegressionConfig(),
|
||||
Collections.emptyMap(),
|
||||
"ml.my_processor",
|
||||
false);
|
||||
|
||||
Map<String, Object> source = new HashMap<>();
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
InferModelAction.Response response = new InferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7)));
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
|
||||
assertThat(document.hasField("ml"), is(false));
|
||||
}
|
||||
|
||||
public void testMutateDocumentModelMetaDataExistingField() {
|
||||
String targetField = "regression_value";
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
"my_processor",
|
||||
targetField,
|
||||
"regression_model",
|
||||
new RegressionConfig(),
|
||||
Collections.emptyMap(),
|
||||
"ml.my_processor",
|
||||
true);
|
||||
|
||||
//cannot use singleton map as attempting to mutate later
|
||||
Map<String, Object> ml = new HashMap<String, Object>(){{
|
||||
put("regression_prediction", 0.55);
|
||||
}};
|
||||
Map<String, Object> source = new HashMap<String, Object>(){{
|
||||
put("ml", ml);
|
||||
}};
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
InferModelAction.Response response = new InferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7)));
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
|
||||
assertThat(document.getFieldValue("ml", Map.class),
|
||||
equalTo(new HashMap<String, Object>(){{
|
||||
put("my_processor", Collections.singletonMap("model_id", "regression_model"));
|
||||
put("regression_prediction", 0.55);
|
||||
}}));
|
||||
}
|
||||
|
||||
public void testGenerateRequestWithEmptyMapping() {
|
||||
String modelId = "model";
|
||||
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
|
||||
|
||||
InferenceProcessor processor = new InferenceProcessor(client,
|
||||
"my_processor",
|
||||
"my_field",
|
||||
modelId,
|
||||
new ClassificationConfig(topNClasses),
|
||||
Collections.emptyMap(),
|
||||
"ml.my_processor",
|
||||
false);
|
||||
|
||||
Map<String, Object> source = new HashMap<String, Object>(){{
|
||||
put("value1", 1);
|
||||
put("value2", 4);
|
||||
put("categorical", "foo");
|
||||
}};
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(source));
|
||||
}
|
||||
|
||||
public void testGenerateWithMapping() {
|
||||
String modelId = "model";
|
||||
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
|
||||
|
||||
Map<String, String> fieldMapping = new HashMap<String, String>(3) {{
|
||||
put("value1", "new_value1");
|
||||
put("value2", "new_value2");
|
||||
put("categorical", "new_categorical");
|
||||
}};
|
||||
|
||||
InferenceProcessor processor = new InferenceProcessor(client,
|
||||
"my_processor",
|
||||
"my_field",
|
||||
modelId,
|
||||
new ClassificationConfig(topNClasses),
|
||||
fieldMapping,
|
||||
"ml.my_processor",
|
||||
false);
|
||||
|
||||
Map<String, Object> source = new HashMap<String, Object>(3){{
|
||||
put("value1", 1);
|
||||
put("categorical", "foo");
|
||||
put("un_touched", "bar");
|
||||
}};
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
Map<String, Object> expectedMap = new HashMap<String, Object>(2) {{
|
||||
put("new_value1", 1);
|
||||
put("new_categorical", "foo");
|
||||
put("un_touched", "bar");
|
||||
}};
|
||||
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,212 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.inference.loadingservice;
|
||||
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
||||
public class LocalModelTests extends ESTestCase {
|
||||
|
||||
public void testClassificationInfer() throws Exception {
|
||||
String modelId = "classification_model";
|
||||
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setTrainedModel(buildClassification(false))
|
||||
.build();
|
||||
|
||||
Model model = new LocalModel(modelId, definition);
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.5);
|
||||
put("categorical", "dog");
|
||||
}};
|
||||
|
||||
SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfig(0));
|
||||
assertThat(result.value(), equalTo(0.0));
|
||||
assertThat(result.valueAsString(), is("0.0"));
|
||||
|
||||
ClassificationInferenceResults classificationResult =
|
||||
(ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1));
|
||||
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
|
||||
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0"));
|
||||
|
||||
// Test with labels
|
||||
definition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setTrainedModel(buildClassification(true))
|
||||
.build();
|
||||
model = new LocalModel(modelId, definition);
|
||||
result = getSingleValue(model, fields, new ClassificationConfig(0));
|
||||
assertThat(result.value(), equalTo(0.0));
|
||||
assertThat(result.valueAsString(), equalTo("not_to_be"));
|
||||
|
||||
classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1));
|
||||
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
|
||||
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be"));
|
||||
|
||||
classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(2));
|
||||
assertThat(classificationResult.getTopClasses(), hasSize(2));
|
||||
|
||||
classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(-1));
|
||||
assertThat(classificationResult.getTopClasses(), hasSize(2));
|
||||
}
|
||||
|
||||
public void testRegression() throws Exception {
|
||||
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setTrainedModel(buildRegression())
|
||||
.build();
|
||||
Model model = new LocalModel("regression_model", trainedModelDefinition);
|
||||
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.5);
|
||||
put("categorical", "dog");
|
||||
}};
|
||||
|
||||
SingleValueInferenceResults results = getSingleValue(model, fields, new RegressionConfig());
|
||||
assertThat(results.value(), equalTo(1.3));
|
||||
|
||||
PlainActionFuture<InferenceResults> failedFuture = new PlainActionFuture<>();
|
||||
model.infer(fields, new ClassificationConfig(2), failedFuture);
|
||||
ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get);
|
||||
assertThat(ex.getCause().getMessage(),
|
||||
equalTo("Cannot infer using configuration for [classification] when model target_type is [regression]"));
|
||||
}
|
||||
|
||||
private static SingleValueInferenceResults getSingleValue(Model model,
|
||||
Map<String, Object> fields,
|
||||
InferenceConfig config) throws Exception {
|
||||
PlainActionFuture<InferenceResults> future = new PlainActionFuture<>();
|
||||
model.infer(fields, config, future);
|
||||
return (SingleValueInferenceResults)future.get();
|
||||
}
|
||||
|
||||
private static Map<String, String> oneHotMap() {
|
||||
Map<String, String> oneHotEncoding = new HashMap<>();
|
||||
oneHotEncoding.put("cat", "animal_cat");
|
||||
oneHotEncoding.put("dog", "animal_dog");
|
||||
return oneHotEncoding;
|
||||
}
|
||||
|
||||
public static TrainedModel buildClassification(boolean includeLabels) {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog");
|
||||
Tree tree1 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(1.0))
|
||||
.addNode(TreeNode.builder(2)
|
||||
.setThreshold(0.8)
|
||||
.setSplitFeature(1)
|
||||
.setLeftChild(3)
|
||||
.setRightChild(4))
|
||||
.addNode(TreeNode.builder(3).setLeafValue(0.0))
|
||||
.addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
|
||||
Tree tree2 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(3)
|
||||
.setThreshold(1.0))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(0.0))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(1.0))
|
||||
.build();
|
||||
Tree tree3 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(1.0))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(1.0))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(0.0))
|
||||
.build();
|
||||
return Ensemble.builder()
|
||||
.setClassificationLabels(includeLabels ? Arrays.asList("not_to_be", "to_be") : null)
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}))
|
||||
.build();
|
||||
}
|
||||
|
||||
public static TrainedModel buildRegression() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog");
|
||||
Tree tree1 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(0.3))
|
||||
.addNode(TreeNode.builder(2)
|
||||
.setThreshold(0.0)
|
||||
.setSplitFeature(3)
|
||||
.setLeftChild(3)
|
||||
.setRightChild(4))
|
||||
.addNode(TreeNode.builder(3).setLeafValue(0.1))
|
||||
.addNode(TreeNode.builder(4).setLeafValue(0.2)).build();
|
||||
Tree tree2 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(2)
|
||||
.setThreshold(1.0))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(1.5))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(0.9))
|
||||
.build();
|
||||
Tree tree3 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(0.2))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(1.5))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(0.9))
|
||||
.build();
|
||||
return Ensemble.builder()
|
||||
.setTargetType(TargetType.REGRESSION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5, 0.5}))
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,363 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.inference.loadingservice;
|
||||
|
||||
import org.elasticsearch.ResourceNotFoundException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.cluster.ClusterChangedEvent;
|
||||
import org.elasticsearch.cluster.ClusterName;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.ClusterStateListener;
|
||||
import org.elasticsearch.cluster.metadata.MetaData;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNode;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNodes;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.transport.TransportAddress;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.ingest.IngestMetadata;
|
||||
import org.elasticsearch.ingest.PipelineConfiguration;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
|
||||
import org.elasticsearch.threadpool.TestThreadPool;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.InetAddress;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.atMost;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class ModelLoadingServiceTests extends ESTestCase {
|
||||
|
||||
private TrainedModelProvider trainedModelProvider;
|
||||
private ThreadPool threadPool;
|
||||
private ClusterService clusterService;
|
||||
private InferenceAuditor auditor;
|
||||
|
||||
@Before
|
||||
public void setUpComponents() {
|
||||
threadPool = new TestThreadPool("ModelLoadingServiceTests", new ScalingExecutorBuilder(UTILITY_THREAD_POOL_NAME,
|
||||
1, 4, TimeValue.timeValueMinutes(10), "xpack.ml.utility_thread_pool"));
|
||||
trainedModelProvider = mock(TrainedModelProvider.class);
|
||||
clusterService = mock(ClusterService.class);
|
||||
auditor = mock(InferenceAuditor.class);
|
||||
doAnswer(a -> null).when(auditor).error(any(String.class), any(String.class));
|
||||
doAnswer(a -> null).when(auditor).info(any(String.class), any(String.class));
|
||||
doAnswer(a -> null).when(auditor).warning(any(String.class), any(String.class));
|
||||
doAnswer((invocationOnMock) -> null).when(clusterService).addListener(any(ClusterStateListener.class));
|
||||
when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("_name")).build());
|
||||
}
|
||||
|
||||
@After
|
||||
public void terminateThreadPool() {
|
||||
terminate(threadPool);
|
||||
}
|
||||
|
||||
public void testGetCachedModels() throws Exception {
|
||||
String model1 = "test-load-model-1";
|
||||
String model2 = "test-load-model-2";
|
||||
String model3 = "test-load-model-3";
|
||||
withTrainedModel(model1, 1L);
|
||||
withTrainedModel(model2, 1L);
|
||||
withTrainedModel(model3, 1L);
|
||||
|
||||
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
Settings.EMPTY);
|
||||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
||||
|
||||
String[] modelIds = new String[]{model1, model2, model3};
|
||||
for(int i = 0; i < 10; i++) {
|
||||
String model = modelIds[i%3];
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
|
||||
|
||||
// Test invalidate cache for model3
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
|
||||
for(int i = 0; i < 10; i++) {
|
||||
String model = modelIds[i%3];
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
|
||||
// It is not referenced, so called eagerly
|
||||
verify(trainedModelProvider, times(4)).getTrainedModel(eq(model3), eq(true), any());
|
||||
}
|
||||
|
||||
public void testMaxCachedLimitReached() throws Exception {
|
||||
String model1 = "test-cached-limit-load-model-1";
|
||||
String model2 = "test-cached-limit-load-model-2";
|
||||
String model3 = "test-cached-limit-load-model-3";
|
||||
withTrainedModel(model1, 10L);
|
||||
withTrainedModel(model2, 5L);
|
||||
withTrainedModel(model3, 15L);
|
||||
|
||||
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build());
|
||||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
||||
|
||||
// Should have been loaded from the cluster change event
|
||||
// Verify that we have at least loaded all three so that evictions occur in the following loop
|
||||
assertBusy(() -> {
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
|
||||
});
|
||||
|
||||
String[] modelIds = new String[]{model1, model2, model3};
|
||||
for(int i = 0; i < 10; i++) {
|
||||
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
|
||||
String model = modelIds[i%2];
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any());
|
||||
// Only loaded requested once on the initial load from the change event
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
|
||||
|
||||
// Load model 3, should invalidate 1
|
||||
for(int i = 0; i < 10; i++) {
|
||||
PlainActionFuture<Model> future3 = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model3, future3);
|
||||
assertThat(future3.get(), is(not(nullValue())));
|
||||
}
|
||||
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model3), eq(true), any());
|
||||
|
||||
// Load model 1, should invalidate 2
|
||||
for(int i = 0; i < 10; i++) {
|
||||
PlainActionFuture<Model> future1 = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model1, future1);
|
||||
assertThat(future1.get(), is(not(nullValue())));
|
||||
}
|
||||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
|
||||
|
||||
// Load model 2, should invalidate 3
|
||||
for(int i = 0; i < 10; i++) {
|
||||
PlainActionFuture<Model> future2 = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model2, future2);
|
||||
assertThat(future2.get(), is(not(nullValue())));
|
||||
}
|
||||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
|
||||
|
||||
|
||||
// Test invalidate cache for model3
|
||||
// Now both model 1 and 2 should fit in cache without issues
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
|
||||
for(int i = 0; i < 10; i++) {
|
||||
String model = modelIds[i%3];
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
|
||||
verify(trainedModelProvider, Mockito.atLeast(4)).getTrainedModel(eq(model3), eq(true), any());
|
||||
verify(trainedModelProvider, Mockito.atMost(5)).getTrainedModel(eq(model3), eq(true), any());
|
||||
}
|
||||
|
||||
|
||||
public void testWhenCacheEnabledButNotIngestNode() throws Exception {
|
||||
String model1 = "test-uncached-not-ingest-model-1";
|
||||
withTrainedModel(model1, 1L);
|
||||
|
||||
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
Settings.EMPTY);
|
||||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(false, model1));
|
||||
|
||||
for(int i = 0; i < 10; i++) {
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model1, future);
|
||||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
verify(trainedModelProvider, times(10)).getTrainedModel(eq(model1), eq(true), any());
|
||||
}
|
||||
|
||||
public void testGetCachedMissingModel() throws Exception {
|
||||
String model = "test-load-cached-missing-model";
|
||||
withMissingModel(model);
|
||||
|
||||
ModelLoadingService modelLoadingService =new ModelLoadingService(trainedModelProvider,
|
||||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
Settings.EMPTY);
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model));
|
||||
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
|
||||
try {
|
||||
future.get();
|
||||
fail("Should not have succeeded in loaded model");
|
||||
} catch (Exception ex) {
|
||||
assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model)));
|
||||
}
|
||||
|
||||
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(true), any());
|
||||
}
|
||||
|
||||
public void testGetMissingModel() {
|
||||
String model = "test-load-missing-model";
|
||||
withMissingModel(model);
|
||||
|
||||
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
Settings.EMPTY);
|
||||
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
try {
|
||||
future.get();
|
||||
fail("Should not have succeeded");
|
||||
} catch (Exception ex) {
|
||||
assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model)));
|
||||
}
|
||||
}
|
||||
|
||||
public void testGetModelEagerly() throws Exception {
|
||||
String model = "test-get-model-eagerly";
|
||||
withTrainedModel(model, 1L);
|
||||
|
||||
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
Settings.EMPTY);
|
||||
|
||||
for(int i = 0; i < 3; i++) {
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(true), any());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private void withTrainedModel(String modelId, long size) {
|
||||
TrainedModelDefinition definition = mock(TrainedModelDefinition.class);
|
||||
when(definition.ramBytesUsed()).thenReturn(size);
|
||||
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
|
||||
when(trainedModelConfig.getDefinition()).thenReturn(definition);
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
||||
listener.onResponse(trainedModelConfig);
|
||||
return null;
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(true), any());
|
||||
}
|
||||
|
||||
private void withMissingModel(String modelId) {
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||
return null;
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(true), any());
|
||||
}
|
||||
|
||||
private static ClusterChangedEvent ingestChangedEvent(String... modelId) throws IOException {
|
||||
return ingestChangedEvent(true, modelId);
|
||||
}
|
||||
|
||||
private static ClusterChangedEvent ingestChangedEvent(boolean isIngestNode, String... modelId) throws IOException {
|
||||
ClusterChangedEvent event = mock(ClusterChangedEvent.class);
|
||||
when(event.changedCustomMetaDataSet()).thenReturn(Collections.singleton(IngestMetadata.TYPE));
|
||||
when(event.state()).thenReturn(buildClusterStateWithModelReferences(isIngestNode, modelId));
|
||||
return event;
|
||||
}
|
||||
|
||||
private static ClusterState buildClusterStateWithModelReferences(boolean isIngestNode, String... modelId) throws IOException {
|
||||
Map<String, PipelineConfiguration> configurations = new HashMap<>(modelId.length);
|
||||
for (String id : modelId) {
|
||||
configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id));
|
||||
}
|
||||
IngestMetadata ingestMetadata = new IngestMetadata(configurations);
|
||||
|
||||
return ClusterState.builder(new ClusterName("_name"))
|
||||
.metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata))
|
||||
.nodes(DiscoveryNodes.builder().add(
|
||||
new DiscoveryNode("node_name",
|
||||
"node_id",
|
||||
new TransportAddress(InetAddress.getLoopbackAddress(), 9300),
|
||||
Collections.emptyMap(),
|
||||
isIngestNode ? Collections.singleton(DiscoveryNodeRole.INGEST_ROLE) : Collections.emptySet(),
|
||||
Version.CURRENT))
|
||||
.localNodeId("node_id")
|
||||
.build())
|
||||
.build();
|
||||
}
|
||||
|
||||
private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException {
|
||||
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors",
|
||||
Collections.singletonList(
|
||||
Collections.singletonMap(InferenceProcessor.TYPE,
|
||||
Collections.singletonMap(InferenceProcessor.MODEL_ID,
|
||||
modelId)))))) {
|
||||
return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,196 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildClassification;
|
||||
import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildRegression;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
||||
|
||||
private TrainedModelProvider trainedModelProvider;
|
||||
|
||||
@Before
|
||||
public void createComponents() throws Exception {
|
||||
trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry());
|
||||
waitForMlTemplates();
|
||||
}
|
||||
|
||||
public void testInferModels() throws Exception {
|
||||
String modelId1 = "test-load-models-regression";
|
||||
String modelId2 = "test-load-models-classification";
|
||||
Map<String, String> oneHotEncoding = new HashMap<>();
|
||||
oneHotEncoding.put("cat", "animal_cat");
|
||||
oneHotEncoding.put("dog", "animal_dog");
|
||||
TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("field1", "field2")))
|
||||
.setDefinition(new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
|
||||
.setTrainedModel(buildClassification(true))
|
||||
.setModelId(modelId1))
|
||||
.setVersion(Version.CURRENT)
|
||||
.setCreateTime(Instant.now())
|
||||
.setEstimatedOperations(0)
|
||||
.setEstimatedHeapMemory(0)
|
||||
.build();
|
||||
TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("field1", "field2")))
|
||||
.setDefinition(new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
|
||||
.setTrainedModel(buildRegression())
|
||||
.setModelId(modelId2))
|
||||
.setVersion(Version.CURRENT)
|
||||
.setEstimatedOperations(0)
|
||||
.setEstimatedHeapMemory(0)
|
||||
.setCreateTime(Instant.now())
|
||||
.build();
|
||||
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
|
||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||
|
||||
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config1, listener), putConfigHolder, exceptionHolder);
|
||||
assertThat(putConfigHolder.get(), is(true));
|
||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config2, listener), putConfigHolder, exceptionHolder);
|
||||
assertThat(putConfigHolder.get(), is(true));
|
||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
|
||||
|
||||
List<Map<String, Object>> toInfer = new ArrayList<>();
|
||||
toInfer.add(new HashMap<String, Object>() {{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.5);
|
||||
put("categorical", "dog");
|
||||
}});
|
||||
toInfer.add(new HashMap<String, Object>() {{
|
||||
put("foo", 0.9);
|
||||
put("bar", 1.5);
|
||||
put("categorical", "cat");
|
||||
}});
|
||||
|
||||
List<Map<String, Object>> toInfer2 = new ArrayList<>();
|
||||
toInfer2.add(new HashMap<String, Object>() {{
|
||||
put("foo", 0.0);
|
||||
put("bar", 0.01);
|
||||
put("categorical", "dog");
|
||||
}});
|
||||
toInfer2.add(new HashMap<String, Object>() {{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.0);
|
||||
put("categorical", "cat");
|
||||
}});
|
||||
|
||||
// Test regression
|
||||
InferModelAction.Request request = new InferModelAction.Request(modelId1, toInfer, new RegressionConfig());
|
||||
InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet();
|
||||
assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()),
|
||||
contains(1.3, 1.25));
|
||||
|
||||
request = new InferModelAction.Request(modelId1, toInfer2, new RegressionConfig());
|
||||
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
|
||||
assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()),
|
||||
contains(1.65, 1.55));
|
||||
|
||||
|
||||
// Test classification
|
||||
request = new InferModelAction.Request(modelId2, toInfer, new ClassificationConfig(0));
|
||||
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
|
||||
assertThat(response.getInferenceResults()
|
||||
.stream()
|
||||
.map(i -> ((SingleValueInferenceResults)i).valueAsString())
|
||||
.collect(Collectors.toList()),
|
||||
contains("not_to_be", "to_be"));
|
||||
|
||||
// Get top classes
|
||||
request = new InferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2));
|
||||
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
|
||||
|
||||
ClassificationInferenceResults classificationInferenceResults =
|
||||
(ClassificationInferenceResults)response.getInferenceResults().get(0);
|
||||
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("not_to_be"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("to_be"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(),
|
||||
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
|
||||
|
||||
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1);
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("not_to_be"));
|
||||
// they should always be in order of Most probable to least
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(),
|
||||
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
|
||||
|
||||
// Test that top classes restrict the number returned
|
||||
request = new InferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1));
|
||||
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
|
||||
|
||||
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0);
|
||||
assertThat(classificationInferenceResults.getTopClasses(), hasSize(1));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be"));
|
||||
}
|
||||
|
||||
public void testInferMissingModel() {
|
||||
String model = "test-infer-missing-model";
|
||||
InferModelAction.Request request = new InferModelAction.Request(model, Collections.emptyList(), new RegressionConfig());
|
||||
try {
|
||||
client().execute(InferModelAction.INSTANCE, request).actionGet();
|
||||
} catch (ElasticsearchException ex) {
|
||||
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model)));
|
||||
}
|
||||
}
|
||||
|
||||
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
|
||||
return TrainedModelConfig.builder()
|
||||
.setCreatedBy("ml_test")
|
||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
|
||||
.setDescription("trained model config for test")
|
||||
.setModelId(modelId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
}
|
|
@ -144,6 +144,8 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
.setDescription("trained model config for test")
|
||||
.setModelId(modelId)
|
||||
.setVersion(Version.CURRENT)
|
||||
.setEstimatedHeapMemory(0)
|
||||
.setEstimatedOperations(0)
|
||||
.setInput(TrainedModelInputTests.createRandomInput());
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"ml.delete_trained_model":{
|
||||
"documentation":{
|
||||
"url":"TODO"
|
||||
},
|
||||
"stability":"experimental",
|
||||
"url":{
|
||||
"paths":[
|
||||
{
|
||||
"path":"/_ml/inference/{model_id}",
|
||||
"methods":[
|
||||
"DELETE"
|
||||
],
|
||||
"parts":{
|
||||
"model_id":{
|
||||
"type":"string",
|
||||
"description":"The ID of the trained model to delete"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
{
|
||||
"ml.get_trained_models":{
|
||||
"documentation":{
|
||||
"url":"TODO"
|
||||
},
|
||||
"stability":"experimental",
|
||||
"url":{
|
||||
"paths":[
|
||||
{
|
||||
"path":"/_ml/inference/{model_id}",
|
||||
"methods":[
|
||||
"GET"
|
||||
],
|
||||
"parts":{
|
||||
"model_id":{
|
||||
"type":"string",
|
||||
"description":"The ID of the trained models to fetch"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"path":"/_ml/inference",
|
||||
"methods":[
|
||||
"GET"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"params":{
|
||||
"allow_no_match":{
|
||||
"type":"boolean",
|
||||
"required":false,
|
||||
"description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)",
|
||||
"default":true
|
||||
},
|
||||
"include_model_definition":{
|
||||
"type":"boolean",
|
||||
"required":false,
|
||||
"description":"Should the full model definition be included in the results. These definitions can be large",
|
||||
"default":false
|
||||
},
|
||||
"from":{
|
||||
"type":"int",
|
||||
"description":"skips a number of trained models",
|
||||
"default":0
|
||||
},
|
||||
"size":{
|
||||
"type":"int",
|
||||
"description":"specifies a max number of trained models to get",
|
||||
"default":100
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
{
|
||||
"ml.get_trained_models_stats":{
|
||||
"documentation":{
|
||||
"url":"TODO"
|
||||
},
|
||||
"stability":"experimental",
|
||||
"url":{
|
||||
"paths":[
|
||||
{
|
||||
"path":"/_ml/inference/{model_id}/_stats",
|
||||
"methods":[
|
||||
"GET"
|
||||
],
|
||||
"parts":{
|
||||
"model_id":{
|
||||
"type":"string",
|
||||
"description":"The ID of the trained models stats to fetch"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"path":"/_ml/inference/_stats",
|
||||
"methods":[
|
||||
"GET"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"params":{
|
||||
"allow_no_match":{
|
||||
"type":"boolean",
|
||||
"required":false,
|
||||
"description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)",
|
||||
"default":true
|
||||
},
|
||||
"from":{
|
||||
"type":"int",
|
||||
"description":"skips a number of trained models",
|
||||
"default":0
|
||||
},
|
||||
"size":{
|
||||
"type":"int",
|
||||
"description":"specifies a max number of trained models to get",
|
||||
"default":100
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
---
|
||||
"Test get-all given no trained models exist":
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "_all"
|
||||
- match: { count: 0 }
|
||||
- match: { trained_model_configs: [] }
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "*"
|
||||
- match: { count: 0 }
|
||||
- match: { trained_model_configs: [] }
|
||||
|
||||
---
|
||||
"Test get given missing trained model":
|
||||
|
||||
- do:
|
||||
catch: missing
|
||||
ml.get_trained_models:
|
||||
model_id: "missing-trained-model"
|
||||
---
|
||||
"Test get given expression without matches and allow_no_match is false":
|
||||
|
||||
- do:
|
||||
catch: missing
|
||||
ml.get_trained_models:
|
||||
model_id: "missing-trained-model*"
|
||||
allow_no_match: false
|
||||
|
||||
---
|
||||
"Test get given expression without matches and allow_no_match is true":
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "missing-trained-model*"
|
||||
allow_no_match: true
|
||||
- match: { count: 0 }
|
||||
- match: { trained_model_configs: [] }
|
||||
---
|
||||
"Test delete given unused trained model":
|
||||
|
||||
- do:
|
||||
index:
|
||||
id: trained_model_config-unused-regression-model-0
|
||||
index: .ml-inference-000001
|
||||
body: >
|
||||
{
|
||||
"model_id": "unused-regression-model",
|
||||
"created_by": "ml_tests",
|
||||
"version": "8.0.0",
|
||||
"description": "empty model for tests",
|
||||
"create_time": 0,
|
||||
"model_version": 0,
|
||||
"model_type": "local"
|
||||
}
|
||||
- do:
|
||||
indices.refresh: {}
|
||||
|
||||
- do:
|
||||
ml.delete_trained_model:
|
||||
model_id: "unused-regression-model"
|
||||
- match: { acknowledged: true }
|
||||
|
||||
---
|
||||
"Test delete with missing model":
|
||||
- do:
|
||||
catch: missing
|
||||
ml.delete_trained_model:
|
||||
model_id: "missing-trained-model"
|
||||
|
||||
---
|
||||
"Test delete given used trained model":
|
||||
- do:
|
||||
index:
|
||||
id: trained_model_config-used-regression-model-0
|
||||
index: .ml-inference-000001
|
||||
body: >
|
||||
{
|
||||
"model_id": "used-regression-model",
|
||||
"created_by": "ml_tests",
|
||||
"version": "8.0.0",
|
||||
"description": "empty model for tests",
|
||||
"create_time": 0,
|
||||
"model_version": 0,
|
||||
"model_type": "local"
|
||||
}
|
||||
- do:
|
||||
indices.refresh: {}
|
||||
|
||||
- do:
|
||||
ingest.put_pipeline:
|
||||
id: "regression-model-pipeline"
|
||||
body: >
|
||||
{
|
||||
"processors": [
|
||||
{
|
||||
"inference" : {
|
||||
"model_id" : "used-regression-model",
|
||||
"inference_config": {"regression": {}},
|
||||
"target_field": "regression_field",
|
||||
"field_mappings": {}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
- match: { acknowledged: true }
|
||||
|
||||
- do:
|
||||
catch: conflict
|
||||
ml.delete_trained_model:
|
||||
model_id: "used-regression-model"
|
|
@ -0,0 +1,230 @@
|
|||
setup:
|
||||
- skip:
|
||||
features: headers
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
index:
|
||||
id: trained_model_config-unused-regression-model1-0
|
||||
index: .ml-inference-000001
|
||||
body: >
|
||||
{
|
||||
"model_id": "unused-regression-model1",
|
||||
"created_by": "ml_tests",
|
||||
"version": "8.0.0",
|
||||
"description": "empty model for tests",
|
||||
"create_time": 0,
|
||||
"model_version": 0,
|
||||
"model_type": "local",
|
||||
"doc_type": "trained_model_config"
|
||||
}
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
index:
|
||||
id: trained_model_config-unused-regression-model-0
|
||||
index: .ml-inference-000001
|
||||
body: >
|
||||
{
|
||||
"model_id": "unused-regression-model",
|
||||
"created_by": "ml_tests",
|
||||
"version": "8.0.0",
|
||||
"description": "empty model for tests",
|
||||
"create_time": 0,
|
||||
"model_version": 0,
|
||||
"model_type": "local",
|
||||
"doc_type": "trained_model_config"
|
||||
}
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
index:
|
||||
id: trained_model_config-used-regression-model-0
|
||||
index: .ml-inference-000001
|
||||
body: >
|
||||
{
|
||||
"model_id": "used-regression-model",
|
||||
"created_by": "ml_tests",
|
||||
"version": "8.0.0",
|
||||
"description": "empty model for tests",
|
||||
"create_time": 0,
|
||||
"model_version": 0,
|
||||
"model_type": "local",
|
||||
"doc_type": "trained_model_config"
|
||||
}
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
indices.refresh: {}
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
ingest.put_pipeline:
|
||||
id: "regression-model-pipeline"
|
||||
body: >
|
||||
{
|
||||
"processors": [
|
||||
{
|
||||
"inference" : {
|
||||
"model_id" : "used-regression-model",
|
||||
"inference_config": {"regression": {}},
|
||||
"target_field": "regression_field",
|
||||
"field_mappings": {}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
ingest.put_pipeline:
|
||||
id: "regression-model-pipeline-1"
|
||||
body: >
|
||||
{
|
||||
"processors": [
|
||||
{
|
||||
"inference" : {
|
||||
"model_id" : "used-regression-model",
|
||||
"inference_config": {"regression": {}},
|
||||
"target_field": "regression_field",
|
||||
"field_mappings": {}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
---
|
||||
"Test get stats given missing trained model":
|
||||
|
||||
- do:
|
||||
catch: missing
|
||||
ml.get_trained_models_stats:
|
||||
model_id: "missing-trained-model"
|
||||
---
|
||||
"Test get stats given expression without matches and allow_no_match is false":
|
||||
|
||||
- do:
|
||||
catch: missing
|
||||
ml.get_trained_models_stats:
|
||||
model_id: "missing-trained-model*"
|
||||
allow_no_match: false
|
||||
|
||||
---
|
||||
"Test get stats given expression without matches and allow_no_match is true":
|
||||
|
||||
- do:
|
||||
ml.get_trained_models_stats:
|
||||
model_id: "missing-trained-model*"
|
||||
allow_no_match: true
|
||||
- match: { count: 0 }
|
||||
- match: { trained_model_stats: [] }
|
||||
---
|
||||
"Test get stats given trained models":
|
||||
|
||||
- do:
|
||||
ml.get_trained_models_stats:
|
||||
model_id: "unused-regression-model"
|
||||
|
||||
- match: { count: 1 }
|
||||
|
||||
- do:
|
||||
ml.get_trained_models_stats:
|
||||
model_id: "_all"
|
||||
- match: { count: 3 }
|
||||
- match: { trained_model_stats.0.model_id: unused-regression-model }
|
||||
- match: { trained_model_stats.0.pipeline_count: 0 }
|
||||
- is_false: trained_model_stats.0.ingest
|
||||
- match: { trained_model_stats.1.model_id: unused-regression-model1 }
|
||||
- match: { trained_model_stats.1.pipeline_count: 0 }
|
||||
- is_false: trained_model_stats.1.ingest
|
||||
- match: { trained_model_stats.2.pipeline_count: 2 }
|
||||
- is_true: trained_model_stats.2.ingest
|
||||
|
||||
- do:
|
||||
ml.get_trained_models_stats:
|
||||
model_id: "*"
|
||||
- match: { count: 3 }
|
||||
- match: { trained_model_stats.0.model_id: unused-regression-model }
|
||||
- match: { trained_model_stats.0.pipeline_count: 0 }
|
||||
- is_false: trained_model_stats.0.ingest
|
||||
- match: { trained_model_stats.1.model_id: unused-regression-model1 }
|
||||
- match: { trained_model_stats.1.pipeline_count: 0 }
|
||||
- is_false: trained_model_stats.1.ingest
|
||||
- match: { trained_model_stats.2.pipeline_count: 2 }
|
||||
- is_true: trained_model_stats.2.ingest
|
||||
|
||||
- do:
|
||||
ml.get_trained_models_stats:
|
||||
model_id: "unused-regression-model*"
|
||||
- match: { count: 2 }
|
||||
- match: { trained_model_stats.0.model_id: unused-regression-model }
|
||||
- match: { trained_model_stats.0.pipeline_count: 0 }
|
||||
- is_false: trained_model_stats.0.ingest
|
||||
- match: { trained_model_stats.1.model_id: unused-regression-model1 }
|
||||
- match: { trained_model_stats.1.pipeline_count: 0 }
|
||||
- is_false: trained_model_stats.1.ingest
|
||||
|
||||
- do:
|
||||
ml.get_trained_models_stats:
|
||||
model_id: "unused-regression-model*"
|
||||
size: 1
|
||||
- match: { count: 2 }
|
||||
- match: { trained_model_stats.0.model_id: unused-regression-model }
|
||||
- match: { trained_model_stats.0.pipeline_count: 0 }
|
||||
- is_false: trained_model_stats.0.ingest
|
||||
|
||||
- do:
|
||||
ml.get_trained_models_stats:
|
||||
model_id: "unused-regression-model*"
|
||||
from: 1
|
||||
size: 1
|
||||
- match: { count: 2 }
|
||||
- match: { trained_model_stats.0.model_id: unused-regression-model1 }
|
||||
- match: { trained_model_stats.0.pipeline_count: 0 }
|
||||
- is_false: trained_model_stats.0.ingest
|
||||
|
||||
- do:
|
||||
ml.get_trained_models_stats:
|
||||
model_id: "used-regression-model"
|
||||
|
||||
- match: { count: 1 }
|
||||
- match: { trained_model_stats.0.model_id: used-regression-model }
|
||||
- match: { trained_model_stats.0.pipeline_count: 2 }
|
||||
- match:
|
||||
trained_model_stats.0.ingest.total:
|
||||
count: 0
|
||||
time_in_millis: 0
|
||||
current: 0
|
||||
failed: 0
|
||||
|
||||
- match:
|
||||
trained_model_stats.0.ingest.pipelines.regression-model-pipeline:
|
||||
count: 0
|
||||
time_in_millis: 0
|
||||
current: 0
|
||||
failed: 0
|
||||
processors:
|
||||
- inference:
|
||||
type: inference
|
||||
stats:
|
||||
count: 0
|
||||
time_in_millis: 0
|
||||
current: 0
|
||||
failed: 0
|
||||
|
||||
- match:
|
||||
trained_model_stats.0.ingest.pipelines.regression-model-pipeline-1:
|
||||
count: 0
|
||||
time_in_millis: 0
|
||||
current: 0
|
||||
failed: 0
|
||||
processors:
|
||||
- inference:
|
||||
type: inference
|
||||
stats:
|
||||
count: 0
|
||||
time_in_millis: 0
|
||||
current: 0
|
||||
failed: 0
|
Loading…
Reference in New Issue