From eefe7688ced859f891de9a1ce3eae624fcf22dd0 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 18 Nov 2019 13:19:17 -0500 Subject: [PATCH] [7.x][ML] ML Model Inference Ingest Processor (#49052) (#49257) * [ML] ML Model Inference Ingest Processor (#49052) * [ML][Inference] adds lazy model loader and inference (#47410) This adds a couple of things: - A model loader service that is accessible via transport calls. This service will load in models and cache them. They will stay loaded until a processor no longer references them - A Model class and its first sub-class LocalModel. Used to cache model information and run inference. - Transport action and handler for requests to infer against a local model Related Feature PRs: * [ML][Inference] Adjust inference configuration option API (#47812) * [ML][Inference] adds logistic_regression output aggregator (#48075) * [ML][Inference] Adding read/del trained models (#47882) * [ML][Inference] Adding inference ingest processor (#47859) * [ML][Inference] fixing classification inference for ensemble (#48463) * [ML][Inference] Adding model memory estimations (#48323) * [ML][Inference] adding more options to inference processor (#48545) * [ML][Inference] handle string values better in feature extraction (#48584) * [ML][Inference] Adding _stats endpoint for inference (#48492) * [ML][Inference] add inference processors and trained models to usage (#47869) * [ML][Inference] add new flag for optionally including model definition (#48718) * [ML][Inference] adding license checks (#49056) * [ML][Inference] Adding memory and compute estimates to inference (#48955) * fixing version of indexed docs for model inference --- .../ml/inference/TrainedModelConfig.java | 51 +- .../ml/inference/TrainedModelConfigTests.java | 5 +- .../org/elasticsearch/ingest/IngestStats.java | 62 ++ .../xpack/core/XPackClientPlugin.java | 31 + .../ml/MachineLearningFeatureSetUsage.java | 17 +- .../ml/action/DeleteTrainedModelAction.java | 81 +++ .../ml/action/GetTrainedModelsAction.java | 128 +++++ .../action/GetTrainedModelsStatsAction.java | 194 +++++++ .../core/ml/action/InferModelAction.java | 144 +++++ .../MlInferenceNamedXContentProvider.java | 18 + .../core/ml/inference/TrainedModelConfig.java | 70 ++- .../ml/inference/TrainedModelDefinition.java | 45 +- .../preprocessing/FrequencyEncoding.java | 17 + .../preprocessing/OneHotEncoding.java | 15 + .../inference/preprocessing/PreProcessor.java | 3 +- .../preprocessing/TargetMeanEncoding.java | 16 + .../ClassificationInferenceResults.java | 175 ++++++ .../inference/results/InferenceResults.java | 16 + .../results/RawInferenceResults.java | 65 +++ .../results/RegressionInferenceResults.java | 68 +++ .../results/SingleValueInferenceResults.java | 51 ++ .../trainedmodel/ClassificationConfig.java | 100 ++++ .../trainedmodel/InferenceConfig.java | 21 + .../trainedmodel/InferenceHelpers.java | 89 +++ .../trainedmodel/NullInferenceConfig.java | 51 ++ .../trainedmodel/RegressionConfig.java | 80 +++ .../inference/trainedmodel/TrainedModel.java | 39 +- .../trainedmodel/ensemble/Ensemble.java | 109 +++- .../ensemble/LogisticRegression.java | 38 +- .../ensemble/OutputAggregator.java | 6 +- .../trainedmodel/ensemble/WeightedMode.java | 41 +- .../trainedmodel/ensemble/WeightedSum.java | 42 +- .../ml/inference/trainedmodel/tree/Tree.java | 116 ++-- .../inference/trainedmodel/tree/TreeNode.java | 97 ++-- .../core/ml/inference/utils/Statistics.java | 18 +- .../xpack/core/ml/job/messages/Messages.java | 6 + .../notifications/InferenceAuditMessage.java | 37 ++ .../xpack/core/ml/utils/ExceptionsHelper.java | 8 + .../DeleteTrainedModelsRequestTests.java | 23 + .../action/GetTrainedModelsRequestTests.java | 26 + ...TrainedModelsStatsActionResponseTests.java | 60 ++ .../action/InferModelActionRequestTests.java | 61 ++ .../action/InferModelActionResponseTests.java | 58 ++ .../ml/inference/TrainedModelConfigTests.java | 19 +- .../TrainedModelDefinitionTests.java | 57 +- .../ClassificationInferenceResultsTests.java | 83 +++ .../results/RawInferenceResultsTests.java | 26 + .../RegressionInferenceResultsTests.java | 41 ++ .../ClassificationConfigTests.java | 47 ++ .../trainedmodel/InferenceHelpersTests.java | 55 ++ .../trainedmodel/RegressionConfigTests.java | 43 ++ .../trainedmodel/ensemble/EnsembleTests.java | 112 +++- .../ensemble/LogisticRegressionTests.java | 14 +- .../ensemble/WeightedModeTests.java | 15 +- .../ensemble/WeightedSumTests.java | 15 +- .../trainedmodel/tree/TreeTests.java | 74 ++- .../AnomalyDetectionAuditMessageTests.java | 16 +- .../ml/notifications/AuditMessageTests.java | 27 + .../DataFrameAnalyticsAuditMessageTests.java | 18 +- .../InferenceAuditMessageTests.java | 35 ++ .../ml/qa/ml-with-security/build.gradle | 7 + .../ml/integration/InferenceIngestIT.java | 544 ++++++++++++++++++ .../xpack/ml/integration/TrainedModelIT.java | 221 +++++++ .../xpack/ml/MachineLearning.java | 82 ++- .../xpack/ml/MachineLearningFeatureSet.java | 227 ++++++-- .../TransportDeleteTrainedModelAction.java | 132 +++++ .../TransportGetTrainedModelsAction.java | 76 +++ .../TransportGetTrainedModelsStatsAction.java | 249 ++++++++ .../ml/action/TransportInferModelAction.java | 75 +++ .../process/AnalyticsResultProcessor.java | 2 + .../process/results/AnalyticsResult.java | 7 +- .../inference/ingest/InferenceProcessor.java | 297 ++++++++++ .../inference/loadingservice/LocalModel.java | 60 ++ .../ml/inference/loadingservice/Model.java | 21 + .../loadingservice/ModelLoadingService.java | 370 ++++++++++++ .../persistence/InferenceInternalIndex.java | 7 + .../persistence/TrainedModelProvider.java | 193 ++++++- .../ml/notifications/InferenceAuditor.java | 20 + .../RestDeleteTrainedModelAction.java | 39 ++ .../inference/RestGetTrainedModelsAction.java | 56 ++ .../RestGetTrainedModelsStatsAction.java | 52 ++ .../MachineLearningLicensingTests.java | 226 ++++++++ .../ml/MachineLearningFeatureSetTests.java | 113 +++- ...sportGetTrainedModelsStatsActionTests.java | 289 ++++++++++ .../AnalyticsResultProcessorTests.java | 2 + .../process/results/AnalyticsResultTests.java | 2 +- .../InferenceProcessorFactoryTests.java | 281 +++++++++ .../ingest/InferenceProcessorTests.java | 230 ++++++++ .../loadingservice/LocalModelTests.java | 212 +++++++ .../ModelLoadingServiceTests.java | 363 ++++++++++++ .../integration/ModelInferenceActionIT.java | 196 +++++++ .../integration/TrainedModelProviderIT.java | 2 + .../api/ml.delete_trained_model.json | 24 + .../api/ml.get_trained_models.json | 54 ++ .../api/ml.get_trained_models_stats.json | 48 ++ .../rest-api-spec/test/ml/inference_crud.yml | 113 ++++ .../test/ml/inference_stats_crud.yml | 230 ++++++++ 97 files changed, 7855 insertions(+), 362 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessage.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelsRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpersTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AuditMessageTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessageTests.java create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeleteTrainedModelAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsStatsAction.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_trained_model.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models_stats.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java index 273aa6b0213..384bfe53e4b 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -22,6 +22,7 @@ import org.elasticsearch.Version; import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -47,6 +48,8 @@ public class TrainedModelConfig implements ToXContentObject { public static final ParseField TAGS = new ParseField("tags"); public static final ParseField METADATA = new ParseField("metadata"); public static final ParseField INPUT = new ParseField("input"); + public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes"); + public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations"); public static final ObjectParser PARSER = new ObjectParser<>(NAME, true, @@ -66,6 +69,8 @@ public class TrainedModelConfig implements ToXContentObject { PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS); PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT); + PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES); + PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS); } public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException { @@ -81,6 +86,8 @@ public class TrainedModelConfig implements ToXContentObject { private final List tags; private final Map metadata; private final TrainedModelInput input; + private final Long estimatedHeapMemory; + private final Long estimatedOperations; TrainedModelConfig(String modelId, String createdBy, @@ -90,7 +97,9 @@ public class TrainedModelConfig implements ToXContentObject { TrainedModelDefinition definition, List tags, Map metadata, - TrainedModelInput input) { + TrainedModelInput input, + Long estimatedHeapMemory, + Long estimatedOperations) { this.modelId = modelId; this.createdBy = createdBy; this.version = version; @@ -100,6 +109,8 @@ public class TrainedModelConfig implements ToXContentObject { this.tags = tags == null ? null : Collections.unmodifiableList(tags); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); this.input = input; + this.estimatedHeapMemory = estimatedHeapMemory; + this.estimatedOperations = estimatedOperations; } public String getModelId() { @@ -138,6 +149,18 @@ public class TrainedModelConfig implements ToXContentObject { return input; } + public ByteSizeValue getEstimatedHeapMemory() { + return estimatedHeapMemory == null ? null : new ByteSizeValue(estimatedHeapMemory); + } + + public Long getEstimatedHeapMemoryBytes() { + return estimatedHeapMemory; + } + + public Long getEstimatedOperations() { + return estimatedOperations; + } + public static Builder builder() { return new Builder(); } @@ -172,6 +195,12 @@ public class TrainedModelConfig implements ToXContentObject { if (input != null) { builder.field(INPUT.getPreferredName(), input); } + if (estimatedHeapMemory != null) { + builder.field(ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), estimatedHeapMemory); + } + if (estimatedOperations != null) { + builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations); + } builder.endObject(); return builder; } @@ -194,6 +223,8 @@ public class TrainedModelConfig implements ToXContentObject { Objects.equals(definition, that.definition) && Objects.equals(tags, that.tags) && Objects.equals(input, that.input) && + Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) && + Objects.equals(estimatedOperations, that.estimatedOperations) && Objects.equals(metadata, that.metadata); } @@ -206,6 +237,8 @@ public class TrainedModelConfig implements ToXContentObject { definition, description, tags, + estimatedHeapMemory, + estimatedOperations, metadata, input); } @@ -222,6 +255,8 @@ public class TrainedModelConfig implements ToXContentObject { private List tags; private TrainedModelDefinition definition; private TrainedModelInput input; + private Long estimatedHeapMemory; + private Long estimatedOperations; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -277,6 +312,16 @@ public class TrainedModelConfig implements ToXContentObject { return this; } + public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) { + this.estimatedHeapMemory = estimatedHeapMemory; + return this; + } + + public Builder setEstimatedOperations(Long estimatedOperations) { + this.estimatedOperations = estimatedOperations; + return this; + } + public TrainedModelConfig build() { return new TrainedModelConfig( modelId, @@ -287,7 +332,9 @@ public class TrainedModelConfig implements ToXContentObject { definition, tags, metadata, - input); + input, + estimatedHeapMemory, + estimatedOperations); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index 0dc672202d8..7e85a08ee17 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -64,7 +64,10 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - randomBoolean() ? null : TrainedModelInputTests.createRandomInput()); + randomBoolean() ? null : TrainedModelInputTests.createRandomInput(), + randomBoolean() ? null : randomNonNegativeLong(), + randomBoolean() ? null : randomNonNegativeLong()); + } @Override diff --git a/server/src/main/java/org/elasticsearch/ingest/IngestStats.java b/server/src/main/java/org/elasticsearch/ingest/IngestStats.java index 5471f60d8bc..17f53e9904c 100644 --- a/server/src/main/java/org/elasticsearch/ingest/IngestStats.java +++ b/server/src/main/java/org/elasticsearch/ingest/IngestStats.java @@ -33,6 +33,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.TimeUnit; public class IngestStats implements Writeable, ToXContentFragment { @@ -150,6 +151,21 @@ public class IngestStats implements Writeable, ToXContentFragment { return processorStats; } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IngestStats that = (IngestStats) o; + return Objects.equals(totalStats, that.totalStats) + && Objects.equals(pipelineStats, that.pipelineStats) + && Objects.equals(processorStats, that.processorStats); + } + + @Override + public int hashCode() { + return Objects.hash(totalStats, pipelineStats, processorStats); + } + public static class Stats implements Writeable, ToXContentFragment { private final long ingestCount; @@ -218,6 +234,22 @@ public class IngestStats implements Writeable, ToXContentFragment { builder.field("failed", ingestFailedCount); return builder; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IngestStats.Stats that = (IngestStats.Stats) o; + return Objects.equals(ingestCount, that.ingestCount) + && Objects.equals(ingestTimeInMillis, that.ingestTimeInMillis) + && Objects.equals(ingestFailedCount, that.ingestFailedCount) + && Objects.equals(ingestCurrent, that.ingestCurrent); + } + + @Override + public int hashCode() { + return Objects.hash(ingestCount, ingestTimeInMillis, ingestFailedCount, ingestCurrent); + } } /** @@ -270,6 +302,20 @@ public class IngestStats implements Writeable, ToXContentFragment { public Stats getStats() { return stats; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IngestStats.PipelineStat that = (IngestStats.PipelineStat) o; + return Objects.equals(pipelineId, that.pipelineId) + && Objects.equals(stats, that.stats); + } + + @Override + public int hashCode() { + return Objects.hash(pipelineId, stats); + } } /** @@ -297,5 +343,21 @@ public class IngestStats implements Writeable, ToXContentFragment { public Stats getStats() { return stats; } + + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IngestStats.ProcessorStat that = (IngestStats.ProcessorStat) o; + return Objects.equals(name, that.name) + && Objects.equals(type, that.type) + && Objects.equals(stats, that.stats); + } + + @Override + public int hashCode() { + return Objects.hash(name, type, stats); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index e48ead592db..5f34bc44631 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -88,6 +88,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteFilterAction; import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction; import org.elasticsearch.xpack.core.ml.action.DeleteJobAction; import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; @@ -109,6 +110,9 @@ import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction; import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; import org.elasticsearch.xpack.core.ml.action.KillProcessAction; import org.elasticsearch.xpack.core.ml.action.MlInfoAction; @@ -153,6 +157,19 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.P import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; @@ -371,6 +388,10 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl StopDataFrameAnalyticsAction.INSTANCE, EvaluateDataFrameAction.INSTANCE, EstimateMemoryUsageAction.INSTANCE, + InferModelAction.INSTANCE, + GetTrainedModelsAction.INSTANCE, + DeleteTrainedModelAction.INSTANCE, + GetTrainedModelsStatsAction.INSTANCE, // security ClearRealmCacheAction.INSTANCE, ClearRolesCacheAction.INSTANCE, @@ -519,6 +540,16 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl new NamedWriteableRegistry.Entry(OutputAggregator.class, LogisticRegression.NAME.getPreferredName(), LogisticRegression::new), + // ML - Inference Results + new NamedWriteableRegistry.Entry(InferenceResults.class, + ClassificationInferenceResults.NAME, + ClassificationInferenceResults::new), + new NamedWriteableRegistry.Entry(InferenceResults.class, + RegressionInferenceResults.NAME, + RegressionInferenceResults::new), + // ML - Inference Configuration + new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new), + new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new), // monitoring new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java index ca898f9cbf6..883828d1d3a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java @@ -29,10 +29,12 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage { public static final String CREATED_BY = "created_by"; public static final String NODE_COUNT = "node_count"; public static final String DATA_FRAME_ANALYTICS_JOBS_FIELD = "data_frame_analytics_jobs"; + public static final String INFERENCE_FIELD = "inference"; private final Map jobsUsage; private final Map datafeedsUsage; private final Map analyticsUsage; + private final Map inferenceUsage; private final int nodeCount; public MachineLearningFeatureSetUsage(boolean available, @@ -40,11 +42,13 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage { Map jobsUsage, Map datafeedsUsage, Map analyticsUsage, + Map inferenceUsage, int nodeCount) { super(XPackField.MACHINE_LEARNING, available, enabled); this.jobsUsage = Objects.requireNonNull(jobsUsage); this.datafeedsUsage = Objects.requireNonNull(datafeedsUsage); this.analyticsUsage = Objects.requireNonNull(analyticsUsage); + this.inferenceUsage = Objects.requireNonNull(inferenceUsage); this.nodeCount = nodeCount; } @@ -57,12 +61,17 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage { } else { this.analyticsUsage = Collections.emptyMap(); } + if (in.getVersion().onOrAfter(Version.V_7_6_0)) { + this.inferenceUsage = in.readMap(); + } else { + this.inferenceUsage = Collections.emptyMap(); + } if (in.getVersion().onOrAfter(Version.V_6_5_0)) { this.nodeCount = in.readInt(); } else { this.nodeCount = -1; } - } + } @Override public void writeTo(StreamOutput out) throws IOException { @@ -72,10 +81,13 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage { if (out.getVersion().onOrAfter(Version.V_7_4_0)) { out.writeMap(analyticsUsage); } + if (out.getVersion().onOrAfter(Version.V_7_6_0)) { + out.writeMap(inferenceUsage); + } if (out.getVersion().onOrAfter(Version.V_6_5_0)) { out.writeInt(nodeCount); } - } + } @Override protected void innerXContent(XContentBuilder builder, Params params) throws IOException { @@ -83,6 +95,7 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage { builder.field(JOBS_FIELD, jobsUsage); builder.field(DATAFEEDS_FIELD, datafeedsUsage); builder.field(DATA_FRAME_ANALYTICS_JOBS_FIELD, analyticsUsage); + builder.field(INFERENCE_FIELD, inferenceUsage); if (nodeCount >= 0) { builder.field(NODE_COUNT, nodeCount); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAction.java new file mode 100644 index 00000000000..521070a959d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAction.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class DeleteTrainedModelAction extends ActionType { + + public static final DeleteTrainedModelAction INSTANCE = new DeleteTrainedModelAction(); + public static final String NAME = "cluster:admin/xpack/ml/inference/delete"; + + private DeleteTrainedModelAction() { + super(NAME, AcknowledgedResponse::new); + } + + public static class Request extends AcknowledgedRequest implements ToXContentFragment { + + private String id; + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + } + + public Request() {} + + public Request(String id) { + this.id = ExceptionsHelper.requireNonNull(id, TrainedModelConfig.MODEL_ID); + } + + public String getId() { + return id; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), id); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DeleteTrainedModelAction.Request request = (DeleteTrainedModelAction.Request) o; + return Objects.equals(id, request.id); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java new file mode 100644 index 00000000000..b86cfced552 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java @@ -0,0 +1,128 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + + +public class GetTrainedModelsAction extends ActionType { + + public static final GetTrainedModelsAction INSTANCE = new GetTrainedModelsAction(); + public static final String NAME = "cluster:monitor/xpack/ml/inference/get"; + + private GetTrainedModelsAction() { + super(NAME, Response::new); + } + + public static class Request extends AbstractGetResourcesRequest { + + public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition"); + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + private final boolean includeModelDefinition; + + public Request(String id, boolean includeModelDefinition) { + setResourceId(id); + setAllowNoResources(true); + this.includeModelDefinition = includeModelDefinition; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.includeModelDefinition = in.readBoolean(); + } + + @Override + public String getResourceIdField() { + return TrainedModelConfig.MODEL_ID.getPreferredName(); + } + + public boolean isIncludeModelDefinition() { + return includeModelDefinition; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(includeModelDefinition); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), includeModelDefinition); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + Request other = (Request) obj; + return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition; + } + } + + public static class Response extends AbstractGetResourcesResponse { + + public static final ParseField RESULTS_FIELD = new ParseField("trained_model_configs"); + + public Response(StreamInput in) throws IOException { + super(in); + } + + public Response(QueryPage trainedModels) { + super(trainedModels); + } + + @Override + protected Reader getReader() { + return TrainedModelConfig::new; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private long totalCount; + private List configs = Collections.emptyList(); + + private Builder() { + } + + public Builder setTotalCount(long totalCount) { + this.totalCount = totalCount; + return this; + } + + public Builder setModels(List configs) { + this.configs = configs; + return this; + } + + public Response build() { + return new Response(new QueryPage<>(configs, totalCount, RESULTS_FIELD)); + } + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java new file mode 100644 index 00000000000..f3cb43e8ef7 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java @@ -0,0 +1,194 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.ingest.IngestStats; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class GetTrainedModelsStatsAction extends ActionType { + + public static final GetTrainedModelsStatsAction INSTANCE = new GetTrainedModelsStatsAction(); + public static final String NAME = "cluster:monitor/xpack/ml/inference/stats/get"; + + public static final ParseField MODEL_ID = new ParseField("model_id"); + public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count"); + + private GetTrainedModelsStatsAction() { + super(NAME, GetTrainedModelsStatsAction.Response::new); + } + + public static class Request extends AbstractGetResourcesRequest { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + public Request() { + setAllowNoResources(true); + } + + public Request(String id) { + setResourceId(id); + setAllowNoResources(true); + } + + public Request(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getResourceIdField() { + return TrainedModelConfig.MODEL_ID.getPreferredName(); + } + + } + + public static class RequestBuilder extends ActionRequestBuilder { + + public RequestBuilder(ElasticsearchClient client, GetTrainedModelsStatsAction action) { + super(client, action, new Request()); + } + } + + public static class Response extends AbstractGetResourcesResponse { + + public static class TrainedModelStats implements ToXContentObject, Writeable { + private final String modelId; + private final IngestStats ingestStats; + private final int pipelineCount; + + private static final IngestStats EMPTY_INGEST_STATS = new IngestStats(new IngestStats.Stats(0, 0, 0, 0), + Collections.emptyList(), + Collections.emptyMap()); + + public TrainedModelStats(String modelId, IngestStats ingestStats, int pipelineCount) { + this.modelId = Objects.requireNonNull(modelId); + this.ingestStats = ingestStats == null ? EMPTY_INGEST_STATS : ingestStats; + if (pipelineCount < 0) { + throw new ElasticsearchException("[{}] must be a greater than or equal to 0", PIPELINE_COUNT.getPreferredName()); + } + this.pipelineCount = pipelineCount; + } + + public TrainedModelStats(StreamInput in) throws IOException { + modelId = in.readString(); + ingestStats = new IngestStats(in); + pipelineCount = in.readVInt(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.field(PIPELINE_COUNT.getPreferredName(), pipelineCount); + if (pipelineCount > 0) { + // Ingest stats is a fragment + ingestStats.toXContent(builder, params); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + ingestStats.writeTo(out); + out.writeVInt(pipelineCount); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, ingestStats, pipelineCount); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + TrainedModelStats other = (TrainedModelStats) obj; + return Objects.equals(this.modelId, other.modelId) + && Objects.equals(this.ingestStats, other.ingestStats) + && Objects.equals(this.pipelineCount, other.pipelineCount); + } + } + + public static final ParseField RESULTS_FIELD = new ParseField("trained_model_stats"); + + public Response(StreamInput in) throws IOException { + super(in); + } + + public Response(QueryPage trainedModels) { + super(trainedModels); + } + + @Override + protected Reader getReader() { + return Response.TrainedModelStats::new; + } + + public static class Builder { + + private long totalModelCount; + private Set expandedIds; + private Map ingestStatsMap; + + public Builder setTotalModelCount(long totalModelCount) { + this.totalModelCount = totalModelCount; + return this; + } + + public Builder setExpandedIds(Set expandedIds) { + this.expandedIds = expandedIds; + return this; + } + + public Set getExpandedIds() { + return this.expandedIds; + } + + public Builder setIngestStatsByModelId(Map ingestStatsByModelId) { + this.ingestStatsMap = ingestStatsByModelId; + return this; + } + + public Response build() { + List trainedModelStats = new ArrayList<>(expandedIds.size()); + expandedIds.forEach(id -> { + IngestStats ingestStats = ingestStatsMap.get(id); + trainedModelStats.add(new TrainedModelStats(id, ingestStats, ingestStats == null ? + 0 : + ingestStats.getPipelineStats().size())); + }); + return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD)); + } + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java new file mode 100644 index 00000000000..29cab602dab --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -0,0 +1,144 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class InferModelAction extends ActionType { + + public static final InferModelAction INSTANCE = new InferModelAction(); + public static final String NAME = "cluster:admin/xpack/ml/inference/infer"; + + private InferModelAction() { + super(NAME, Response::new); + } + + public static class Request extends ActionRequest { + + private final String modelId; + private final List> objectsToInfer; + private final InferenceConfig config; + + public Request(String modelId) { + this(modelId, Collections.emptyList(), new RegressionConfig()); + } + + public Request(String modelId, List> objectsToInfer, InferenceConfig inferenceConfig) { + this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); + this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer")); + this.config = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config"); + } + + public Request(String modelId, Map objectToInfer, InferenceConfig config) { + this(modelId, + Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")), + config); + } + + public Request(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap)); + this.config = in.readNamedWriteable(InferenceConfig.class); + } + + public String getModelId() { + return modelId; + } + + public List> getObjectsToInfer() { + return objectsToInfer; + } + + public InferenceConfig getConfig() { + return config; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + out.writeCollection(objectsToInfer, StreamOutput::writeMap); + out.writeNamedWriteable(config); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferModelAction.Request that = (InferModelAction.Request) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(config, that.config) + && Objects.equals(objectsToInfer, that.objectsToInfer); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, objectsToInfer, config); + } + + } + + public static class Response extends ActionResponse { + + private final List inferenceResults; + + public Response(List inferenceResults) { + super(); + this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults")); + } + + public Response(StreamInput in) throws IOException { + super(in); + this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class)); + } + + public List getInferenceResults() { + return inferenceResults; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteableList(inferenceResults); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferModelAction.Response that = (InferModelAction.Response) o; + return Objects.equals(inferenceResults, that.inferenceResults); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceResults); + } + + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 78d2981eb05..ca380dac2bf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -8,7 +8,13 @@ package org.elasticsearch.xpack.core.ml.inference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; @@ -110,6 +116,18 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider { LogisticRegression.NAME.getPreferredName(), LogisticRegression::new)); + // Inference Results + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, + ClassificationInferenceResults.NAME, + ClassificationInferenceResults::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, + RegressionInferenceResults.NAME, + RegressionInferenceResults::new)); + + // Inference Configs + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new)); + return namedWriteables; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 04eece32b5c..5361760e5ca 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -34,6 +35,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final String NAME = "trained_model_config"; + private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage"; + public static final ParseField MODEL_ID = new ParseField("model_id"); public static final ParseField CREATED_BY = new ParseField("created_by"); public static final ParseField VERSION = new ParseField("version"); @@ -43,6 +46,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final ParseField TAGS = new ParseField("tags"); public static final ParseField METADATA = new ParseField("metadata"); public static final ParseField INPUT = new ParseField("input"); + public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes"); + public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations"); // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly public static final ObjectParser LENIENT_PARSER = createParser(true); @@ -66,6 +71,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { parser.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p, ignoreUnknownFields), INPUT); + parser.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES); + parser.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS); return parser; } @@ -81,6 +88,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { private final List tags; private final Map metadata; private final TrainedModelInput input; + private final long estimatedHeapMemory; + private final long estimatedOperations; private final TrainedModelDefinition definition; @@ -92,7 +101,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { TrainedModelDefinition definition, List tags, Map metadata, - TrainedModelInput input) { + TrainedModelInput input, + Long estimatedHeapMemory, + Long estimatedOperations) { this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY); this.version = ExceptionsHelper.requireNonNull(version, VERSION); @@ -102,6 +113,15 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS)); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); this.input = ExceptionsHelper.requireNonNull(input, INPUT); + if (ExceptionsHelper.requireNonNull(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES) < 0) { + throw new IllegalArgumentException( + "[" + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName() + "] must be greater than or equal to 0"); + } + this.estimatedHeapMemory = estimatedHeapMemory; + if (ExceptionsHelper.requireNonNull(estimatedOperations, ESTIMATED_OPERATIONS) < 0) { + throw new IllegalArgumentException("[" + ESTIMATED_OPERATIONS.getPreferredName() + "] must be greater than or equal to 0"); + } + this.estimatedOperations = estimatedOperations; } public TrainedModelConfig(StreamInput in) throws IOException { @@ -114,6 +134,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { tags = Collections.unmodifiableList(in.readList(StreamInput::readString)); metadata = in.readMap(); input = new TrainedModelInput(in); + estimatedHeapMemory = in.readVLong(); + estimatedOperations = in.readVLong(); } public String getModelId() { @@ -157,6 +179,14 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return new Builder(); } + public long getEstimatedHeapMemory() { + return estimatedHeapMemory; + } + + public long getEstimatedOperations() { + return estimatedOperations; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); @@ -168,6 +198,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { out.writeCollection(tags, StreamOutput::writeString); out.writeMap(metadata); input.writeTo(out); + out.writeVLong(estimatedHeapMemory); + out.writeVLong(estimatedOperations); } @Override @@ -180,7 +212,6 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { builder.field(DESCRIPTION.getPreferredName(), description); } builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); - // We don't store the definition in the same document as the configuration if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) { builder.field(DEFINITION.getPreferredName(), definition); @@ -193,6 +224,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); } builder.field(INPUT.getPreferredName(), input); + builder.humanReadableField( + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), + ESTIMATED_HEAP_MEMORY_USAGE_HUMAN, + new ByteSizeValue(estimatedHeapMemory)); + builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations); builder.endObject(); return builder; } @@ -215,6 +251,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { Objects.equals(definition, that.definition) && Objects.equals(tags, that.tags) && Objects.equals(input, that.input) && + Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) && + Objects.equals(estimatedOperations, that.estimatedOperations) && Objects.equals(metadata, that.metadata); } @@ -228,6 +266,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { description, tags, metadata, + estimatedHeapMemory, + estimatedOperations, input); } @@ -242,6 +282,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { private Map metadata; private TrainedModelInput input; private TrainedModelDefinition definition; + private Long estimatedHeapMemory; + private Long estimatedOperations; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -297,6 +339,16 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return this; } + public Builder setEstimatedHeapMemory(long estimatedHeapMemory) { + this.estimatedHeapMemory = estimatedHeapMemory; + return this; + } + + public Builder setEstimatedOperations(long estimatedOperations) { + this.estimatedOperations = estimatedOperations; + return this; + } + // TODO move to REST level instead of here in the builder public void validate() { // We require a definition to be available here even though it will be stored in a different doc @@ -327,6 +379,16 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", CREATE_TIME.getPreferredName()); } + + if (estimatedHeapMemory != null) { + throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName()); + } + + if (estimatedOperations != null) { + throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", + ESTIMATED_OPERATIONS.getPreferredName()); + } } public TrainedModelConfig build() { @@ -339,7 +401,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { definition, tags, metadata, - input); + input, + estimatedHeapMemory, + estimatedOperations); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index 63a5b1fd1d6..23981d07688 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -5,12 +5,16 @@ */ package org.elasticsearch.xpack.core.ml.inference; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.Accountables; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -19,6 +23,8 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; @@ -27,13 +33,18 @@ import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Objects; -public class TrainedModelDefinition implements ToXContentObject, Writeable { +public class TrainedModelDefinition implements ToXContentObject, Writeable, Accountable { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TrainedModelDefinition.class); public static final String NAME = "trained_model_definition"; + public static final String HEAP_MEMORY_ESTIMATION = "heap_memory_estimation"; public static final ParseField TRAINED_MODEL = new ParseField("trained_model"); public static final ParseField PREPROCESSORS = new ParseField("preprocessors"); @@ -106,6 +117,11 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable { true, PREPROCESSORS.getPreferredName(), preProcessors); + if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) { + builder.humanReadableField(HEAP_MEMORY_ESTIMATION + "_bytes", + HEAP_MEMORY_ESTIMATION, + new ByteSizeValue(ramBytesUsed())); + } if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); assert modelId != null; @@ -123,6 +139,15 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable { return preProcessors; } + private void preProcess(Map fields) { + preProcessors.forEach(preProcessor -> preProcessor.process(fields)); + } + + public InferenceResults infer(Map fields, InferenceConfig config) { + preProcess(fields); + return trainedModel.infer(fields, config); + } + @Override public String toString() { return Strings.toString(this); @@ -143,6 +168,24 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable { return Objects.hash(trainedModel, preProcessors, modelId); } + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOf(trainedModel); + size += RamUsageEstimator.sizeOfCollection(preProcessors); + return size; + } + + @Override + public Collection getChildResources() { + List accountables = new ArrayList<>(preProcessors.size() + 2); + accountables.add(Accountables.namedAccountable("trained_model", trainedModel)); + for(PreProcessor preProcessor : preProcessors) { + accountables.add(Accountables.namedAccountable("pre_processor_" + preProcessor.getName(), preProcessor)); + } + return accountables; + } + public static class Builder { private List preProcessors; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java index 351c0f05960..cea99d3edc8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java @@ -5,7 +5,9 @@ */ package org.elasticsearch.xpack.core.ml.inference.preprocessing; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -25,6 +27,8 @@ import java.util.Objects; */ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FrequencyEncoding.class); + public static final ParseField NAME = new ParseField("frequency_encoding"); public static final ParseField FIELD = new ParseField("field"); public static final ParseField FEATURE_NAME = new ParseField("feature_name"); @@ -143,4 +147,17 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP return Objects.hash(field, featureName, frequencyMap); } + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOf(field); + size += RamUsageEstimator.sizeOf(featureName); + size += RamUsageEstimator.sizeOfMap(frequencyMap); + return size; + } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java index 106cb1e26c1..9784ed8cbe7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java @@ -5,7 +5,9 @@ */ package org.elasticsearch.xpack.core.ml.inference.preprocessing; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -23,6 +25,7 @@ import java.util.Objects; */ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(OneHotEncoding.class); public static final ParseField NAME = new ParseField("one_hot_encoding"); public static final ParseField FIELD = new ParseField("field"); public static final ParseField HOT_MAP = new ParseField("hot_map"); @@ -127,4 +130,16 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars return Objects.hash(field, hotMap); } + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOf(field); + size += RamUsageEstimator.sizeOfMap(hotMap); + return size; + } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java index 79e1ce16ad8..f5c2ff73980 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.preprocessing; +import org.apache.lucene.util.Accountable; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; @@ -14,7 +15,7 @@ import java.util.Map; * Describes a pre-processor for a defined machine learning model * This processor should take a set of fields and return the modified set of fields. */ -public interface PreProcessor extends NamedXContentObject, NamedWriteable { +public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable { /** * Process the given fields and their values and return the modified map. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java index d8f413b3b17..914b43f98e9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java @@ -5,7 +5,9 @@ */ package org.elasticsearch.xpack.core.ml.inference.preprocessing; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -25,6 +27,7 @@ import java.util.Objects; */ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TargetMeanEncoding.class); public static final ParseField NAME = new ParseField("target_mean_encoding"); public static final ParseField FIELD = new ParseField("field"); public static final ParseField FEATURE_NAME = new ParseField("feature_name"); @@ -158,4 +161,17 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly return Objects.hash(field, featureName, meanMap, defaultValue); } + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOf(field); + size += RamUsageEstimator.sizeOf(featureName); + size += RamUsageEstimator.sizeOfMap(meanMap); + return size; + } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java new file mode 100644 index 00000000000..662585bedf5 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -0,0 +1,175 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class ClassificationInferenceResults extends SingleValueInferenceResults { + + public static final String NAME = "classification"; + public static final ParseField CLASSIFICATION_LABEL = new ParseField("classification_label"); + public static final ParseField TOP_CLASSES = new ParseField("top_classes"); + + private final String classificationLabel; + private final List topClasses; + + public ClassificationInferenceResults(double value, String classificationLabel, List topClasses) { + super(value); + this.classificationLabel = classificationLabel; + this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses); + } + + public ClassificationInferenceResults(StreamInput in) throws IOException { + super(in); + this.classificationLabel = in.readOptionalString(); + this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new)); + } + + public String getClassificationLabel() { + return classificationLabel; + } + + public List getTopClasses() { + return topClasses; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(classificationLabel); + out.writeCollection(topClasses); + } + + @Override + XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { + if (classificationLabel != null) { + builder.field(CLASSIFICATION_LABEL.getPreferredName(), classificationLabel); + } + if (topClasses.isEmpty() == false) { + builder.field(TOP_CLASSES.getPreferredName(), topClasses); + } + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + ClassificationInferenceResults that = (ClassificationInferenceResults) object; + return Objects.equals(value(), that.value()) && + Objects.equals(classificationLabel, that.classificationLabel) && + Objects.equals(topClasses, that.topClasses); + } + + @Override + public int hashCode() { + return Objects.hash(value(), classificationLabel, topClasses); + } + + @Override + public String valueAsString() { + return classificationLabel == null ? super.valueAsString() : classificationLabel; + } + + @Override + public void writeResult(IngestDocument document, String resultField) { + ExceptionsHelper.requireNonNull(document, "document"); + ExceptionsHelper.requireNonNull(resultField, "resultField"); + if (topClasses.isEmpty()) { + document.setFieldValue(resultField, valueAsString()); + } else { + document.setFieldValue(resultField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList())); + } + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } + + public static class TopClassEntry implements ToXContentObject, Writeable { + + public final ParseField CLASSIFICATION = new ParseField("classification"); + public final ParseField PROBABILITY = new ParseField("probability"); + + private final String classification; + private final double probability; + + public TopClassEntry(String classification, Double probability) { + this.classification = ExceptionsHelper.requireNonNull(classification, CLASSIFICATION); + this.probability = ExceptionsHelper.requireNonNull(probability, PROBABILITY); + } + + public TopClassEntry(StreamInput in) throws IOException { + this.classification = in.readString(); + this.probability = in.readDouble(); + } + + public String getClassification() { + return classification; + } + + public double getProbability() { + return probability; + } + + public Map asValueMap() { + Map map = new HashMap<>(2); + map.put(CLASSIFICATION.getPreferredName(), classification); + map.put(PROBABILITY.getPreferredName(), probability); + return map; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(classification); + out.writeDouble(probability); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASSIFICATION.getPreferredName(), classification); + builder.field(PROBABILITY.getPreferredName(), probability); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + TopClassEntry that = (TopClassEntry) object; + return Objects.equals(classification, that.classification) && + Objects.equals(probability, that.probability); + } + + @Override + public int hashCode() { + return Objects.hash(classification, probability); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java new file mode 100644 index 00000000000..00744f6982f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + +public interface InferenceResults extends NamedXContentObject, NamedWriteable { + + void writeResult(IngestDocument document, String resultField); + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java new file mode 100644 index 00000000000..884d66032b5 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.ingest.IngestDocument; + +import java.io.IOException; +import java.util.Objects; + +public class RawInferenceResults extends SingleValueInferenceResults { + + public static final String NAME = "raw"; + + public RawInferenceResults(double value) { + super(value); + } + + public RawInferenceResults(StreamInput in) throws IOException { + super(in.readDouble()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + XContentBuilder innerToXContent(XContentBuilder builder, Params params) { + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + RawInferenceResults that = (RawInferenceResults) object; + return Objects.equals(value(), that.value()); + } + + @Override + public int hashCode() { + return Objects.hash(value()); + } + + @Override + public void writeResult(IngestDocument document, String resultField) { + throw new UnsupportedOperationException("[raw] does not support writing inference results"); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java new file mode 100644 index 00000000000..e186489b91d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class RegressionInferenceResults extends SingleValueInferenceResults { + + public static final String NAME = "regression"; + + public RegressionInferenceResults(double value) { + super(value); + } + + public RegressionInferenceResults(StreamInput in) throws IOException { + super(in.readDouble()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + XContentBuilder innerToXContent(XContentBuilder builder, Params params) { + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + RegressionInferenceResults that = (RegressionInferenceResults) object; + return Objects.equals(value(), that.value()); + } + + @Override + public int hashCode() { + return Objects.hash(value()); + } + + @Override + public void writeResult(IngestDocument document, String resultField) { + ExceptionsHelper.requireNonNull(document, "document"); + ExceptionsHelper.requireNonNull(resultField, "resultField"); + document.setFieldValue(resultField, value()); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java new file mode 100644 index 00000000000..2905a667958 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; + +public abstract class SingleValueInferenceResults implements InferenceResults { + + public final ParseField VALUE = new ParseField("value"); + + private final double value; + + SingleValueInferenceResults(StreamInput in) throws IOException { + value = in.readDouble(); + } + + SingleValueInferenceResults(double value) { + this.value = value; + } + + public Double value() { + return value; + } + + public String valueAsString() { + return String.valueOf(value); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(value); + } + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(VALUE.getPreferredName(), value); + innerToXContent(builder, params); + builder.endObject(); + return builder; + } + + abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException; +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java new file mode 100644 index 00000000000..f7da41eda7b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +public class ClassificationConfig implements InferenceConfig { + + public static final String NAME = "classification"; + + public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0; + + public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0); + + private final int numTopClasses; + + public static ClassificationConfig fromMap(Map map) { + Map options = new HashMap<>(map); + Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName()); + if (options.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet()); + } + return new ClassificationConfig(numTopClasses); + } + + public ClassificationConfig(Integer numTopClasses) { + this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; + } + + public ClassificationConfig(StreamInput in) throws IOException { + this.numTopClasses = in.readInt(); + } + + public int getNumTopClasses() { + return numTopClasses; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(numTopClasses); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassificationConfig that = (ClassificationConfig) o; + return Objects.equals(numTopClasses, that.numTopClasses); + } + + @Override + public int hashCode() { + return Objects.hash(numTopClasses); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (numTopClasses != 0) { + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean isTargetTypeSupported(TargetType targetType) { + return TargetType.CLASSIFICATION.equals(targetType); + } + + @Override + public Version getMinimalSupportedVersion() { + return MIN_SUPPORTED_VERSION; + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java new file mode 100644 index 00000000000..5d1dc7983ff --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + + +public interface InferenceConfig extends NamedXContentObject, NamedWriteable { + + boolean isTargetTypeSupported(TargetType targetType); + + /** + * All nodes in the cluster must be at least this version + */ + Version getMinimalSupportedVersion(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java new file mode 100644 index 00000000000..86bf076cd6b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public final class InferenceHelpers { + + private InferenceHelpers() { } + + public static List topClasses(List probabilities, + List classificationLabels, + int numToInclude) { + if (numToInclude == 0) { + return Collections.emptyList(); + } + int[] sortedIndices = IntStream.range(0, probabilities.size()) + .boxed() + .sorted(Comparator.comparing(probabilities::get).reversed()) + .mapToInt(i -> i) + .toArray(); + + if (classificationLabels != null && probabilities.size() != classificationLabels.size()) { + throw ExceptionsHelper + .serverError( + "model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]", + null, + probabilities.size(), + classificationLabels); + } + + List labels = classificationLabels == null ? + // If we don't have the labels we should return the top classification values anyways, they will just be numeric + IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) : + classificationLabels; + + int count = numToInclude < 0 ? probabilities.size() : Math.min(numToInclude, probabilities.size()); + List topClassEntries = new ArrayList<>(count); + for(int i = 0; i < count; i++) { + int idx = sortedIndices[i]; + topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx))); + } + + return topClassEntries; + } + + public static String classificationLabel(double inferenceValue, @Nullable List classificationLabels) { + assert inferenceValue == Math.rint(inferenceValue); + if (classificationLabels == null) { + return String.valueOf(inferenceValue); + } + int label = Double.valueOf(inferenceValue).intValue(); + if (label < 0 || label >= classificationLabels.size()) { + throw ExceptionsHelper.serverError( + "model returned classification value of [{}] which is not a valid index in classification labels [{}]", + null, + label, + classificationLabels); + } + return classificationLabels.get(label); + } + + public static Double toDouble(Object value) { + if (value instanceof Number) { + return ((Number)value).doubleValue(); + } + if (value instanceof String) { + try { + return Double.valueOf((String)value); + } catch (NumberFormatException nfe) { + assert false : "value is not properly formatted double [" + value + "]"; + return null; + } + } + return null; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java new file mode 100644 index 00000000000..b7c4a71b3e7 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; + +/** + * Used by ensemble to pass into sub-models. + */ +public class NullInferenceConfig implements InferenceConfig { + + public static final NullInferenceConfig INSTANCE = new NullInferenceConfig(); + + private NullInferenceConfig() { } + + @Override + public boolean isTargetTypeSupported(TargetType targetType) { + return true; + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.CURRENT; + } + + @Override + public String getWriteableName() { + return "null"; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + } + + @Override + public String getName() { + return "null"; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java new file mode 100644 index 00000000000..6dd03e87747 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class RegressionConfig implements InferenceConfig { + + public static final String NAME = "regression"; + private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0; + + public static RegressionConfig fromMap(Map map) { + if (map.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet()); + } + return new RegressionConfig(); + } + + public RegressionConfig() { + } + + public RegressionConfig(StreamInput in) { + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + return true; + } + + @Override + public int hashCode() { + return Objects.hash(NAME); + } + + @Override + public boolean isTargetTypeSupported(TargetType targetType) { + return TargetType.REGRESSION.equals(targetType); + } + + @Override + public Version getMinimalSupportedVersion() { + return MIN_SUPPORTED_VERSION; + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index cad5a6c0a8c..e206a709180 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -5,14 +5,16 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.apache.lucene.util.Accountable; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; import java.util.List; import java.util.Map; -public interface TrainedModel extends NamedXContentObject, NamedWriteable { +public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accountable { /** * @return List of featureNames expected by the model. In the order that they are expected @@ -23,16 +25,11 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable { * Infer against the provided fields * * @param fields The fields and their values to infer against + * @param config The configuration options for inference * @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0). * For regression this is continuous. */ - double infer(Map fields); - - /** - * @param fields similar to {@link TrainedModel#infer(Map)}, but fields are already in order and doubles - * @return The predicted value. - */ - double infer(List fields); + InferenceResults infer(Map fields, InferenceConfig config); /** * @return {@link TargetType} for the model. @@ -40,26 +37,7 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable { TargetType targetType(); /** - * This gathers the probabilities for each potential classification value. - * - * The probabilities are indexed by classification ordinal label encoding. - * The length of this list is equal to the number of classification labels. - * - * This only should return if the implementation model is inferring classification values and not regression - * @param fields The fields and their values to infer against - * @return The probabilities of each classification value - */ - List classificationProbability(Map fields); - - /** - * @param fields similar to {@link TrainedModel#classificationProbability(Map)} but the fields are already in order and doubles - * @return The probabilities of each classification value - */ - List classificationProbability(List fields); - - /** - * The ordinal encoded list of the classification labels. - * @return Oridinal encoded list of classification labels. + * @return Ordinal encoded list of classification labels. */ @Nullable List classificationLabels(); @@ -72,4 +50,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable { * @throws org.elasticsearch.ElasticsearchException if validations fail */ void validate(); + + /** + * @return The estimated number of operations required at inference time + */ + long estimatedNumOperations(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 5e5199c2405..a59f1a1c245 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -5,6 +5,9 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.Accountables; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -12,7 +15,16 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; @@ -20,14 +32,20 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.OptionalDouble; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; + public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Ensemble.class); // TODO should we have regression/classification sub-classes that accept the builder? public static final ParseField NAME = new ParseField("ensemble"); public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); @@ -106,14 +124,18 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai } @Override - public double infer(Map fields) { - List processedInferences = inferAndProcess(fields); - return outputAggregator.aggregate(processedInferences); - } - - @Override - public double infer(List fields) { - throw new UnsupportedOperationException("Ensemble requires map containing field names and values"); + public InferenceResults infer(Map fields, InferenceConfig config) { + if (config.isTargetTypeSupported(targetType) == false) { + throw ExceptionsHelper.badRequestException( + "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); + } + List inferenceResults = this.models.stream().map(model -> { + InferenceResults results = model.infer(fields, NullInferenceConfig.INSTANCE); + assert results instanceof SingleValueInferenceResults; + return ((SingleValueInferenceResults)results).value(); + }).collect(Collectors.toList()); + List processed = outputAggregator.processValues(inferenceResults); + return buildResults(processed, config); } @Override @@ -121,18 +143,27 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai return targetType; } - @Override - public List classificationProbability(Map fields) { - if ((targetType == TargetType.CLASSIFICATION) == false) { - throw new UnsupportedOperationException( - "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); + private InferenceResults buildResults(List processedInferences, InferenceConfig config) { + // Indicates that the config is useless and the caller just wants the raw value + if (config instanceof NullInferenceConfig) { + return new RawInferenceResults(outputAggregator.aggregate(processedInferences)); + } + switch(targetType) { + case REGRESSION: + return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences)); + case CLASSIFICATION: + ClassificationConfig classificationConfig = (ClassificationConfig) config; + List topClasses = InferenceHelpers.topClasses( + processedInferences, + classificationLabels, + classificationConfig.getNumTopClasses()); + double value = outputAggregator.aggregate(processedInferences); + return new ClassificationInferenceResults(outputAggregator.aggregate(processedInferences), + classificationLabel(value, classificationLabels), + topClasses); + default: + throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model"); } - return inferAndProcess(fields); - } - - @Override - public List classificationProbability(List fields) { - throw new UnsupportedOperationException("Ensemble requires map containing field names and values"); } @Override @@ -140,11 +171,6 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai return classificationLabels; } - private List inferAndProcess(Map fields) { - List modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList()); - return outputAggregator.processValues(modelInferences); - } - @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -204,6 +230,13 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai @Override public void validate() { + if (outputAggregator.compatibleWith(targetType) == false) { + throw ExceptionsHelper.badRequestException( + "aggregate_output [{}] is not compatible with target_type [{}]", + this.targetType, + outputAggregator.getName() + ); + } if (outputAggregator.expectedValueSize() != null && outputAggregator.expectedValueSize() != models.size()) { throw ExceptionsHelper.badRequestException( @@ -219,10 +252,38 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai this.models.forEach(TrainedModel::validate); } + @Override + public long estimatedNumOperations() { + OptionalDouble avg = models.stream().mapToLong(TrainedModel::estimatedNumOperations).average(); + assert avg.isPresent() : "unexpected null when calculating number of operations"; + // Average operations for each model and the operations required for processing and aggregating with the outputAggregator + return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1); + } + public static Builder builder() { return new Builder(); } + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOfCollection(featureNames); + size += RamUsageEstimator.sizeOfCollection(classificationLabels); + size += RamUsageEstimator.sizeOfCollection(models); + size += outputAggregator.ramBytesUsed(); + return size; + } + + @Override + public Collection getChildResources() { + List accountables = new ArrayList<>(models.size() + 1); + for (TrainedModel model : models) { + accountables.add(Accountables.namedAccountable(model.getName(), model)); + } + accountables.add(Accountables.namedAccountable(outputAggregator.getName(), outputAggregator)); + return Collections.unmodifiableCollection(accountables); + } + public static class Builder { private List featureNames; private List trainedModels; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java index 36ca2ba79bc..2dba9691639 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java @@ -6,16 +6,17 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.stream.IntStream; @@ -24,6 +25,7 @@ import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.sigmoid public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class); public static final ParseField NAME = new ParseField("logistic_regression"); public static final ParseField WEIGHTS = new ParseField("weights"); @@ -48,19 +50,23 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie return LENIENT_PARSER.apply(parser, null); } - private final List weights; + private final double[] weights; LogisticRegression() { this((List) null); } - public LogisticRegression(List weights) { - this.weights = weights == null ? null : Collections.unmodifiableList(weights); + private LogisticRegression(List weights) { + this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray()); + } + + public LogisticRegression(double[] weights) { + this.weights = weights; } public LogisticRegression(StreamInput in) throws IOException { if (in.readBoolean()) { - this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + this.weights = in.readDoubleArray(); } else { this.weights = null; } @@ -68,18 +74,18 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie @Override public Integer expectedValueSize() { - return this.weights == null ? null : this.weights.size(); + return this.weights == null ? null : this.weights.length; } @Override public List processValues(List values) { Objects.requireNonNull(values, "values must not be null"); - if (weights != null && values.size() != weights.size()) { + if (weights != null && values.size() != weights.length) { throw new IllegalArgumentException("values must be the same length as weights."); } double summation = weights == null ? values.stream().mapToDouble(Double::valueOf).sum() : - IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).sum(); + IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).sum(); double probOfClassOne = sigmoid(summation); assert 0.0 <= probOfClassOne && probOfClassOne <= 1.0; return Arrays.asList(1.0 - probOfClassOne, probOfClassOne); @@ -108,6 +114,11 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie return NAME.getPreferredName(); } + @Override + public boolean compatibleWith(TargetType targetType) { + return true; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -117,7 +128,7 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(weights != null); if (weights != null) { - out.writeCollection(weights, StreamOutput::writeDouble); + out.writeDoubleArray(weights); } } @@ -136,12 +147,17 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; LogisticRegression that = (LogisticRegression) o; - return Objects.equals(weights, that.weights); + return Arrays.equals(weights, that.weights); } @Override public int hashCode() { - return Objects.hash(weights); + return Arrays.hashCode(weights); } + @Override + public long ramBytesUsed() { + long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights); + return SHALLOW_SIZE + weightSize; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java index 1f882b724ee..16b1fd7c405 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java @@ -5,12 +5,14 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +import org.apache.lucene.util.Accountable; import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; import java.util.List; -public interface OutputAggregator extends NamedXContentObject, NamedWriteable { +public interface OutputAggregator extends NamedXContentObject, NamedWriteable, Accountable { /** * @return The expected size of the values array when aggregating. `null` implies there is no expected size. @@ -44,4 +46,6 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable { * @return The name of the output aggregator */ String getName(); + + boolean compatibleWith(TargetType targetType); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index 739a4e13d86..73689d16b1c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -6,15 +6,18 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Objects; @@ -23,6 +26,7 @@ import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedMode.class); public static final ParseField NAME = new ParseField("weighted_mode"); public static final ParseField WEIGHTS = new ParseField("weights"); @@ -47,19 +51,23 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa return LENIENT_PARSER.apply(parser, null); } - private final List weights; + private final double[] weights; WeightedMode() { - this.weights = null; + this((List) null); } - public WeightedMode(List weights) { - this.weights = weights == null ? null : Collections.unmodifiableList(weights); + private WeightedMode(List weights) { + this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray()); + } + + public WeightedMode(double[] weights) { + this.weights = weights; } public WeightedMode(StreamInput in) throws IOException { if (in.readBoolean()) { - this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + this.weights = in.readDoubleArray(); } else { this.weights = null; } @@ -67,13 +75,13 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa @Override public Integer expectedValueSize() { - return this.weights == null ? null : this.weights.size(); + return this.weights == null ? null : this.weights.length; } @Override public List processValues(List values) { Objects.requireNonNull(values, "values must not be null"); - if (weights != null && values.size() != weights.size()) { + if (weights != null && values.size() != weights.length) { throw new IllegalArgumentException("values must be the same length as weights."); } List freqArray = new ArrayList<>(); @@ -93,7 +101,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa } List frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY)); for (int i = 0; i < freqArray.size(); i++) { - Double weight = weights == null ? 1.0 : weights.get(i); + Double weight = weights == null ? 1.0 : weights[i]; Integer value = freqArray.get(i); Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight; frequencies.set(value, frequency); @@ -123,6 +131,11 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa return NAME.getPreferredName(); } + @Override + public boolean compatibleWith(TargetType targetType) { + return true; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -132,7 +145,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(weights != null); if (weights != null) { - out.writeCollection(weights, StreamOutput::writeDouble); + out.writeDoubleArray(weights); } } @@ -151,11 +164,17 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; WeightedMode that = (WeightedMode) o; - return Objects.equals(weights, that.weights); + return Arrays.equals(weights, that.weights); } @Override public int hashCode() { - return Objects.hash(weights); + return Arrays.hashCode(weights); + } + + @Override + public long ramBytesUsed() { + long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights); + return SHALLOW_SIZE + weightSize; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java index f5812dabf88..ed1c13cf102 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -6,15 +6,17 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; -import java.util.Collections; +import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -23,6 +25,7 @@ import java.util.stream.IntStream; public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedSum.class); public static final ParseField NAME = new ParseField("weighted_sum"); public static final ParseField WEIGHTS = new ParseField("weights"); @@ -47,19 +50,23 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar return LENIENT_PARSER.apply(parser, null); } - private final List weights; + private final double[] weights; WeightedSum() { - this.weights = null; + this((List) null); } - public WeightedSum(List weights) { - this.weights = weights == null ? null : Collections.unmodifiableList(weights); + private WeightedSum(List weights) { + this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray()); + } + + public WeightedSum(double[] weights) { + this.weights = weights; } public WeightedSum(StreamInput in) throws IOException { if (in.readBoolean()) { - this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + this.weights = in.readDoubleArray(); } else { this.weights = null; } @@ -71,10 +78,10 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar if (weights == null) { return values; } - if (values.size() != weights.size()) { + if (values.size() != weights.length) { throw new IllegalArgumentException("values must be the same length as weights."); } - return IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).boxed().collect(Collectors.toList()); + return IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).boxed().collect(Collectors.toList()); } @Override @@ -104,7 +111,7 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(weights != null); if (weights != null) { - out.writeCollection(weights, StreamOutput::writeDouble); + out.writeDoubleArray(weights); } } @@ -123,16 +130,27 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; WeightedSum that = (WeightedSum) o; - return Objects.equals(weights, that.weights); + return Arrays.equals(weights, that.weights); } @Override public int hashCode() { - return Objects.hash(weights); + return Arrays.hashCode(weights); } @Override public Integer expectedValueSize() { - return weights == null ? null : this.weights.size(); + return weights == null ? null : this.weights.length; + } + + @Override + public boolean compatibleWith(TargetType targetType) { + return TargetType.REGRESSION.equals(targetType); + } + + @Override + public long ramBytesUsed() { + long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights); + return SHALLOW_SIZE + weightSize; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 3a91ec0cd86..1408b17a069 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -5,6 +5,9 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.Accountables; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -13,7 +16,15 @@ import org.elasticsearch.common.util.CachedSupplier; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -22,6 +33,7 @@ import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.List; @@ -31,8 +43,11 @@ import java.util.Queue; import java.util.Set; import java.util.stream.Collectors; -public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; +public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel, Accountable { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Tree.class); // TODO should we have regression/classification sub-classes that accept the builder? public static final ParseField NAME = new ParseField("tree"); @@ -72,7 +87,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM Tree(List featureNames, List nodes, TargetType targetType, List classificationLabels) { this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES)); - this.nodes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE)); + if(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) { + throw new IllegalArgumentException("[tree_structure] must not be empty"); + } + this.nodes = Collections.unmodifiableList(nodes); this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue()); @@ -105,20 +123,42 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM } @Override - public double infer(Map fields) { - List features = featureNames.stream().map(f -> - fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null - ).collect(Collectors.toList()); - return infer(features); + public InferenceResults infer(Map fields, InferenceConfig config) { + if (config.isTargetTypeSupported(targetType) == false) { + throw ExceptionsHelper.badRequestException( + "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); + } + + List features = featureNames.stream().map(f -> InferenceHelpers.toDouble(fields.get(f))).collect(Collectors.toList()); + return infer(features, config); } - @Override - public double infer(List features) { + private InferenceResults infer(List features, InferenceConfig config) { TreeNode node = nodes.get(0); while(node.isLeaf() == false) { node = nodes.get(node.compare(features)); } - return node.getLeafValue(); + return buildResult(node.getLeafValue(), config); + } + + private InferenceResults buildResult(Double value, InferenceConfig config) { + // Indicates that the config is useless and the caller just wants the raw value + if (config instanceof NullInferenceConfig) { + return new RawInferenceResults(value); + } + switch (targetType) { + case CLASSIFICATION: + ClassificationConfig classificationConfig = (ClassificationConfig) config; + List topClasses = InferenceHelpers.topClasses( + classificationProbability(value), + classificationLabels, + classificationConfig.getNumTopClasses()); + return new ClassificationInferenceResults(value, classificationLabel(value, classificationLabels), topClasses); + case REGRESSION: + return new RegressionInferenceResults(value); + default: + throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model"); + } } /** @@ -142,34 +182,15 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM return targetType; } - @Override - public List classificationProbability(Map fields) { - if ((targetType == TargetType.CLASSIFICATION) == false) { - throw new UnsupportedOperationException( - "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); - } - List features = featureNames.stream().map(f -> - fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null) - .collect(Collectors.toList()); - - return classificationProbability(features); - } - - @Override - public List classificationProbability(List fields) { - if ((targetType == TargetType.CLASSIFICATION) == false) { - throw new UnsupportedOperationException( - "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); - } - double label = infer(fields); + private List classificationProbability(double inferenceValue) { // If we are classification, we should assume that the inference return value is whole. - assert label == Math.rint(label); + assert inferenceValue == Math.rint(inferenceValue); double maxCategory = this.highestOrderCategory.get(); // If we are classification, we should assume that the largest leaf value is whole. assert maxCategory == Math.rint(maxCategory); List list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)); // TODO, eventually have TreeNodes contain confidence levels - list.set(Double.valueOf(label).intValue(), 1.0); + list.set(Double.valueOf(inferenceValue).intValue(), 1.0); return list; } @@ -239,6 +260,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM detectCycle(); } + @Override + public long estimatedNumOperations() { + // Grabbing the features from the doc + the depth of the tree + return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size(); + } + private void checkTargetType() { if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) { throw ExceptionsHelper.badRequestException( @@ -247,9 +274,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM } private void detectCycle() { - if (nodes.isEmpty()) { - return; - } Set visited = new HashSet<>(nodes.size()); Queue toVisit = new ArrayDeque<>(nodes.size()); toVisit.add(0); @@ -270,10 +294,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM } private void detectMissingNodes() { - if (nodes.isEmpty()) { - return; - } - List missingNodes = new ArrayList<>(); for (int i = 0; i < nodes.size(); i++) { TreeNode currentNode = nodes.get(i); @@ -302,6 +322,24 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM null; } + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOfCollection(classificationLabels); + size += RamUsageEstimator.sizeOfCollection(featureNames); + size += RamUsageEstimator.sizeOfCollection(nodes); + return size; + } + + @Override + public Collection getChildResources() { + List accountables = new ArrayList<>(nodes.size()); + for (TreeNode node : nodes) { + accountables.add(Accountables.namedAccountable("tree_node_" + node.getNodeIndex(), node)); + } + return Collections.unmodifiableCollection(accountables); + } + public static class Builder { private List featureNames; private ArrayList nodes; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java index 9beda88e2c5..9d58c280905 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java @@ -5,6 +5,9 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.Numbers; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -15,14 +18,15 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.job.config.Operator; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.List; import java.util.Objects; -public class TreeNode implements ToXContentObject, Writeable { +public class TreeNode implements ToXContentObject, Writeable, Accountable { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TreeNode.class); public static final String NAME = "tree_node"; public static final ParseField DECISION_TYPE = new ParseField("decision_type"); @@ -63,31 +67,31 @@ public class TreeNode implements ToXContentObject, Writeable { } private final Operator operator; - private final Double threshold; - private final Integer splitFeature; + private final double threshold; + private final int splitFeature; private final int nodeIndex; - private final Double splitGain; - private final Double leafValue; + private final double splitGain; + private final double leafValue; private final boolean defaultLeft; private final int leftChild; private final int rightChild; - TreeNode(Operator operator, - Double threshold, - Integer splitFeature, - Integer nodeIndex, - Double splitGain, - Double leafValue, - Boolean defaultLeft, - Integer leftChild, - Integer rightChild) { + private TreeNode(Operator operator, + Double threshold, + Integer splitFeature, + int nodeIndex, + Double splitGain, + Double leafValue, + Boolean defaultLeft, + Integer leftChild, + Integer rightChild) { this.operator = operator == null ? Operator.LTE : operator; - this.threshold = threshold; - this.splitFeature = splitFeature; - this.nodeIndex = ExceptionsHelper.requireNonNull(nodeIndex, NODE_INDEX.getPreferredName()); - this.splitGain = splitGain; - this.leafValue = leafValue; + this.threshold = threshold == null ? Double.NaN : threshold; + this.splitFeature = splitFeature == null ? -1 : splitFeature; + this.nodeIndex = nodeIndex; + this.splitGain = splitGain == null ? Double.NaN : splitGain; + this.leafValue = leafValue == null ? Double.NaN : leafValue; this.defaultLeft = defaultLeft == null ? false : defaultLeft; this.leftChild = leftChild == null ? -1 : leftChild; this.rightChild = rightChild == null ? -1 : rightChild; @@ -95,11 +99,11 @@ public class TreeNode implements ToXContentObject, Writeable { public TreeNode(StreamInput in) throws IOException { operator = Operator.readFromStream(in); - threshold = in.readOptionalDouble(); - splitFeature = in.readOptionalInt(); - splitGain = in.readOptionalDouble(); - nodeIndex = in.readInt(); - leafValue = in.readOptionalDouble(); + threshold = in.readDouble(); + splitFeature = in.readInt(); + splitGain = in.readDouble(); + nodeIndex = in.readVInt(); + leafValue = in.readDouble(); defaultLeft = in.readBoolean(); leftChild = in.readInt(); rightChild = in.readInt(); @@ -110,23 +114,23 @@ public class TreeNode implements ToXContentObject, Writeable { return operator; } - public Double getThreshold() { + public double getThreshold() { return threshold; } - public Integer getSplitFeature() { + public int getSplitFeature() { return splitFeature; } - public Integer getNodeIndex() { + public int getNodeIndex() { return nodeIndex; } - public Double getSplitGain() { + public double getSplitGain() { return splitGain; } - public Double getLeafValue() { + public double getLeafValue() { return leafValue; } @@ -164,11 +168,11 @@ public class TreeNode implements ToXContentObject, Writeable { @Override public void writeTo(StreamOutput out) throws IOException { operator.writeTo(out); - out.writeOptionalDouble(threshold); - out.writeOptionalInt(splitFeature); - out.writeOptionalDouble(splitGain); - out.writeInt(nodeIndex); - out.writeOptionalDouble(leafValue); + out.writeDouble(threshold); + out.writeInt(splitFeature); + out.writeDouble(splitGain); + out.writeVInt(nodeIndex); + out.writeDouble(leafValue); out.writeBoolean(defaultLeft); out.writeInt(leftChild); out.writeInt(rightChild); @@ -177,12 +181,14 @@ public class TreeNode implements ToXContentObject, Writeable { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - addOptionalField(builder, DECISION_TYPE, operator); - addOptionalField(builder, THRESHOLD, threshold); - addOptionalField(builder, SPLIT_FEATURE, splitFeature); - addOptionalField(builder, SPLIT_GAIN, splitGain); + builder.field(DECISION_TYPE.getPreferredName(), operator); + addOptionalDouble(builder, THRESHOLD, threshold); + if (splitFeature > -1) { + builder.field(SPLIT_FEATURE.getPreferredName(), splitFeature); + } + addOptionalDouble(builder, SPLIT_GAIN, splitGain); builder.field(NODE_INDEX.getPreferredName(), nodeIndex); - addOptionalField(builder, LEAF_VALUE, leafValue); + addOptionalDouble(builder, LEAF_VALUE, leafValue); builder.field(DEFAULT_LEFT.getPreferredName(), defaultLeft); if (leftChild >= 0) { builder.field(LEFT_CHILD.getPreferredName(), leftChild); @@ -194,8 +200,8 @@ public class TreeNode implements ToXContentObject, Writeable { return builder; } - private void addOptionalField(XContentBuilder builder, ParseField field, Object value) throws IOException { - if (value != null) { + private void addOptionalDouble(XContentBuilder builder, ParseField field, double value) throws IOException { + if (Numbers.isValidDouble(value)) { builder.field(field.getPreferredName(), value); } } @@ -237,7 +243,12 @@ public class TreeNode implements ToXContentObject, Writeable { public static Builder builder(int nodeIndex) { return new Builder(nodeIndex); } - + + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE; + } + public static class Builder { private Operator operator; private Double threshold; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java index fb282459d46..1cdddcd7af2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java @@ -5,6 +5,8 @@ */ package org.elasticsearch.xpack.core.ml.inference.utils; +import org.elasticsearch.common.Numbers; + import java.util.List; import java.util.stream.Collectors; @@ -22,24 +24,24 @@ public final class Statistics { */ public static List softMax(List values) { Double expSum = 0.0; - Double max = values.stream().filter(v -> isInvalid(v) == false).max(Double::compareTo).orElse(null); + Double max = values.stream().filter(Statistics::isValid).max(Double::compareTo).orElse(null); if (max == null) { throw new IllegalArgumentException("no valid values present"); } - List exps = values.stream().map(v -> isInvalid(v) ? Double.NEGATIVE_INFINITY : v - max) + List exps = values.stream().map(v -> isValid(v) ? v - max : Double.NEGATIVE_INFINITY) .collect(Collectors.toList()); for (int i = 0; i < exps.size(); i++) { - if (isInvalid(exps.get(i)) == false) { + if (isValid(exps.get(i))) { Double exp = Math.exp(exps.get(i)); expSum += exp; exps.set(i, exp); } } for (int i = 0; i < exps.size(); i++) { - if (isInvalid(exps.get(i))) { - exps.set(i, 0.0); - } else { + if (isValid(exps.get(i))) { exps.set(i, exps.get(i)/expSum); + } else { + exps.set(i, 0.0); } } return exps; @@ -49,8 +51,8 @@ public final class Statistics { return 1/(1 + Math.exp(-value)); } - public static boolean isInvalid(Double v) { - return v == null || Double.isInfinite(v) || Double.isNaN(v); + private static boolean isValid(Double v) { + return v != null && Numbers.isValidDouble(v); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index fda426462bd..961cbbf1b83 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -83,7 +83,13 @@ public final class Messages { public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists"; public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]"; + public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}"; + public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION = + "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]"; public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]"; + public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]"; + public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED = + "Getting model definition is not supported when getting more than one model"; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessage.java new file mode 100644 index 00000000000..7c1b93786bc --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessage.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.notifications; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage; +import org.elasticsearch.xpack.core.common.notifications.Level; +import org.elasticsearch.xpack.core.ml.job.config.Job; + +import java.util.Date; + + +public class InferenceAuditMessage extends AbstractAuditMessage { + + //TODO this should be MODEL_ID... + private static final ParseField JOB_ID = Job.ID; + public static final ConstructingObjectParser PARSER = + createParser("ml_inference_audit_message", InferenceAuditMessage::new, JOB_ID); + + public InferenceAuditMessage(String resourceId, String message, Level level, Date timestamp, String nodeName) { + super(resourceId, message, level, timestamp, nodeName); + } + + @Override + public final String getJobType() { + return "inference"; + } + + @Override + protected String getResourceField() { + return JOB_ID.getPreferredName(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java index 517e600ab44..9cc4a5cdfbf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java @@ -43,6 +43,10 @@ public class ExceptionsHelper { return new ResourceAlreadyExistsException("A data frame analytics with id [{}] already exists", id); } + public static ResourceNotFoundException missingTrainedModel(String modelId) { + return new ResourceNotFoundException("No known trained model with model_id [{}]", modelId); + } + public static ElasticsearchException serverError(String msg) { return new ElasticsearchException(msg); } @@ -51,6 +55,10 @@ public class ExceptionsHelper { return new ElasticsearchException(msg, cause); } + public static ElasticsearchException serverError(String msg, Throwable cause, Object... args) { + return new ElasticsearchException(msg, cause, args); + } + public static ElasticsearchStatusException conflictStatusException(String msg, Throwable cause, Object... args) { return new ElasticsearchStatusException(msg, RestStatus.CONFLICT, cause, args); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelsRequestTests.java new file mode 100644 index 00000000000..0797b20d438 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelsRequestTests.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction.Request; + +public class DeleteTrainedModelsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + return new Request(randomAlphaOfLengthBetween(1, 20)); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java new file mode 100644 index 00000000000..85345467df1 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request; + +public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(randomAlphaOfLength(20), randomBoolean()); + request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java new file mode 100644 index 00000000000..b7e79e68b70 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestStats; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response; + +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class GetTrainedModelsStatsActionResponseTests extends AbstractWireSerializingTestCase { + + @Override + protected Response createTestInstance() { + int listSize = randomInt(10); + List trainedModelStats = Stream.generate(() -> randomAlphaOfLength(10)) + .limit(listSize).map(id -> + new Response.TrainedModelStats(id, + randomBoolean() ? randomIngestStats() : null, + randomIntBetween(0, 10)) + ) + .collect(Collectors.toList()); + return new Response(new QueryPage<>(trainedModelStats, randomLongBetween(listSize, 1000), Response.RESULTS_FIELD)); + } + + private IngestStats randomIngestStats() { + List pipelineIds = Stream.generate(()-> randomAlphaOfLength(10)) + .limit(randomIntBetween(0, 10)) + .collect(Collectors.toList()); + return new IngestStats( + new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()), + pipelineIds.stream().map(id -> new IngestStats.PipelineStat(id, randomStats())).collect(Collectors.toList()), + pipelineIds.stream().collect(Collectors.toMap(Function.identity(), (v) -> randomProcessorStats()))); + } + + private IngestStats.Stats randomStats(){ + return new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()); + } + + private List randomProcessorStats() { + return Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomIntBetween(0, 10)) + .map(name -> new IngestStats.ProcessorStat(name, "inference", randomStats())) + .collect(Collectors.toList()); + } + + @Override + protected Writeable.Reader instanceReader() { + return Response::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java new file mode 100644 index 00000000000..051da354c2e --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class InferModelActionRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + return randomBoolean() ? + new Request( + randomAlphaOfLength(10), + Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()), + randomInferenceConfig()) : + new Request( + randomAlphaOfLength(10), + randomMap(), + randomInferenceConfig()); + } + + private static InferenceConfig randomInferenceConfig() { + return randomFrom(RegressionConfigTests.randomRegressionConfig(), ClassificationConfigTests.randomClassificationConfig()); + } + + private static Map randomMap() { + return Stream.generate(()-> randomAlphaOfLength(10)) + .limit(randomInt(10)) + .collect(Collectors.toMap(Function.identity(), (v) -> randomAlphaOfLength(10))); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java new file mode 100644 index 00000000000..9e72d1c4e68 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class InferModelActionResponseTests extends AbstractWireSerializingTestCase { + + @Override + protected Response createTestInstance() { + String resultType = randomFrom(ClassificationInferenceResults.NAME, RegressionInferenceResults.NAME); + return new Response( + Stream.generate(() -> randomInferenceResult(resultType)) + .limit(randomIntBetween(0, 10)) + .collect(Collectors.toList())); + } + + private static InferenceResults randomInferenceResult(String resultType) { + if (resultType.equals(ClassificationInferenceResults.NAME)) { + return ClassificationInferenceResultsTests.createRandomResults(); + } else if (resultType.equals(RegressionInferenceResults.NAME)) { + return RegressionInferenceResultsTests.createRandomResults(); + } else { + fail("unexpected result type [" + resultType + "]"); + return null; + } + } + + @Override + protected Writeable.Reader instanceReader() { + return Response::new; + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 53b8d5c36c1..a066a29cec1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -33,6 +33,7 @@ import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.not; @@ -73,7 +74,9 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase { - private boolean lenient; - - @Before - public void chooseStrictOrLenient() { - lenient = randomBoolean(); - } - @Override protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException { - return TrainedModelDefinition.fromXContent(parser, lenient).build(); + return TrainedModelDefinition.fromXContent(parser, true).build(); } @Override protected boolean supportsUnknownFields() { - return lenient; + return true; } @Override @@ -58,6 +58,16 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase !field.isEmpty(); } + @Override + protected ToXContent.Params getToXContentParams() { + return new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")); + } + + @Override + protected boolean assertToXContentEquivalence() { + return false; + } + public static TrainedModelDefinition.Builder createRandomBuilder(String modelId) { int numberOfProcessors = randomIntBetween(1, 10); return new TrainedModelDefinition.Builder() @@ -69,7 +79,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase TrainedModelDefinition.fromXContent(parser, false)); + + assertThat(exception.getMessage(), containsString("[trained_model_definition] unknown field [doc_type]")); + } + @Override protected TrainedModelDefinition createTestInstance() { - return createRandomBuilder(null).build(); + return createRandomBuilder(randomAlphaOfLength(10)).build(); } @Override @@ -298,4 +326,9 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase { + + public static ClassificationInferenceResults createRandomResults() { + return new ClassificationInferenceResults(randomDouble(), + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : + Stream.generate(ClassificationInferenceResultsTests::createRandomClassEntry) + .limit(randomIntBetween(0, 10)) + .collect(Collectors.toList())); + } + + private static ClassificationInferenceResults.TopClassEntry createRandomClassEntry() { + return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble()); + } + + public void testWriteResultsWithClassificationLabel() { + ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, "foo", Collections.emptyList()); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field", String.class), equalTo("foo")); + } + + public void testWriteResultsWithoutClassificationLabel() { + ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, Collections.emptyList()); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field", String.class), equalTo("1.0")); + } + + @SuppressWarnings("unchecked") + public void testWriteResultsWithTopClasses() { + List entries = Arrays.asList( + new ClassificationInferenceResults.TopClassEntry("foo", 0.7), + new ClassificationInferenceResults.TopClassEntry("bar", 0.2), + new ClassificationInferenceResults.TopClassEntry("baz", 0.1)); + ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, + "foo", + entries); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + List list = document.getFieldValue("result_field", List.class); + assertThat(list.size(), equalTo(3)); + + for(int i = 0; i < 3; i++) { + Map map = (Map)list.get(i); + assertThat(map, equalTo(entries.get(i).asValueMap())); + } + } + + @Override + protected ClassificationInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return ClassificationInferenceResults::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java new file mode 100644 index 00000000000..d9d4e9933b2 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +public class RawInferenceResultsTests extends AbstractWireSerializingTestCase { + + public static RawInferenceResults createRandomResults() { + return new RawInferenceResults(randomDouble()); + } + + @Override + protected RawInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return RawInferenceResults::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java new file mode 100644 index 00000000000..4f2d5926c84 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; + +import java.util.HashMap; + +import static org.hamcrest.Matchers.equalTo; + + +public class RegressionInferenceResultsTests extends AbstractWireSerializingTestCase { + + public static RegressionInferenceResults createRandomResults() { + return new RegressionInferenceResults(randomDouble()); + } + + public void testWriteResults() { + RegressionInferenceResults result = new RegressionInferenceResults(0.3); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field", Double.class), equalTo(0.3)); + } + + @Override + protected RegressionInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return RegressionInferenceResults::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java new file mode 100644 index 00000000000..808aaf960f4 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.util.Collections; + +import static org.hamcrest.Matchers.equalTo; + +public class ClassificationConfigTests extends AbstractWireSerializingTestCase { + + public static ClassificationConfig randomClassificationConfig() { + return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10)); + } + + public void testFromMap() { + ClassificationConfig expected = new ClassificationConfig(0); + assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected)); + + expected = new ClassificationConfig(3); + assertThat(ClassificationConfig.fromMap(Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3)), + equalTo(expected)); + } + + public void testFromMapWithUnknownField() { + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> ClassificationConfig.fromMap(Collections.singletonMap("some_key", 1))); + assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + } + + @Override + protected ClassificationConfig createTestInstance() { + return randomClassificationConfig(); + } + + @Override + protected Writeable.Reader instanceReader() { + return ClassificationConfig::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpersTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpersTests.java new file mode 100644 index 00000000000..ec5093f625b --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpersTests.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.test.ESTestCase; + +import java.util.HashMap; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; + + +public class InferenceHelpersTests extends ESTestCase { + + public void testToDoubleFromNumbers() { + assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5))); + assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5))); + assertThat(5.0, equalTo(InferenceHelpers.toDouble(5L))); + assertThat(5.0, equalTo(InferenceHelpers.toDouble(5))); + assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5f))); + } + + public void testToDoubleFromString() { + assertThat(0.5, equalTo(InferenceHelpers.toDouble("0.5"))); + assertThat(-0.5, equalTo(InferenceHelpers.toDouble("-0.5"))); + assertThat(5.0, equalTo(InferenceHelpers.toDouble("5"))); + assertThat(-5.0, equalTo(InferenceHelpers.toDouble("-5"))); + + // if ae are turned off, then we should get a null value + // otherwise, we should expect an assertion failure telling us that the string is improperly formatted + try { + assertThat(InferenceHelpers.toDouble(""), is(nullValue())); + } catch (AssertionError ae) { + assertThat(ae.getMessage(), equalTo("value is not properly formatted double []")); + } + try { + assertThat(InferenceHelpers.toDouble("notADouble"), is(nullValue())); + } catch (AssertionError ae) { + assertThat(ae.getMessage(), equalTo("value is not properly formatted double [notADouble]")); + } + } + + public void testToDoubleFromNull() { + assertThat(InferenceHelpers.toDouble(null), is(nullValue())); + } + + public void testDoubleFromUnknownObj() { + assertThat(InferenceHelpers.toDouble(new HashMap<>()), is(nullValue())); + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java new file mode 100644 index 00000000000..bdb0e6d0320 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.util.Collections; + +import static org.hamcrest.Matchers.equalTo; + +public class RegressionConfigTests extends AbstractWireSerializingTestCase { + + public static RegressionConfig randomRegressionConfig() { + return new RegressionConfig(); + } + + public void testFromMap() { + RegressionConfig expected = new RegressionConfig(); + assertThat(RegressionConfig.fromMap(Collections.emptyMap()), equalTo(expected)); + } + + public void testFromMapWithUnknownField() { + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> RegressionConfig.fromMap(Collections.singletonMap("some_key", 1))); + assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + } + + @Override + protected RegressionConfig createTestInstance() { + return randomRegressionConfig(); + } + + @Override + protected Writeable.Reader instanceReader() { + return RegressionConfig::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 7ef08547def..5bff55790f8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -15,13 +15,16 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; import org.junit.Before; - import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -68,9 +71,9 @@ public class EnsembleTests extends AbstractSerializingTestCase { List models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6)) .limit(numberOfModels) .collect(Collectors.toList()); - List weights = randomBoolean() ? + double[] weights = randomBoolean() ? null : - Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); + Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).mapToDouble(Double::valueOf).toArray(); OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights), new LogisticRegression(weights)); @@ -114,9 +117,9 @@ public class EnsembleTests extends AbstractSerializingTestCase { public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() { List featureNames = Arrays.asList("foo", "bar"); int numberOfModels = 5; - List weights = new ArrayList<>(numberOfModels + 2); + double[] weights = new double[numberOfModels + 2]; for (int i = 0; i < numberOfModels + 2; i++) { - weights.add(randomDouble()); + weights[i] = randomDouble(); } OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); @@ -158,6 +161,27 @@ public class EnsembleTests extends AbstractSerializingTestCase { }); } + public void testEnsembleWithAggregatorOutputNotSupportingTargetType() { + List featureNames = Arrays.asList("foo", "bar"); + ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder() + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList( + Tree.builder() + .setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .setFeatureNames(featureNames) + .build())) + .setClassificationLabels(Arrays.asList("label1", "label2")) + .setTargetType(TargetType.CLASSIFICATION) + .setOutputAggregator(new WeightedSum()) + .build() + .validate(); + }); + } + public void testEnsembleWithTargetTypeAndLabelsMismatch() { List featureNames = Arrays.asList("foo", "bar"); String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"; @@ -189,6 +213,7 @@ public class EnsembleTests extends AbstractSerializingTestCase { .setFeatureNames(featureNames) .build())) .setTargetType(TargetType.CLASSIFICATION) + .setOutputAggregator(new WeightedMode()) .build() .validate(); }); @@ -236,32 +261,35 @@ public class EnsembleTests extends AbstractSerializingTestCase { .setTargetType(TargetType.CLASSIFICATION) .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) - .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) + .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0})) .build(); List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - List expected = Arrays.asList(0.231475216, 0.768524783); + List expected = Arrays.asList(0.768524783, 0.231475216); double eps = 0.000001; - List probabilities = ensemble.classificationProbability(featureMap); + List probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - expected = Arrays.asList(0.3100255188, 0.689974481); - probabilities = ensemble.classificationProbability(featureMap); + expected = Arrays.asList(0.689974481, 0.3100255188); + probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); - expected = Arrays.asList(0.231475216, 0.768524783); - probabilities = ensemble.classificationProbability(featureMap); + expected = Arrays.asList(0.768524783, 0.231475216); + probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } // This should handle missing values and take the default_left path @@ -270,9 +298,10 @@ public class EnsembleTests extends AbstractSerializingTestCase { put("bar", null); }}; expected = Arrays.asList(0.6899744811, 0.3100255188); - probabilities = ensemble.classificationProbability(featureMap); + probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } } @@ -292,7 +321,9 @@ public class EnsembleTests extends AbstractSerializingTestCase { .setLeftChild(3) .setRightChild(4)) .addNode(TreeNode.builder(3).setLeafValue(0.0)) - .addNode(TreeNode.builder(4).setLeafValue(1.0)).build(); + .addNode(TreeNode.builder(4).setLeafValue(1.0)) + .setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION)) + .build(); Tree tree2 = Tree.builder() .setFeatureNames(featureNames) .setRoot(TreeNode.builder(0) @@ -302,6 +333,7 @@ public class EnsembleTests extends AbstractSerializingTestCase { .setThreshold(0.5)) .addNode(TreeNode.builder(1).setLeafValue(0.0)) .addNode(TreeNode.builder(2).setLeafValue(1.0)) + .setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION)) .build(); Tree tree3 = Tree.builder() .setFeatureNames(featureNames) @@ -312,31 +344,36 @@ public class EnsembleTests extends AbstractSerializingTestCase { .setThreshold(1.0)) .addNode(TreeNode.builder(1).setLeafValue(1.0)) .addNode(TreeNode.builder(2).setLeafValue(0.0)) + .setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION)) .build(); Ensemble ensemble = Ensemble.builder() .setTargetType(TargetType.CLASSIFICATION) .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) - .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) + .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0})) .build(); List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); featureMap = new HashMap(2) {{ put("foo", 0.3); put("bar", null); }}; - assertEquals(0.0, ensemble.infer(featureMap), 0.00001); + assertThat(0.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); } public void testRegressionInference() { @@ -370,16 +407,18 @@ public class EnsembleTests extends AbstractSerializingTestCase { .setTargetType(TargetType.REGRESSION) .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList(tree1, tree2)) - .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5))) + .setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5})) .build(); List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - assertEquals(0.9, ensemble.infer(featureMap), 0.00001); + assertThat(0.9, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(0.5, ensemble.infer(featureMap), 0.00001); + assertThat(0.5, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); // Test with NO aggregator supplied, verifies default behavior of non-weighted sum ensemble = Ensemble.builder() @@ -390,17 +429,32 @@ public class EnsembleTests extends AbstractSerializingTestCase { featureVector = Arrays.asList(0.4, 0.0); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.8, ensemble.infer(featureMap), 0.00001); + assertThat(1.8, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); featureMap = new HashMap(2) {{ put("foo", 0.3); put("bar", null); }}; - assertEquals(1.8, ensemble.infer(featureMap), 0.00001); + assertThat(1.8, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); + } + + public void testOperationsEstimations() { + Tree tree1 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 2); + Tree tree2 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5); + Tree tree3 = TreeTests.buildRandomTree(Arrays.asList("foo", "baz"), 3); + Ensemble ensemble = Ensemble.builder().setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(Arrays.asList("foo", "bar", "baz")) + .setOutputAggregator(new LogisticRegression(new double[]{0.1, 0.4, 1.0})) + .build(); + assertThat(ensemble.estimatedNumOperations(), equalTo(9L)); } private static Map zipObjMap(List keys, List values) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java index b68a3763390..e630a5874fc 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java @@ -8,20 +8,21 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; public class LogisticRegressionTests extends WeightedAggregatorTests { @Override LogisticRegression createTestInstance(int numberOfWeights) { - List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()); + double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray(); return new LogisticRegression(weights); } @@ -41,13 +42,13 @@ public class LogisticRegressionTests extends WeightedAggregatorTests ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0); + double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0}; List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); LogisticRegression logisticRegression = new LogisticRegression(ones); assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0)); - List variedWeights = Arrays.asList(.01, -1.0, .1, 0.0, 0.0); + double[] variedWeights = new double[]{.01, -1.0, .1, 0.0, 0.0}; logisticRegression = new LogisticRegression(variedWeights); assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(0.0)); @@ -56,4 +57,9 @@ public class LogisticRegressionTests extends WeightedAggregatorTests { @Override WeightedMode createTestInstance(int numberOfWeights) { - List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()); + double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray(); return new WeightedMode(weights); } @@ -41,13 +42,13 @@ public class WeightedModeTests extends WeightedAggregatorTests { } public void testAggregate() { - List ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0); + double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0}; List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); WeightedMode weightedMode = new WeightedMode(ones); assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); - List variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0); + double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0}; weightedMode = new WeightedMode(variedWeights); assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(5.0)); @@ -55,4 +56,10 @@ public class WeightedModeTests extends WeightedAggregatorTests { weightedMode = new WeightedMode(); assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); } + + public void testCompatibleWith() { + WeightedMode weightedMode = createTestInstance(); + assertThat(weightedMode.compatibleWith(TargetType.CLASSIFICATION), is(true)); + assertThat(weightedMode.compatibleWith(TargetType.REGRESSION), is(true)); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java index 89222365c83..8e4a6577dbb 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java @@ -8,20 +8,21 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; public class WeightedSumTests extends WeightedAggregatorTests { @Override WeightedSum createTestInstance(int numberOfWeights) { - List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()); + double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray(); return new WeightedSum(weights); } @@ -41,13 +42,13 @@ public class WeightedSumTests extends WeightedAggregatorTests { } public void testAggregate() { - List ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0); + double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0}; List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); WeightedSum weightedSum = new WeightedSum(ones); assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0)); - List variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0); + double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0}; weightedSum = new WeightedSum(variedWeights); assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(28.0)); @@ -55,4 +56,10 @@ public class WeightedSumTests extends WeightedAggregatorTests { weightedSum = new WeightedSum(); assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0)); } + + public void testCompatibleWith() { + WeightedSum weightedSum = createTestInstance(); + assertThat(weightedSum.compatibleWith(TargetType.CLASSIFICATION), is(false)); + assertThat(weightedSum.compatibleWith(TargetType.REGRESSION), is(true)); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index b98d19b07ff..0fe8ce47fef 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -10,6 +10,10 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.junit.Before; @@ -104,6 +108,19 @@ public class TreeTests extends AbstractSerializingTestCase { return Tree::new; } + public void testInferWithStump() { + Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION); + builder.setRoot(TreeNode.builder(0).setLeafValue(42.0)); + builder.setFeatureNames(Collections.emptyList()); + + Tree tree = builder.build(); + List featureNames = Arrays.asList("foo", "bar"); + List featureVector = Arrays.asList(0.6, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump + assertThat(42.0, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); + } + public void testInfer() { // Build a tree with 2 nodes and 3 leaves using 2 features // The leaves have unique values 0.1, 0.2, 0.3 @@ -120,26 +137,36 @@ public class TreeTests extends AbstractSerializingTestCase { // This feature vector should hit the right child of the root node List featureVector = Arrays.asList(0.6, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - assertThat(0.3, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.3, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); // This should hit the left child of the left child of the root node // i.e. it takes the path left, left featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.1, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); // This should hit the right child of the left child of the root node // i.e. it takes the path left, right featureVector = Arrays.asList(0.3, 0.9); featureMap = zipObjMap(featureNames, featureVector); - assertThat(0.2, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.2, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); + + // This should still work if the internal values are strings + List featureVectorStrings = Arrays.asList("0.3", "0.9"); + featureMap = zipObjMap(featureNames, featureVectorStrings); + assertThat(0.2, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); // This should handle missing values and take the default_left path featureMap = new HashMap(2) {{ put("foo", 0.3); put("bar", null); }}; - assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.1, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); } public void testTreeClassificationProbability() { @@ -153,31 +180,43 @@ public class TreeTests extends AbstractSerializingTestCase { builder.addLeaf(leftChildNode.getRightChild(), 0.0); List featureNames = Arrays.asList("foo", "bar"); - Tree tree = builder.setFeatureNames(featureNames).build(); + Tree tree = builder.setFeatureNames(featureNames).setClassificationLabels(Arrays.asList("cat", "dog")).build(); + double eps = 0.000001; // This feature vector should hit the right child of the root node List featureVector = Arrays.asList(0.6, 0.0); + List expectedProbs = Arrays.asList(1.0, 0.0); + List expectedFields = Arrays.asList("dog", "cat"); Map featureMap = zipObjMap(featureNames, featureVector); - assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap)); + List probabilities = + ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); + for(int i = 0; i < expectedProbs.size(); i++) { + assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); + assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); + } // This should hit the left child of the left child of the root node // i.e. it takes the path left, left featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap)); - - // This should hit the right child of the left child of the root node - // i.e. it takes the path left, right - featureVector = Arrays.asList(0.3, 0.9); - featureMap = zipObjMap(featureNames, featureVector); - assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap)); + probabilities = + ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); + for(int i = 0; i < expectedProbs.size(); i++) { + assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); + assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); + } // This should handle missing values and take the default_left path featureMap = new HashMap(2) {{ put("foo", 0.3); put("bar", null); }}; - assertEquals(1.0, tree.infer(featureMap), 0.00001); + probabilities = + ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); + for(int i = 0; i < expectedProbs.size(); i++) { + assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); + assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); + } } public void testTreeWithNullRoot() { @@ -261,7 +300,12 @@ public class TreeTests extends AbstractSerializingTestCase { assertThat(ex.getMessage(), equalTo(msg)); } - private static Map zipObjMap(List keys, List values) { + public void testOperationsEstimations() { + Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5); + assertThat(tree.estimatedNumOperations(), equalTo(7L)); + } + + private static Map zipObjMap(List keys, List values) { return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AnomalyDetectionAuditMessageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AnomalyDetectionAuditMessageTests.java index c6a904228b6..f6a319dab7a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AnomalyDetectionAuditMessageTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AnomalyDetectionAuditMessageTests.java @@ -6,19 +6,16 @@ package org.elasticsearch.xpack.core.ml.notifications; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.xpack.core.common.notifications.Level; import org.elasticsearch.xpack.core.ml.job.config.Job; import java.util.Date; -import static org.hamcrest.Matchers.equalTo; +public class AnomalyDetectionAuditMessageTests extends AuditMessageTests { -public class AnomalyDetectionAuditMessageTests extends AbstractXContentTestCase { - - public void testGetJobType() { - AnomalyDetectionAuditMessage message = createTestInstance(); - assertThat(message.getJobType(), equalTo(Job.ANOMALY_DETECTOR_JOB_TYPE)); + @Override + public String getJobType() { + return Job.ANOMALY_DETECTOR_JOB_TYPE; } @Override @@ -26,11 +23,6 @@ public class AnomalyDetectionAuditMessageTests extends AbstractXContentTestCase< return AnomalyDetectionAuditMessage.PARSER.apply(parser, null); } - @Override - protected boolean supportsUnknownFields() { - return true; - } - @Override protected AnomalyDetectionAuditMessage createTestInstance() { return new AnomalyDetectionAuditMessage( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AuditMessageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AuditMessageTests.java new file mode 100644 index 00000000000..2ccb1fbcbf4 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AuditMessageTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.notifications; + +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage; + + +import static org.hamcrest.Matchers.equalTo; + +public abstract class AuditMessageTests extends AbstractXContentTestCase { + + public abstract String getJobType(); + + public void testGetJobType() { + AbstractAuditMessage message = createTestInstance(); + assertThat(message.getJobType(), equalTo(getJobType())); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/DataFrameAnalyticsAuditMessageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/DataFrameAnalyticsAuditMessageTests.java index 139e76160d4..9637af79a94 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/DataFrameAnalyticsAuditMessageTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/DataFrameAnalyticsAuditMessageTests.java @@ -6,30 +6,22 @@ package org.elasticsearch.xpack.core.ml.notifications; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.xpack.core.common.notifications.Level; import java.util.Date; -import static org.hamcrest.Matchers.equalTo; +public class DataFrameAnalyticsAuditMessageTests extends AuditMessageTests { -public class DataFrameAnalyticsAuditMessageTests extends AbstractXContentTestCase { - - public void testGetJobType() { - DataFrameAnalyticsAuditMessage message = createTestInstance(); - assertThat(message.getJobType(), equalTo("data_frame_analytics")); + @Override + public String getJobType() { + return "data_frame_analytics"; } - + @Override protected DataFrameAnalyticsAuditMessage doParseInstance(XContentParser parser) { return DataFrameAnalyticsAuditMessage.PARSER.apply(parser, null); } - @Override - protected boolean supportsUnknownFields() { - return true; - } - @Override protected DataFrameAnalyticsAuditMessage createTestInstance() { return new DataFrameAnalyticsAuditMessage( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessageTests.java new file mode 100644 index 00000000000..5a9b86578ef --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessageTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.notifications; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.common.notifications.Level; + +import java.util.Date; + +public class InferenceAuditMessageTests extends AuditMessageTests { + + @Override + public String getJobType() { + return "inference"; + } + + @Override + protected InferenceAuditMessage doParseInstance(XContentParser parser) { + return InferenceAuditMessage.PARSER.apply(parser, null); + } + + @Override + protected InferenceAuditMessage createTestInstance() { + return new InferenceAuditMessage( + randomBoolean() ? null : randomAlphaOfLength(10), + randomAlphaOfLengthBetween(1, 20), + randomFrom(Level.values()), + new Date(), + randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20) + ); + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 063ad43a921..961dc944ea7 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -126,6 +126,13 @@ integTest.runner { 'ml/filter_crud/Test get all filter given index exists but no mapping for filter_id', 'ml/get_datafeed_stats/Test get datafeed stats given missing datafeed_id', 'ml/get_datafeeds/Test get datafeed given missing datafeed_id', + 'ml/inference_crud/Test delete given used trained model', + 'ml/inference_crud/Test delete given unused trained model', + 'ml/inference_crud/Test delete with missing model', + 'ml/inference_crud/Test get given missing trained model', + 'ml/inference_crud/Test get given expression without matches and allow_no_match is false', + 'ml/inference_stats_crud/Test get stats given missing trained model', + 'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false', 'ml/jobs_crud/Test cannot create job with existing categorizer state document', 'ml/jobs_crud/Test cannot create job with existing quantiles document', 'ml/jobs_crud/Test cannot create job with existing result document', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java new file mode 100644 index 00000000000..c1b260090a1 --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -0,0 +1,544 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.ingest.SimulateDocumentBaseResult; +import org.elasticsearch.action.ingest.SimulatePipelineResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.junit.Before; + +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { + + @Before + public void createBothModels() { + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) + .setId("test_classification") + .setSource(CLASSIFICATION_CONFIG, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) + .setId(TrainedModelDefinition.docId("test_classification")) + .setSource(CLASSIFICATION_DEFINITION, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) + .setId("test_regression") + .setSource(REGRESSION_CONFIG, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) + .setId(TrainedModelDefinition.docId("test_regression")) + .setSource(REGRESSION_DEFINITION, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + } + + public void testPipelineCreationAndDeletion() throws Exception { + + for (int i = 0; i < 10; i++) { + assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline", + new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME) + .setSource(new HashMap(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}) + .setPipeline("simple_classification_pipeline") + .get(); + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(), + is(true)); + + assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline", + new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME) + .setSource(new HashMap(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}) + .setPipeline("simple_regression_pipeline") + .get(); + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(), + is(true)); + } + + assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline", + new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline", + new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + for (int i = 0; i < 10; i++) { + client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME) + .setSource(generateSourceDoc()) + .setPipeline("simple_classification_pipeline") + .get(); + + client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME) + .setSource(generateSourceDoc()) + .setPipeline("simple_regression_pipeline") + .get(); + } + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(), + is(true)); + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(), + is(true)); + + client().admin().indices().refresh(new RefreshRequest("index_for_inference_test")).get(); + + assertThat(client().search(new SearchRequest().indices("index_for_inference_test") + .source(new SearchSourceBuilder() + .size(0) + .trackTotalHits(true) + .query(QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("regression_value"))))).get().getHits().getTotalHits().value, + equalTo(20L)); + + assertThat(client().search(new SearchRequest().indices("index_for_inference_test") + .source(new SearchSourceBuilder() + .size(0) + .trackTotalHits(true) + .query(QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("result_class"))))).get().getHits().getTotalHits().value, + equalTo(20L)); + + } + + public void testSimulate() { + String source = "{\n" + + " \"pipeline\": {\n" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class\",\n" + + " \"inference_config\": {\"classification\":{}},\n" + + " \"model_id\": \"test_classification\",\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class_prob\",\n" + + " \"inference_config\": {\"classification\": {\"num_top_classes\":2}},\n" + + " \"model_id\": \"test_classification\",\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"regression_value\",\n" + + " \"model_id\": \"test_regression\",\n" + + " \"inference_config\": {\"regression\":{}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " },\n" + + " \"docs\": [\n" + + " {\"_source\": {\n" + + " \"col1\": \"female\",\n" + + " \"col2\": \"M\",\n" + + " \"col3\": \"none\",\n" + + " \"col4\": 10\n" + + " }}]\n" + + "}"; + + SimulatePipelineResponse response = client().admin().cluster() + .prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get(); + SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0); + assertThat(baseResult.getIngestDocument().getFieldValue("regression_value", Double.class), equalTo(1.0)); + assertThat(baseResult.getIngestDocument().getFieldValue("result_class", String.class), equalTo("second")); + assertThat(baseResult.getIngestDocument().getFieldValue("result_class_prob", List.class).size(), equalTo(2)); + + String sourceWithMissingModel = "{\n" + + " \"pipeline\": {\n" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class\",\n" + + " \"model_id\": \"test_classification_missing\",\n" + + " \"inference_config\": {\"classification\":{}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " },\n" + + " \"docs\": [\n" + + " {\"_source\": {\n" + + " \"col1\": \"female\",\n" + + " \"col2\": \"M\",\n" + + " \"col3\": \"none\",\n" + + " \"col4\": 10\n" + + " }}]\n" + + "}"; + + response = client().admin().cluster() + .prepareSimulatePipeline(new BytesArray(sourceWithMissingModel.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get(); + + assertThat(((SimulateDocumentBaseResult) response.getResults().get(0)).getFailure().getMessage(), + containsString("Could not find trained model [test_classification_missing]")); + } + + private Map generateSourceDoc() { + return new HashMap(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}; + } + + private static final String REGRESSION_DEFINITION = "{" + + " \"preprocessors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"col1\",\n" + + " \"hot_map\": {\n" + + " \"male\": \"col1_male\",\n" + + " \"female\": \"col1_female\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"target_mean_encoding\": {\n" + + " \"field\": \"col2\",\n" + + " \"feature_name\": \"col2_encoded\",\n" + + " \"target_map\": {\n" + + " \"S\": 5.0,\n" + + " \"M\": 10.0,\n" + + " \"L\": 20\n" + + " },\n" + + " \"default_value\": 5.0\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field\": \"col3\",\n" + + " \"feature_name\": \"col3_encoded\",\n" + + " \"frequency_map\": {\n" + + " \"none\": 0.75,\n" + + " \"true\": 0.10,\n" + + " \"false\": 0.15\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"trained_model\": {\n" + + " \"ensemble\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"aggregate_output\": {\n" + + " \"weighted_sum\": {\n" + + " \"weights\": [\n" + + " 0.5,\n" + + " 0.5\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"target_type\": \"regression\",\n" + + " \"trained_models\": [\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"model_id\": \"test_regression\"\n" + + "}"; + + private static final String REGRESSION_CONFIG = "{" + + " \"model_id\": \"test_regression\",\n" + + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + + " \"description\": \"test model for regression\",\n" + + " \"version\": \"7.6.0\",\n" + + " \"created_by\": \"ml_test\",\n" + + " \"estimated_heap_memory_usage_bytes\": 0," + + " \"estimated_operations\": 0," + + " \"created_time\": 0" + + "}"; + + private static final String CLASSIFICATION_DEFINITION = "{" + + " \"preprocessors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"col1\",\n" + + " \"hot_map\": {\n" + + " \"male\": \"col1_male\",\n" + + " \"female\": \"col1_female\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"target_mean_encoding\": {\n" + + " \"field\": \"col2\",\n" + + " \"feature_name\": \"col2_encoded\",\n" + + " \"target_map\": {\n" + + " \"S\": 5.0,\n" + + " \"M\": 10.0,\n" + + " \"L\": 20\n" + + " },\n" + + " \"default_value\": 5.0\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field\": \"col3\",\n" + + " \"feature_name\": \"col3_encoded\",\n" + + " \"frequency_map\": {\n" + + " \"none\": 0.75,\n" + + " \"true\": 0.10,\n" + + " \"false\": 0.15\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"trained_model\": {\n" + + " \"ensemble\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"aggregate_output\": {\n" + + " \"weighted_mode\": {\n" + + " \"weights\": [\n" + + " 0.5,\n" + + " 0.5\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"target_type\": \"classification\",\n" + + " \"classification_labels\": [\"first\", \"second\"],\n" + + " \"trained_models\": [\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"model_id\": \"test_classification\"\n" + + "}"; + + private static final String CLASSIFICATION_CONFIG = "" + + "{\n" + + " \"model_id\": \"test_classification\",\n" + + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + + " \"description\": \"test model for classification\",\n" + + " \"version\": \"7.6.0\",\n" + + " \"created_by\": \"benwtrent\",\n" + + " \"estimated_heap_memory_usage_bytes\": 0," + + " \"estimated_operations\": 0," + + " \"created_time\": 0\n" + + "}"; + + private static final String CLASSIFICATION_PIPELINE = "{" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class\",\n" + + " \"model_id\": \"test_classification\",\n" + + " \"inference_config\": {\"classification\": {}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }]}\n"; + + private static final String REGRESSION_PIPELINE = "{" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"regression_value\",\n" + + " \"model_id\": \"test_regression\",\n" + + " \"inference_config\": {\"regression\": {}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }]}\n"; + +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java new file mode 100644 index 00000000000..153b169ea8f --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -0,0 +1,221 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.apache.http.util.EntityUtils; +import org.elasticsearch.Version; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.test.SecuritySettingsSourceField; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests; +import org.junit.After; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; + +import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public class TrainedModelIT extends ESRestTestCase { + + private static final String BASIC_AUTH_VALUE = basicAuthHeaderValue("x_pack_rest_user", + SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING); + + @Override + protected Settings restClientSettings() { + return Settings.builder().put(super.restClientSettings()).put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE).build(); + } + + @Override + protected boolean preserveTemplatesUponCompletion() { + return true; + } + + public void testGetTrainedModels() throws IOException { + String modelId = "test_regression_model"; + String modelId2 = "test_regression_model-2"; + Request model1 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId); + model1.setJsonEntity(buildRegressionModel(modelId)); + assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201)); + + Request modelDefinition1 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId)); + modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId)); + assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201)); + + Request model2 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2); + model2.setJsonEntity(buildRegressionModel(modelId2)); + assertThat(client().performRequest(model2).getStatusLine().getStatusCode(), equalTo(201)); + + adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh")); + Response getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/" + modelId)); + + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + String response = EntityUtils.toString(getModel.getEntity()); + + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"count\":1")); + + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/test_regression*")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + assertThat(response, not(containsString("\"definition\""))); + assertThat(response, containsString("\"count\":2")); + + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/test_regression_model?human=true&include_model_definition=true")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"heap_memory_estimation_bytes\"")); + assertThat(response, containsString("\"heap_memory_estimation\"")); + assertThat(response, containsString("\"definition\"")); + assertThat(response, containsString("\"count\":1")); + + ResponseException responseException = expectThrows(ResponseException.class, () -> + client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/test_regression*?human=true&include_model_definition=true"))); + assertThat(EntityUtils.toString(responseException.getResponse().getEntity()), + containsString(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED)); + + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/test_regression_model,test_regression_model-2")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + assertThat(response, containsString("\"count\":2")); + + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/classification*?allow_no_match=true")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"count\":0")); + + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/classification*?allow_no_match=false"))); + assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + + getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=0&size=1")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"count\":2")); + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, not(containsString("\"model_id\":\"test_regression_model-2\""))); + + getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=1&size=1")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"count\":2")); + assertThat(response, not(containsString("\"model_id\":\"test_regression_model\""))); + assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + } + + public void testDeleteTrainedModels() throws IOException { + String modelId = "test_delete_regression_model"; + Request model1 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId); + model1.setJsonEntity(buildRegressionModel(modelId)); + assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201)); + + Request modelDefinition1 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId)); + modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId)); + assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201)); + + adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh")); + + Response delModel = client().performRequest(new Request("DELETE", + MachineLearning.BASE_PATH + "inference/" + modelId)); + String response = EntityUtils.toString(delModel.getEntity()); + assertThat(response, containsString("\"acknowledged\":true")); + + ResponseException responseException = expectThrows(ResponseException.class, + () -> client().performRequest(new Request("DELETE", MachineLearning.BASE_PATH + "inference/" + modelId))); + assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + + responseException = expectThrows(ResponseException.class, + () -> client().performRequest( + new Request("GET", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId)))); + assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + + responseException = expectThrows(ResponseException.class, + () -> client().performRequest( + new Request("GET", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId))); + assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + } + + private static String buildRegressionModel(String modelId) throws IOException { + try(XContentBuilder builder = XContentFactory.jsonBuilder()) { + TrainedModelConfig.builder() + .setModelId(modelId) + .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3"))) + .setCreatedBy("ml_test") + .setVersion(Version.CURRENT) + .setCreateTime(Instant.now()) + .setEstimatedOperations(0) + .setEstimatedHeapMemory(0) + .build() + .toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON); + } + } + + private static String buildRegressionModelDefinition(String modelId) throws IOException { + try(XContentBuilder builder = XContentFactory.jsonBuilder()) { + new TrainedModelDefinition.Builder() + .setPreProcessors(Collections.emptyList()) + .setTrainedModel(LocalModelTests.buildRegression()) + .setModelId(modelId) + .build() + .toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON); + } + } + + + @After + public void clearMlState() throws Exception { + new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata(); + ESRestTestCase.waitForPendingTasks(adminClient()); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index ab2c7a00f08..f367ccc4381 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -43,12 +43,14 @@ import org.elasticsearch.env.NodeEnvironment; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.TokenizerFactory; import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider; +import org.elasticsearch.ingest.Processor; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.monitor.os.OsProbe; import org.elasticsearch.monitor.os.OsStats; import org.elasticsearch.persistent.PersistentTasksExecutor; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.AnalysisPlugin; +import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestController; @@ -72,6 +74,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteFilterAction; import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction; import org.elasticsearch.xpack.core.ml.action.DeleteJobAction; import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; @@ -93,7 +96,10 @@ import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction; import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.KillProcessAction; import org.elasticsearch.xpack.core.ml.action.MlInfoAction; import org.elasticsearch.xpack.core.ml.action.OpenJobAction; @@ -139,6 +145,7 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteFilterAction; import org.elasticsearch.xpack.ml.action.TransportDeleteForecastAction; import org.elasticsearch.xpack.ml.action.TransportDeleteJobAction; import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction; +import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.action.TransportEstimateMemoryUsageAction; import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction; import org.elasticsearch.xpack.ml.action.TransportFinalizeJobExecutionAction; @@ -160,6 +167,9 @@ import org.elasticsearch.xpack.ml.action.TransportGetJobsStatsAction; import org.elasticsearch.xpack.ml.action.TransportGetModelSnapshotsAction; import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction; import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction; +import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsStatsAction; +import org.elasticsearch.xpack.ml.action.TransportInferModelAction; +import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsAction; import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportKillProcessAction; import org.elasticsearch.xpack.ml.action.TransportMlInfoAction; @@ -199,6 +209,8 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactor import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.job.JobManager; @@ -221,6 +233,7 @@ import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerFactory; import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import org.elasticsearch.xpack.ml.process.NativeController; import org.elasticsearch.xpack.ml.process.NativeControllerHolder; @@ -257,6 +270,9 @@ import org.elasticsearch.xpack.ml.rest.filter.RestDeleteFilterAction; import org.elasticsearch.xpack.ml.rest.filter.RestGetFiltersAction; import org.elasticsearch.xpack.ml.rest.filter.RestPutFilterAction; import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction; +import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction; +import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; +import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction; @@ -297,7 +313,7 @@ import java.util.function.UnaryOperator; import static java.util.Collections.emptyList; import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME; -public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlugin, PersistentTaskPlugin { +public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; public static final String PRE_V7_BASE_PATH = "/_xpack/ml/"; @@ -321,6 +337,22 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu }; + @Override + public Map getProcessors(Processor.Parameters parameters) { + if (this.enabled == false) { + return Collections.emptyMap(); + } + + InferenceProcessor.Factory inferenceFactory = new InferenceProcessor.Factory(parameters.client, + parameters.ingestService.getClusterService(), + this.settings, + parameters.ingestService, + getLicenseState()); + getLicenseState().addListener(inferenceFactory); + parameters.ingestService.addIngestClusterStateListener(inferenceFactory); + return Collections.singletonMap(InferenceProcessor.TYPE, inferenceFactory); + } + @Override public Set getRoles() { return Collections.singleton(ML_ROLE); @@ -401,18 +433,21 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu public List> getSettings() { return Collections.unmodifiableList( - Arrays.asList(MachineLearningField.AUTODETECT_PROCESS, - PROCESS_CONNECT_TIMEOUT, - ML_ENABLED, - CONCURRENT_JOB_ALLOCATIONS, - MachineLearningField.MAX_MODEL_MEMORY_LIMIT, - MAX_LAZY_ML_NODES, - MAX_MACHINE_MEMORY_PERCENT, - AutodetectBuilder.DONT_PERSIST_MODEL_STATE_SETTING, - AutodetectBuilder.MAX_ANOMALY_RECORDS_SETTING_DYNAMIC, - MAX_OPEN_JOBS_PER_NODE, - MIN_DISK_SPACE_OFF_HEAP, - MlConfigMigrationEligibilityCheck.ENABLE_CONFIG_MIGRATION)); + Arrays.asList(MachineLearningField.AUTODETECT_PROCESS, + PROCESS_CONNECT_TIMEOUT, + ML_ENABLED, + CONCURRENT_JOB_ALLOCATIONS, + MachineLearningField.MAX_MODEL_MEMORY_LIMIT, + MAX_LAZY_ML_NODES, + MAX_MACHINE_MEMORY_PERCENT, + AutodetectBuilder.DONT_PERSIST_MODEL_STATE_SETTING, + AutodetectBuilder.MAX_ANOMALY_RECORDS_SETTING_DYNAMIC, + MAX_OPEN_JOBS_PER_NODE, + MIN_DISK_SPACE_OFF_HEAP, + MlConfigMigrationEligibilityCheck.ENABLE_CONFIG_MIGRATION, + InferenceProcessor.MAX_INFERENCE_PROCESSORS, + ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE, + ModelLoadingService.INFERENCE_MODEL_CACHE_TTL)); } public Settings additionalSettings() { @@ -483,6 +518,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu AnomalyDetectionAuditor anomalyDetectionAuditor = new AnomalyDetectionAuditor(client, clusterService.getNodeName()); DataFrameAnalyticsAuditor dataFrameAnalyticsAuditor = new DataFrameAnalyticsAuditor(client, clusterService.getNodeName()); + InferenceAuditor inferenceAuditor = new InferenceAuditor(client, clusterService.getNodeName()); this.dataFrameAnalyticsAuditor.set(dataFrameAnalyticsAuditor); JobResultsProvider jobResultsProvider = new JobResultsProvider(client, settings); JobResultsPersister jobResultsPersister = new JobResultsPersister(client); @@ -569,7 +605,12 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu this.datafeedManager.set(datafeedManager); // Inference components - TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry); + final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry); + final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + inferenceAuditor, + threadPool, + clusterService, + settings); // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory, @@ -614,12 +655,14 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu datafeedManager, anomalyDetectionAuditor, dataFrameAnalyticsAuditor, + inferenceAuditor, mlAssignmentNotifier, memoryTracker, analyticsProcessManager, memoryEstimationProcessManager, dataFrameAnalyticsConfigProvider, nativeStorageProvider, + modelLoadingService, trainedModelProvider ); } @@ -717,7 +760,10 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu new RestStartDataFrameAnalyticsAction(restController), new RestStopDataFrameAnalyticsAction(restController), new RestEvaluateDataFrameAction(restController), - new RestEstimateMemoryUsageAction(restController) + new RestEstimateMemoryUsageAction(restController), + new RestGetTrainedModelsAction(restController), + new RestDeleteTrainedModelAction(restController), + new RestGetTrainedModelsStatsAction(restController) ); } @@ -784,7 +830,11 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu new ActionHandler<>(StartDataFrameAnalyticsAction.INSTANCE, TransportStartDataFrameAnalyticsAction.class), new ActionHandler<>(StopDataFrameAnalyticsAction.INSTANCE, TransportStopDataFrameAnalyticsAction.class), new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class), - new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class) + new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class), + new ActionHandler<>(InferModelAction.INSTANCE, TransportInferModelAction.class), + new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class), + new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class), + new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningFeatureSet.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningFeatureSet.java index e06feb4d6aa..088d16c87e3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningFeatureSet.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningFeatureSet.java @@ -10,6 +10,12 @@ import org.apache.lucene.util.Constants; import org.apache.lucene.util.Counter; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.cluster.node.stats.NodeStats; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.MetaData; @@ -18,8 +24,10 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.env.Environment; +import org.elasticsearch.ingest.IngestStats; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.plugins.Platforms; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.XPackFeatureSet; import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.XPackPlugin; @@ -31,11 +39,13 @@ import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeStats; import org.elasticsearch.xpack.core.ml.stats.ForecastStats; import org.elasticsearch.xpack.core.ml.stats.StatsAccumulator; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.process.NativeController; import org.elasticsearch.xpack.ml.process.NativeControllerHolder; @@ -44,11 +54,14 @@ import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; @@ -135,10 +148,10 @@ public class MachineLearningFeatureSet implements XPackFeatureSet { @Override public void usage(ActionListener listener) { ClusterState state = clusterService.state(); - new Retriever(client, jobManagerHolder, available(), enabled(), mlNodeCount(state)).execute(listener); + new Retriever(client, jobManagerHolder, available(), enabled(), state).execute(listener); } - private int mlNodeCount(final ClusterState clusterState) { + private static int mlNodeCount(boolean enabled, final ClusterState clusterState) { if (enabled == false) { return 0; } @@ -161,9 +174,11 @@ public class MachineLearningFeatureSet implements XPackFeatureSet { private Map jobsUsage; private Map datafeedsUsage; private Map analyticsUsage; + private Map inferenceUsage; + private final ClusterState state; private int nodeCount; - public Retriever(Client client, JobManagerHolder jobManagerHolder, boolean available, boolean enabled, int nodeCount) { + public Retriever(Client client, JobManagerHolder jobManagerHolder, boolean available, boolean enabled, ClusterState state) { this.client = Objects.requireNonNull(client); this.jobManagerHolder = jobManagerHolder; this.available = available; @@ -171,61 +186,15 @@ public class MachineLearningFeatureSet implements XPackFeatureSet { this.jobsUsage = new LinkedHashMap<>(); this.datafeedsUsage = new LinkedHashMap<>(); this.analyticsUsage = new LinkedHashMap<>(); - this.nodeCount = nodeCount; + this.inferenceUsage = new LinkedHashMap<>(); + this.nodeCount = mlNodeCount(enabled, state); + this.state = state; } - public void execute(ActionListener listener) { - // empty holder means either ML disabled or transport client mode - if (jobManagerHolder.isEmpty()) { - listener.onResponse( - new MachineLearningFeatureSetUsage(available, - enabled, - Collections.emptyMap(), - Collections.emptyMap(), - Collections.emptyMap(), - 0)); - return; - } - - // Step 3. Extract usage from data frame analytics and return usage response - ActionListener dataframeAnalyticsListener = ActionListener.wrap( - response -> { - addDataFrameAnalyticsUsage(response, analyticsUsage); - listener.onResponse(new MachineLearningFeatureSetUsage(available, - enabled, - jobsUsage, - datafeedsUsage, - analyticsUsage, - nodeCount)); - }, - listener::onFailure - ); - - // Step 2. Extract usage from datafeeds stats and return usage response - ActionListener datafeedStatsListener = - ActionListener.wrap(response -> { - addDatafeedsUsage(response); - GetDataFrameAnalyticsStatsAction.Request dataframeAnalyticsStatsRequest = - new GetDataFrameAnalyticsStatsAction.Request(GetDatafeedsStatsAction.ALL); - dataframeAnalyticsStatsRequest.setPageParams(new PageParams(0, 10_000)); - client.execute(GetDataFrameAnalyticsStatsAction.INSTANCE, dataframeAnalyticsStatsRequest, dataframeAnalyticsListener); - }, - listener::onFailure); - - // Step 1. Extract usage from jobs stats and then request stats for all datafeeds - GetJobsStatsAction.Request jobStatsRequest = new GetJobsStatsAction.Request(MetaData.ALL); - ActionListener jobStatsListener = ActionListener.wrap( - response -> { - jobManagerHolder.getJobManager().expandJobs(MetaData.ALL, true, ActionListener.wrap(jobs -> { - addJobsUsage(response, jobs.results()); - GetDatafeedsStatsAction.Request datafeedStatsRequest = new GetDatafeedsStatsAction.Request( - GetDatafeedsStatsAction.ALL); - client.execute(GetDatafeedsStatsAction.INSTANCE, datafeedStatsRequest, datafeedStatsListener); - }, listener::onFailure)); - }, listener::onFailure); - - // Step 0. Kick off the chain of callbacks by requesting jobs stats - client.execute(GetJobsStatsAction.INSTANCE, jobStatsRequest, jobStatsListener); + private static void initializeStats(Map emptyStatsMap) { + emptyStatsMap.put("sum", 0L); + emptyStatsMap.put("min", 0L); + emptyStatsMap.put("max", 0L); } private void addJobsUsage(GetJobsStatsAction.Response response, List jobs) { @@ -322,7 +291,7 @@ public class MachineLearningFeatureSet implements XPackFeatureSet { } private void addDataFrameAnalyticsUsage(GetDataFrameAnalyticsStatsAction.Response response, - Map dataframeAnalyticsUsage) { + Map dataframeAnalyticsUsage) { Map dataFrameAnalyticsStateCounterMap = new HashMap<>(); for(GetDataFrameAnalyticsStatsAction.Response.Stats stats : response.getResponse().results()) { @@ -334,5 +303,149 @@ public class MachineLearningFeatureSet implements XPackFeatureSet { createCountUsageEntry(dataFrameAnalyticsStateCounterMap.get(state).get())); } } + + private static void updateStats(Map statsMap, Long value) { + statsMap.compute("sum", (k, v) -> v + value); + statsMap.compute("min", (k, v) -> Math.min(v, value)); + statsMap.compute("max", (k, v) -> Math.max(v, value)); + } + + private static String[] ingestNodes(final ClusterState clusterState) { + String[] ingestNodes = new String[clusterState.nodes().getIngestNodes().size()]; + Iterator nodeIterator = clusterState.nodes().getIngestNodes().keysIt(); + int i = 0; + while (nodeIterator.hasNext()) { + ingestNodes[i++] = nodeIterator.next(); + } + return ingestNodes; + } + + public void execute(ActionListener listener) { + // empty holder means either ML disabled or transport client mode + if (jobManagerHolder.isEmpty()) { + listener.onResponse( + new MachineLearningFeatureSetUsage(available, + enabled, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + 0)); + return; + } + + // Step 5. extract trained model config count and then return results + ActionListener trainedModelConfigCountListener = ActionListener.wrap( + response -> { + addTrainedModelStats(response, inferenceUsage); + MachineLearningFeatureSetUsage usage = new MachineLearningFeatureSetUsage(available, + enabled, jobsUsage, datafeedsUsage, analyticsUsage, inferenceUsage, nodeCount); + listener.onResponse(usage); + }, + listener::onFailure + ); + + // Step 4. Extract usage from ingest statistics and gather trained model config count + ActionListener nodesStatsListener = ActionListener.wrap( + response -> { + addInferenceIngestUsage(response, inferenceUsage); + SearchRequestBuilder requestBuilder = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) + .setSize(0) + .setTrackTotalHits(true); + ClientHelper.executeAsyncWithOrigin(client.threadPool().getThreadContext(), + ClientHelper.ML_ORIGIN, + requestBuilder.request(), + trainedModelConfigCountListener, + client::search); + }, + listener::onFailure + ); + + // Step 3. Extract usage from data frame analytics stats and then request ingest node stats + ActionListener dataframeAnalyticsListener = ActionListener.wrap( + response -> { + addDataFrameAnalyticsUsage(response, analyticsUsage); + String[] ingestNodes = ingestNodes(state); + NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear().ingest(true); + client.execute(NodesStatsAction.INSTANCE, nodesStatsRequest, nodesStatsListener); + }, + listener::onFailure + ); + + // Step 2. Extract usage from datafeeds stats and return usage response + ActionListener datafeedStatsListener = + ActionListener.wrap(response -> { + addDatafeedsUsage(response); + GetDataFrameAnalyticsStatsAction.Request dataframeAnalyticsStatsRequest = + new GetDataFrameAnalyticsStatsAction.Request(GetDatafeedsStatsAction.ALL); + dataframeAnalyticsStatsRequest.setPageParams(new PageParams(0, 10_000)); + client.execute(GetDataFrameAnalyticsStatsAction.INSTANCE, dataframeAnalyticsStatsRequest, dataframeAnalyticsListener); + }, + listener::onFailure); + + // Step 1. Extract usage from jobs stats and then request stats for all datafeeds + GetJobsStatsAction.Request jobStatsRequest = new GetJobsStatsAction.Request(MetaData.ALL); + ActionListener jobStatsListener = ActionListener.wrap( + response -> { + jobManagerHolder.getJobManager().expandJobs(MetaData.ALL, true, ActionListener.wrap(jobs -> { + addJobsUsage(response, jobs.results()); + GetDatafeedsStatsAction.Request datafeedStatsRequest = new GetDatafeedsStatsAction.Request( + GetDatafeedsStatsAction.ALL); + client.execute(GetDatafeedsStatsAction.INSTANCE, datafeedStatsRequest, datafeedStatsListener); + }, listener::onFailure)); + }, listener::onFailure); + + // Step 0. Kick off the chain of callbacks by requesting jobs stats + client.execute(GetJobsStatsAction.INSTANCE, jobStatsRequest, jobStatsListener); + } + + //TODO separate out ours and users models possibly regression vs classification + private void addTrainedModelStats(SearchResponse response, Map inferenceUsage) { + inferenceUsage.put("trained_models", + Collections.singletonMap(MachineLearningFeatureSetUsage.ALL, + createCountUsageEntry(response.getHits().getTotalHits().value))); + } + + //TODO separate out ours and users models possibly regression vs classification + private void addInferenceIngestUsage(NodesStatsResponse response, Map inferenceUsage) { + Set pipelines = new HashSet<>(); + Map docCountStats = new HashMap<>(3); + Map timeStats = new HashMap<>(3); + Map failureStats = new HashMap<>(3); + initializeStats(docCountStats); + initializeStats(timeStats); + initializeStats(failureStats); + + response.getNodes() + .stream() + .map(NodeStats::getIngestStats) + .map(IngestStats::getProcessorStats) + .forEach(map -> + map.forEach((pipelineId, processors) -> { + boolean containsInference = false; + for (IngestStats.ProcessorStat stats : processors) { + if (stats.getName().equals(InferenceProcessor.TYPE)) { + containsInference = true; + long ingestCount = stats.getStats().getIngestCount(); + long ingestTime = stats.getStats().getIngestTimeInMillis(); + long failureCount = stats.getStats().getIngestFailedCount(); + updateStats(docCountStats, ingestCount); + updateStats(timeStats, ingestTime); + updateStats(failureStats, failureCount); + } + } + if (containsInference) { + pipelines.add(pipelineId); + } + }) + ); + + Map ingestUsage = new HashMap<>(6); + ingestUsage.put("pipelines", createCountUsageEntry(pipelines.size())); + ingestUsage.put("num_docs_processed", docCountStats); + ingestUsage.put("time_ms", timeStats); + ingestUsage.put("num_failures", failureStats); + inferenceUsage.put("ingest_processors", Collections.singletonMap(MachineLearningFeatureSetUsage.ALL, ingestUsage)); + } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java new file mode 100644 index 00000000000..47a9fdd32ad --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.Pipeline; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + + +/** + * The action is a master node action to ensure it reads an up-to-date cluster + * state in order to determine if there is a processor referencing the trained model + */ +public class TransportDeleteTrainedModelAction + extends TransportMasterNodeAction { + + private static final Logger LOGGER = LogManager.getLogger(TransportDeleteTrainedModelAction.class); + + private final TrainedModelProvider trainedModelProvider; + private final InferenceAuditor auditor; + private final IngestService ingestService; + + @Inject + public TransportDeleteTrainedModelAction(TransportService transportService, ClusterService clusterService, + ThreadPool threadPool, ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, + TrainedModelProvider configProvider, InferenceAuditor auditor, + IngestService ingestService) { + super(DeleteTrainedModelAction.NAME, transportService, clusterService, threadPool, actionFilters, + DeleteTrainedModelAction.Request::new, indexNameExpressionResolver); + this.trainedModelProvider = configProvider; + this.ingestService = ingestService; + this.auditor = Objects.requireNonNull(auditor); + } + + @Override + protected String executor() { + return ThreadPool.Names.SAME; + } + + @Override + protected AcknowledgedResponse read(StreamInput in) throws IOException { + return new AcknowledgedResponse(in); + } + + @Override + protected void masterOperation(DeleteTrainedModelAction.Request request, + ClusterState state, + ActionListener listener) { + String id = request.getId(); + IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE); + Set referencedModels = getReferencedModelKeys(currentIngestMetadata); + + if (referencedModels.contains(id)) { + listener.onFailure(new ElasticsearchStatusException("Cannot delete model [{}] as it is still referenced by ingest processors", + RestStatus.CONFLICT, + id)); + return; + } + + trainedModelProvider.deleteTrainedModel(request.getId(), ActionListener.wrap( + r -> { + auditor.info(request.getId(), "trained model deleted"); + listener.onResponse(new AcknowledgedResponse(true)); + }, + listener::onFailure + )); + } + + private Set getReferencedModelKeys(IngestMetadata ingestMetadata) { + Set allReferencedModelKeys = new HashSet<>(); + if (ingestMetadata == null) { + return allReferencedModelKeys; + } + for(Map.Entry entry : ingestMetadata.getPipelines().entrySet()) { + String pipelineId = entry.getKey(); + Map config = entry.getValue().getConfigAsMap(); + try { + Pipeline pipeline = Pipeline.create(pipelineId, + config, + ingestService.getProcessorFactories(), + ingestService.getScriptService()); + pipeline.getProcessors().stream() + .filter(p -> p instanceof InferenceProcessor) + .map(p -> (InferenceProcessor) p) + .map(InferenceProcessor::getModelId) + .forEach(allReferencedModelKeys::add); + } catch (Exception ex) { + LOGGER.warn(new ParameterizedMessage("failed to load pipeline [{}]", pipelineId), ex); + } + } + return allReferencedModelKeys; + } + + + @Override + protected ClusterBlockException checkBlock(DeleteTrainedModelAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java new file mode 100644 index 00000000000..15629579368 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Response; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; + +import java.util.Collections; +import java.util.Set; + + +public class TransportGetTrainedModelsAction extends HandledTransportAction { + + private final TrainedModelProvider provider; + @Inject + public TransportGetTrainedModelsAction(TransportService transportService, + ActionFilters actionFilters, + TrainedModelProvider trainedModelProvider) { + super(GetTrainedModelsAction.NAME, transportService, actionFilters, GetTrainedModelsAction.Request::new); + this.provider = trainedModelProvider; + } + + @Override + protected void doExecute(Task task, Request request, ActionListener listener) { + + Response.Builder responseBuilder = Response.builder(); + + ActionListener>> idExpansionListener = ActionListener.wrap( + totalAndIds -> { + responseBuilder.setTotalCount(totalAndIds.v1()); + + if (totalAndIds.v2().isEmpty()) { + listener.onResponse(responseBuilder.build()); + return; + } + + if (request.isIncludeModelDefinition() && totalAndIds.v2().size() > 1) { + listener.onFailure( + ExceptionsHelper.badRequestException(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED) + ); + return; + } + + if (request.isIncludeModelDefinition()) { + provider.getTrainedModel(totalAndIds.v2().iterator().next(), true, ActionListener.wrap( + config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()), + listener::onFailure + )); + } else { + provider.getTrainedModels(totalAndIds.v2(), request.isAllowNoResources(), ActionListener.wrap( + configs -> listener.onResponse(responseBuilder.setModels(configs).build()), + listener::onFailure + )); + } + }, + listener::onFailure + ); + + provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idExpansionListener); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java new file mode 100644 index 00000000000..a15579b62de --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java @@ -0,0 +1,249 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.cluster.node.stats.NodeStats; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.metrics.CounterMetric; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.IngestStats; +import org.elasticsearch.ingest.Pipeline; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + + +public class TransportGetTrainedModelsStatsAction extends HandledTransportAction { + + private final Client client; + private final ClusterService clusterService; + private final IngestService ingestService; + private final TrainedModelProvider trainedModelProvider; + + @Inject + public TransportGetTrainedModelsStatsAction(TransportService transportService, + ActionFilters actionFilters, + ClusterService clusterService, + IngestService ingestService, + TrainedModelProvider trainedModelProvider, + Client client) { + super(GetTrainedModelsStatsAction.NAME, transportService, actionFilters, GetTrainedModelsStatsAction.Request::new); + this.client = client; + this.clusterService = clusterService; + this.ingestService = ingestService; + this.trainedModelProvider = trainedModelProvider; + } + + @Override + protected void doExecute(Task task, + GetTrainedModelsStatsAction.Request request, + ActionListener listener) { + + GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder(); + + ActionListener nodesStatsListener = ActionListener.wrap( + nodesStatsResponse -> { + Map modelIdIngestStats = inferenceIngestStatsByPipelineId(nodesStatsResponse, + pipelineIdsByModelIds(clusterService.state(), + ingestService, + responseBuilder.getExpandedIds())); + listener.onResponse(responseBuilder.setIngestStatsByModelId(modelIdIngestStats).build()); + }, + listener::onFailure + ); + + ActionListener>> idsListener = ActionListener.wrap( + tuple -> { + responseBuilder.setExpandedIds(tuple.v2()) + .setTotalModelCount(tuple.v1()); + String[] ingestNodes = ingestNodes(clusterService.state()); + NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear().ingest(true); + executeAsyncWithOrigin(client, ML_ORIGIN, NodesStatsAction.INSTANCE, nodesStatsRequest, nodesStatsListener); + }, + listener::onFailure + ); + + trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idsListener); + } + + static Map inferenceIngestStatsByPipelineId(NodesStatsResponse response, + Map> modelIdToPipelineId) { + + Map ingestStatsMap = new HashMap<>(); + + modelIdToPipelineId.forEach((modelId, pipelineIds) -> { + List collectedStats = response.getNodes() + .stream() + .map(nodeStats -> ingestStatsForPipelineIds(nodeStats, pipelineIds)) + .collect(Collectors.toList()); + ingestStatsMap.put(modelId, mergeStats(collectedStats)); + }); + + return ingestStatsMap; + } + + static String[] ingestNodes(final ClusterState clusterState) { + String[] ingestNodes = new String[clusterState.nodes().getIngestNodes().size()]; + Iterator nodeIterator = clusterState.nodes().getIngestNodes().keysIt(); + int i = 0; + while(nodeIterator.hasNext()) { + ingestNodes[i++] = nodeIterator.next(); + } + return ingestNodes; + } + + static Map> pipelineIdsByModelIds(ClusterState state, IngestService ingestService, Set modelIds) { + IngestMetadata ingestMetadata = state.metaData().custom(IngestMetadata.TYPE); + Map> pipelineIdsByModelIds = new HashMap<>(); + if (ingestMetadata == null) { + return pipelineIdsByModelIds; + } + + ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> { + try { + Pipeline pipeline = Pipeline.create(pipelineId, + pipelineConfiguration.getConfigAsMap(), + ingestService.getProcessorFactories(), + ingestService.getScriptService()); + pipeline.getProcessors().forEach(processor -> { + if (processor instanceof InferenceProcessor) { + InferenceProcessor inferenceProcessor = (InferenceProcessor) processor; + if (modelIds.contains(inferenceProcessor.getModelId())) { + pipelineIdsByModelIds.computeIfAbsent(inferenceProcessor.getModelId(), + m -> new LinkedHashSet<>()).add(pipelineId); + } + } + }); + } catch (Exception ex) { + throw new ElasticsearchException("unexpected failure gathering pipeline information", ex); + } + }); + + return pipelineIdsByModelIds; + } + + static IngestStats ingestStatsForPipelineIds(NodeStats nodeStats, Set pipelineIds) { + IngestStats fullNodeStats = nodeStats.getIngestStats(); + Map> filteredProcessorStats = new HashMap<>(fullNodeStats.getProcessorStats()); + filteredProcessorStats.keySet().retainAll(pipelineIds); + List filteredPipelineStats = fullNodeStats.getPipelineStats() + .stream() + .filter(pipelineStat -> pipelineIds.contains(pipelineStat.getPipelineId())) + .collect(Collectors.toList()); + CounterMetric ingestCount = new CounterMetric(); + CounterMetric ingestTimeInMillis = new CounterMetric(); + CounterMetric ingestCurrent = new CounterMetric(); + CounterMetric ingestFailedCount = new CounterMetric(); + + filteredPipelineStats.forEach(pipelineStat -> { + IngestStats.Stats stats = pipelineStat.getStats(); + ingestCount.inc(stats.getIngestCount()); + ingestTimeInMillis.inc(stats.getIngestTimeInMillis()); + ingestCurrent.inc(stats.getIngestCurrent()); + ingestFailedCount.inc(stats.getIngestFailedCount()); + }); + + return new IngestStats( + new IngestStats.Stats(ingestCount.count(), ingestTimeInMillis.count(), ingestCurrent.count(), ingestFailedCount.count()), + filteredPipelineStats, + filteredProcessorStats); + } + + private static IngestStats mergeStats(List ingestStatsList) { + + Map pipelineStatsAcc = new LinkedHashMap<>(ingestStatsList.size()); + Map> processorStatsAcc = new LinkedHashMap<>(ingestStatsList.size()); + IngestStatsAccumulator totalStats = new IngestStatsAccumulator(); + ingestStatsList.forEach(ingestStats -> { + + ingestStats.getPipelineStats() + .forEach(pipelineStat -> + pipelineStatsAcc.computeIfAbsent(pipelineStat.getPipelineId(), + p -> new IngestStatsAccumulator()).inc(pipelineStat.getStats())); + + ingestStats.getProcessorStats() + .forEach((pipelineId, processorStat) -> { + Map processorAcc = processorStatsAcc.computeIfAbsent(pipelineId, + k -> new LinkedHashMap<>()); + processorStat.forEach(p -> + processorAcc.computeIfAbsent(p.getName(), + k -> new IngestStatsAccumulator(p.getType())).inc(p.getStats())); + }); + + totalStats.inc(ingestStats.getTotalStats()); + }); + + List pipelineStatList = new ArrayList<>(pipelineStatsAcc.size()); + pipelineStatsAcc.forEach((pipelineId, accumulator) -> + pipelineStatList.add(new IngestStats.PipelineStat(pipelineId, accumulator.build()))); + + Map> processorStatList = new LinkedHashMap<>(processorStatsAcc.size()); + processorStatsAcc.forEach((pipelineId, accumulatorMap) -> { + List processorStats = new ArrayList<>(accumulatorMap.size()); + accumulatorMap.forEach((processorName, acc) -> + processorStats.add(new IngestStats.ProcessorStat(processorName, acc.type, acc.build()))); + processorStatList.put(pipelineId, processorStats); + }); + + return new IngestStats(totalStats.build(), pipelineStatList, processorStatList); + } + + private static class IngestStatsAccumulator { + CounterMetric ingestCount = new CounterMetric(); + CounterMetric ingestTimeInMillis = new CounterMetric(); + CounterMetric ingestCurrent = new CounterMetric(); + CounterMetric ingestFailedCount = new CounterMetric(); + + String type; + + IngestStatsAccumulator() {} + + IngestStatsAccumulator(String type) { + this.type = type; + } + + IngestStatsAccumulator inc(IngestStats.Stats s) { + ingestCount.inc(s.getIngestCount()); + ingestTimeInMillis.inc(s.getIngestTimeInMillis()); + ingestCurrent.inc(s.getIngestCurrent()); + ingestFailedCount.inc(s.getIngestFailedCount()); + return this; + } + + IngestStats.Stats build() { + return new IngestStats.Stats(ingestCount.count(), ingestTimeInMillis.count(), ingestCurrent.count(), ingestFailedCount.count()); + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java new file mode 100644 index 00000000000..4edd214094f --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.ml.inference.loadingservice.Model; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; +import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; + + +public class TransportInferModelAction extends HandledTransportAction { + + private final ModelLoadingService modelLoadingService; + private final Client client; + private final XPackLicenseState licenseState; + + @Inject + public TransportInferModelAction(TransportService transportService, + ActionFilters actionFilters, + ModelLoadingService modelLoadingService, + Client client, + XPackLicenseState licenseState) { + super(InferModelAction.NAME, transportService, actionFilters, InferModelAction.Request::new); + this.modelLoadingService = modelLoadingService; + this.client = client; + this.licenseState = licenseState; + } + + @Override + protected void doExecute(Task task, InferModelAction.Request request, ActionListener listener) { + + if (licenseState.isMachineLearningAllowed() == false) { + listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + return; + } + + ActionListener getModelListener = ActionListener.wrap( + model -> { + TypedChainTaskExecutor typedChainTaskExecutor = + new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME), + // run through all tasks + r -> true, + // Always fail immediately and return an error + ex -> true); + request.getObjectsToInfer().forEach(stringObjectMap -> + typedChainTaskExecutor.add(chainedTask -> + model.infer(stringObjectMap, request.getConfig(), chainedTask))); + + typedChainTaskExecutor.execute(ActionListener.wrap( + inferenceResultsInterfaces -> + listener.onResponse(new InferModelAction.Response(inferenceResultsInterfaces)), + listener::onFailure + )); + }, + listener::onFailure + ); + + this.modelLoadingService.getModel(request.getModelId(), getModelListener); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index eb47c17137a..3abc3b5e43c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -150,6 +150,8 @@ public class AnalyticsResultProcessor { .setMetadata(Collections.singletonMap("analytics_config", XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) .setDefinition(definition) + .setEstimatedHeapMemory(definition.ramBytesUsed()) + .setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations()) .setInput(new TrainedModelInput(fieldNames)) .build(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java index 0127ff26f3c..b1c7bf6599a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java @@ -7,14 +7,17 @@ package org.elasticsearch.xpack.ml.dataframe.process.results; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import java.io.IOException; +import java.util.Collections; import java.util.Objects; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; public class AnalyticsResult implements ToXContentObject { @@ -67,7 +70,9 @@ public class AnalyticsResult implements ToXContentObject { builder.field(PROGRESS_PERCENT.getPreferredName(), progressPercent); } if (inferenceModel != null) { - builder.field(INFERENCE_MODEL.getPreferredName(), inferenceModel); + builder.field(INFERENCE_MODEL.getPreferredName(), + inferenceModel, + new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"))); } builder.endObject(); return builder; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java new file mode 100644 index 00000000000..fe9f942dff4 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -0,0 +1,297 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.ingest; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.ingest.AbstractProcessor; +import org.elasticsearch.ingest.ConfigurationUtils; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.Pipeline; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.ingest.Processor; +import org.elasticsearch.license.LicenseStateListener; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class InferenceProcessor extends AbstractProcessor { + + // How many total inference processors are allowed to be used in the cluster. + public static final Setting MAX_INFERENCE_PROCESSORS = Setting.intSetting("xpack.ml.max_inference_processors", + 50, + 1, + Setting.Property.Dynamic, + Setting.Property.NodeScope); + + public static final String TYPE = "inference"; + public static final String MODEL_ID = "model_id"; + public static final String INFERENCE_CONFIG = "inference_config"; + public static final String TARGET_FIELD = "target_field"; + public static final String FIELD_MAPPINGS = "field_mappings"; + public static final String MODEL_INFO_FIELD = "model_info_field"; + public static final String INCLUDE_MODEL_METADATA = "include_model_metadata"; + + private final Client client; + private final String modelId; + + private final String targetField; + private final String modelInfoField; + private final InferenceConfig inferenceConfig; + private final Map fieldMapping; + private final boolean includeModelMetadata; + + public InferenceProcessor(Client client, + String tag, + String targetField, + String modelId, + InferenceConfig inferenceConfig, + Map fieldMapping, + String modelInfoField, + boolean includeModelMetadata) { + super(tag); + this.client = ExceptionsHelper.requireNonNull(client, "client"); + this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD); + this.modelInfoField = ExceptionsHelper.requireNonNull(modelInfoField, MODEL_INFO_FIELD); + this.includeModelMetadata = includeModelMetadata; + this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); + this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); + this.fieldMapping = ExceptionsHelper.requireNonNull(fieldMapping, FIELD_MAPPINGS); + } + + public String getModelId() { + return modelId; + } + + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + executeAsyncWithOrigin(client, + ML_ORIGIN, + InferModelAction.INSTANCE, + this.buildRequest(ingestDocument), + ActionListener.wrap( + r -> { + try { + mutateDocument(r, ingestDocument); + handler.accept(ingestDocument, null); + } catch(ElasticsearchException ex) { + handler.accept(ingestDocument, ex); + } + }, + e -> handler.accept(ingestDocument, e) + )); + } + + InferModelAction.Request buildRequest(IngestDocument ingestDocument) { + Map fields = new HashMap<>(ingestDocument.getSourceAndMetadata()); + if (fieldMapping != null) { + fieldMapping.forEach((src, dest) -> { + Object srcValue = fields.remove(src); + if (srcValue != null) { + fields.put(dest, srcValue); + } + }); + } + return new InferModelAction.Request(modelId, fields, inferenceConfig); + } + + void mutateDocument(InferModelAction.Response response, IngestDocument ingestDocument) { + if (response.getInferenceResults().isEmpty()) { + throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR); + } + response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField); + if (includeModelMetadata) { + ingestDocument.setFieldValue(modelInfoField + "." + MODEL_ID, modelId); + } + } + + @Override + public IngestDocument execute(IngestDocument ingestDocument) { + throw new UnsupportedOperationException("should never be called"); + } + + @Override + public String getType() { + return TYPE; + } + + public static final class Factory implements Processor.Factory, Consumer, LicenseStateListener { + + private static final Logger logger = LogManager.getLogger(Factory.class); + + private final Client client; + private final IngestService ingestService; + private final XPackLicenseState licenseState; + private volatile int currentInferenceProcessors; + private volatile int maxIngestProcessors; + private volatile Version minNodeVersion = Version.CURRENT; + private volatile boolean inferenceAllowed; + + public Factory(Client client, + ClusterService clusterService, + Settings settings, + IngestService ingestService, + XPackLicenseState licenseState) { + this.client = client; + this.maxIngestProcessors = MAX_INFERENCE_PROCESSORS.get(settings); + this.ingestService = ingestService; + this.licenseState = licenseState; + this.inferenceAllowed = licenseState.isMachineLearningAllowed(); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_INFERENCE_PROCESSORS, this::setMaxIngestProcessors); + } + + @Override + public void accept(ClusterState state) { + minNodeVersion = state.nodes().getMinNodeVersion(); + MetaData metaData = state.getMetaData(); + if (metaData == null) { + currentInferenceProcessors = 0; + return; + } + IngestMetadata ingestMetadata = metaData.custom(IngestMetadata.TYPE); + if (ingestMetadata == null) { + currentInferenceProcessors = 0; + return; + } + + int count = 0; + for (PipelineConfiguration configuration : ingestMetadata.getPipelines().values()) { + try { + Pipeline pipeline = Pipeline.create(configuration.getId(), + configuration.getConfigAsMap(), + ingestService.getProcessorFactories(), + ingestService.getScriptService()); + count += pipeline.getProcessors().stream().filter(processor -> processor instanceof InferenceProcessor).count(); + } catch (Exception ex) { + logger.warn(new ParameterizedMessage("failure parsing pipeline config [{}]", configuration.getId()), ex); + } + } + currentInferenceProcessors = count; + } + + // Used for testing + int numInferenceProcessors() { + return currentInferenceProcessors; + } + + @Override + public InferenceProcessor create(Map processorFactories, String tag, Map config) + throws Exception { + + if (inferenceAllowed == false) { + throw LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING); + } + + if (this.maxIngestProcessors <= currentInferenceProcessors) { + throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. " + + "Adjust the setting [{}]: [{}] if a greater number is desired.", + RestStatus.CONFLICT, + currentInferenceProcessors, + MAX_INFERENCE_PROCESSORS.getKey(), + maxIngestProcessors); + } + + boolean includeModelMetadata = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, INCLUDE_MODEL_METADATA, true); + String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID); + String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD); + Map fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS); + InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG)); + String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "ml"); + // If multiple inference processors are in the same pipeline, it is wise to tag them + // The tag will keep metadata entries from stepping on each other + if (tag != null) { + modelInfoField += "." + tag; + } + return new InferenceProcessor(client, + tag, + targetField, + modelId, + inferenceConfig, + fieldMapping, + modelInfoField, + includeModelMetadata); + } + + // Package private for testing + void setMaxIngestProcessors(int maxIngestProcessors) { + logger.trace("updating setting maxIngestProcessors from [{}] to [{}]", this.maxIngestProcessors, maxIngestProcessors); + this.maxIngestProcessors = maxIngestProcessors; + } + + InferenceConfig inferenceConfigFromMap(Map inferenceConfig) throws IOException { + ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); + + if (inferenceConfig.size() != 1) { + throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.", + INFERENCE_CONFIG); + } + Object value = inferenceConfig.values().iterator().next(); + + if ((value instanceof Map) == false) { + throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.", + INFERENCE_CONFIG); + } + @SuppressWarnings("unchecked") + Map valueMap = (Map)value; + + if (inferenceConfig.containsKey(ClassificationConfig.NAME)) { + checkSupportedVersion(new ClassificationConfig(0)); + return ClassificationConfig.fromMap(valueMap); + } else if (inferenceConfig.containsKey(RegressionConfig.NAME)) { + checkSupportedVersion(new RegressionConfig()); + return RegressionConfig.fromMap(valueMap); + } else { + throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}", + inferenceConfig.keySet(), + Arrays.asList(ClassificationConfig.NAME, RegressionConfig.NAME)); + } + } + + void checkSupportedVersion(InferenceConfig config) { + if (config.getMinimalSupportedVersion().after(minNodeVersion)) { + throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION, + config.getName(), + config.getMinimalSupportedVersion(), + minNodeVersion)); + } + } + + @Override + public void licenseStateChanged() { + this.inferenceAllowed = licenseState.isMachineLearningAllowed(); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java new file mode 100644 index 00000000000..403f10dd7d8 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; + +import java.util.Map; + +public class LocalModel implements Model { + + private final TrainedModelDefinition trainedModelDefinition; + private final String modelId; + + public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) { + this.trainedModelDefinition = trainedModelDefinition; + this.modelId = modelId; + } + + long ramBytesUsed() { + return trainedModelDefinition.ramBytesUsed(); + } + + @Override + public String getModelId() { + return modelId; + } + + @Override + public String getResultsType() { + switch (trainedModelDefinition.getTrainedModel().targetType()) { + case CLASSIFICATION: + return ClassificationInferenceResults.NAME; + case REGRESSION: + return RegressionInferenceResults.NAME; + default: + throw ExceptionsHelper.badRequestException("Model [{}] has unsupported target type [{}]", + modelId, + trainedModelDefinition.getTrainedModel().targetType()); + } + } + + @Override + public void infer(Map fields, InferenceConfig config, ActionListener listener) { + try { + listener.onResponse(trainedModelDefinition.infer(fields, config)); + } catch (Exception e) { + listener.onFailure(e); + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java new file mode 100644 index 00000000000..fb32ce7f646 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; + +import java.util.Map; + +public interface Model { + + String getResultsType(); + + void infer(Map fields, InferenceConfig inferenceConfig, ActionListener listener); + + String getModelId(); +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java new file mode 100644 index 00000000000..cee7fbb1807 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -0,0 +1,370 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.cache.Cache; +import org.elasticsearch.common.cache.CacheBuilder; +import org.elasticsearch.common.cache.RemovalNotification; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +/** + * This is a thread safe model loading service. + * + * It will cache local models that are referenced by processors in memory (as long as it is instantiated on an ingest node). + * + * If more than one processor references the same model, that model will only be cached once. + */ +public class ModelLoadingService implements ClusterStateListener { + + /** + * The maximum size of the local model cache here in the loading service + * + * Once the limit is reached, LRU models are evicted in favor of new models + */ + public static final Setting INFERENCE_MODEL_CACHE_SIZE = + Setting.byteSizeSetting("xpack.ml.inference_model.cache_size", + new ByteSizeValue(1, ByteSizeUnit.GB), + Setting.Property.NodeScope); + + /** + * How long should a model stay in the cache since its last access + * + * If nothing references a model via getModel for this configured timeValue, it will be evicted. + * + * Specifically, in the ingest scenario, a processor will call getModel whenever it needs to run inference. So, if a processor is not + * executed for an extended period of time, the model will be evicted and will have to be loaded again when getModel is called. + * + */ + public static final Setting INFERENCE_MODEL_CACHE_TTL = + Setting.timeSetting("xpack.ml.inference_model.time_to_live", + new TimeValue(5, TimeUnit.MINUTES), + new TimeValue(1, TimeUnit.MILLISECONDS), + Setting.Property.NodeScope); + + private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); + private final Cache localModelCache; + private final Set referencedModels = new HashSet<>(); + private final Map>> loadingListeners = new HashMap<>(); + private final TrainedModelProvider provider; + private final Set shouldNotAudit; + private final ThreadPool threadPool; + private final InferenceAuditor auditor; + private final ByteSizeValue maxCacheSize; + + public ModelLoadingService(TrainedModelProvider trainedModelProvider, + InferenceAuditor auditor, + ThreadPool threadPool, + ClusterService clusterService, + Settings settings) { + this.provider = trainedModelProvider; + this.threadPool = threadPool; + this.maxCacheSize = INFERENCE_MODEL_CACHE_SIZE.get(settings); + this.auditor = auditor; + this.shouldNotAudit = new HashSet<>(); + this.localModelCache = CacheBuilder.builder() + .setMaximumWeight(this.maxCacheSize.getBytes()) + .weigher((id, localModel) -> localModel.ramBytesUsed()) + .removalListener(this::cacheEvictionListener) + .setExpireAfterAccess(INFERENCE_MODEL_CACHE_TTL.get(settings)) + .build(); + clusterService.addListener(this); + } + + /** + * Gets the model referenced by `modelId` and responds to the listener. + * + * This method first checks the local LRU cache for the model. If it is present, it is returned from cache. + * + * If it is not present, one of the following occurs: + * + * - If the model is referenced by a pipeline and is currently being loaded, the `modelActionListener` + * is added to the list of listeners to be alerted when the model is fully loaded. + * - If the model is referenced by a pipeline and is currently NOT being loaded, a new load attempt is made and the resulting + * model will attempt to be cached for future reference + * - If the models is NOT referenced by a pipeline, the model is simply loaded from the index and given to the listener. + * It is not cached. + * + * @param modelId the model to get + * @param modelActionListener the listener to alert when the model has been retrieved. + */ + public void getModel(String modelId, ActionListener modelActionListener) { + LocalModel cachedModel = localModelCache.get(modelId); + if (cachedModel != null) { + modelActionListener.onResponse(cachedModel); + logger.trace("[{}] loaded from cache", modelId); + return; + } + if (loadModelIfNecessary(modelId, modelActionListener) == false) { + // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called + // by a simulated pipeline + logger.trace("[{}] not actively loading, eager loading without cache", modelId); + provider.getTrainedModel(modelId, true, ActionListener.wrap( + trainedModelConfig -> + modelActionListener.onResponse(new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition())), + modelActionListener::onFailure + )); + } else { + logger.trace("[{}] is loading or loaded, added new listener to queue", modelId); + } + } + + /** + * Returns true if the model is loaded and the listener has been given the cached model + * Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded + * Returns false if the model is not loaded or actively being loaded + */ + private boolean loadModelIfNecessary(String modelId, ActionListener modelActionListener) { + synchronized (loadingListeners) { + Model cachedModel = localModelCache.get(modelId); + if (cachedModel != null) { + modelActionListener.onResponse(cachedModel); + return true; + } + // It is referenced by a pipeline, but the cache does not contain it + if (referencedModels.contains(modelId)) { + // If the loaded model is referenced there but is not present, + // that means the previous load attempt failed or the model has been evicted + // Attempt to load and cache the model if necessary + if (loadingListeners.computeIfPresent( + modelId, + (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) { + logger.trace("[{}] attempting to load and cache", modelId); + loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener)); + loadModel(modelId); + } + return true; + } + // if the cachedModel entry is null, but there are listeners present, that means it is being loaded + return loadingListeners.computeIfPresent(modelId, + (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) != null; + } // synchronized (loadingListeners) + } + + private void loadModel(String modelId) { + provider.getTrainedModel(modelId, true, ActionListener.wrap( + trainedModelConfig -> { + logger.debug("[{}] successfully loaded model", modelId); + handleLoadSuccess(modelId, trainedModelConfig); + }, + failure -> { + logger.warn(new ParameterizedMessage("[{}] failed to load model", modelId), failure); + handleLoadFailure(modelId, failure); + } + )); + } + + private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelConfig) { + Queue> listeners; + LocalModel loadedModel = new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition()); + synchronized (loadingListeners) { + listeners = loadingListeners.remove(modelId); + // If there is no loadingListener that means the loading was canceled and the listener was already notified as such + // Consequently, we should not store the retrieved model + if (listeners == null) { + return; + } + localModelCache.put(modelId, loadedModel); + shouldNotAudit.remove(modelId); + } // synchronized (loadingListeners) + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + listener.onResponse(loadedModel); + } + } + + private void handleLoadFailure(String modelId, Exception failure) { + Queue> listeners; + synchronized (loadingListeners) { + listeners = loadingListeners.remove(modelId); + if (listeners == null) { + return; + } + } // synchronized (loadingListeners) + // If we failed to load and there were listeners present, that means that this model is referenced by a processor + // Alert the listeners to the failure + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + listener.onFailure(failure); + } + } + + private void cacheEvictionListener(RemovalNotification notification) { + if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) { + String msg = new ParameterizedMessage( + "model cache entry evicted." + + "current cache [{}] current max [{}] model size [{}]. " + + "If this is undesired, consider updating setting [{}] or [{}].", + new ByteSizeValue(localModelCache.weight()).getStringRep(), + maxCacheSize.getStringRep(), + new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(), + INFERENCE_MODEL_CACHE_SIZE.getKey(), + INFERENCE_MODEL_CACHE_TTL.getKey()).getFormattedMessage(); + auditIfNecessary(notification.getKey(), msg); + } + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + // If ingest data has not changed or if the current node is not an ingest node, don't bother caching models + if (event.changedCustomMetaDataSet().contains(IngestMetadata.TYPE) == false || + event.state().nodes().getLocalNode().isIngestNode() == false) { + return; + } + + ClusterState state = event.state(); + IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE); + Set allReferencedModelKeys = getReferencedModelKeys(currentIngestMetadata); + if (allReferencedModelKeys.equals(referencedModels)) { + return; + } + // The listeners still waiting for a model and we are canceling the load? + List>>> drainWithFailure = new ArrayList<>(); + Set referencedModelsBeforeClusterState = null; + Set loadingModelBeforeClusterState = null; + Set removedModels = null; + synchronized (loadingListeners) { + referencedModelsBeforeClusterState = new HashSet<>(referencedModels); + if (logger.isTraceEnabled()) { + loadingModelBeforeClusterState = new HashSet<>(loadingListeners.keySet()); + } + // If we had models still loading here but are no longer referenced + // we should remove them from loadingListeners and alert the listeners + for (String modelId : loadingListeners.keySet()) { + if (allReferencedModelKeys.contains(modelId) == false) { + drainWithFailure.add(Tuple.tuple(modelId, new ArrayList<>(loadingListeners.remove(modelId)))); + } + } + removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys); + + // Remove all cached models that are not referenced by any processors + removedModels.forEach(localModelCache::invalidate); + // Remove the models that are no longer referenced + referencedModels.removeAll(removedModels); + shouldNotAudit.removeAll(removedModels); + + // Remove all that are still referenced, i.e. the intersection of allReferencedModelKeys and referencedModels + allReferencedModelKeys.removeAll(referencedModels); + referencedModels.addAll(allReferencedModelKeys); + + // Populate loadingListeners key so we know that we are currently loading the model + for (String modelId : allReferencedModelKeys) { + loadingListeners.put(modelId, new ArrayDeque<>()); + } + } // synchronized (loadingListeners) + if (logger.isTraceEnabled()) { + if (loadingListeners.keySet().equals(loadingModelBeforeClusterState) == false) { + logger.trace("cluster state event changed loading models: before {} after {}", loadingModelBeforeClusterState, + loadingListeners.keySet()); + } + if (referencedModels.equals(referencedModelsBeforeClusterState) == false) { + logger.trace("cluster state event changed referenced models: before {} after {}", referencedModelsBeforeClusterState, + referencedModels); + } + } + for (Tuple>> modelAndListeners : drainWithFailure) { + final String msg = new ParameterizedMessage( + "Cancelling load of model [{}] as it is no longer referenced by a pipeline", + modelAndListeners.v1()).getFormat(); + for (ActionListener listener : modelAndListeners.v2()) { + listener.onFailure(new ElasticsearchException(msg)); + } + } + removedModels.forEach(this::auditUnreferencedModel); + loadModels(allReferencedModelKeys); + } + + private void auditIfNecessary(String modelId, String msg) { + if (shouldNotAudit.contains(modelId)) { + logger.trace("[{}] {}", modelId, msg); + return; + } + auditor.warning(modelId, msg); + shouldNotAudit.add(modelId); + logger.warn("[{}] {}", modelId, msg); + } + + private void loadModels(Set modelIds) { + if (modelIds.isEmpty()) { + return; + } + // Execute this on a utility thread as when the callbacks occur we don't want them tying up the cluster listener thread pool + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { + for (String modelId : modelIds) { + auditNewReferencedModel(modelId); + this.loadModel(modelId); + } + }); + } + + private void auditNewReferencedModel(String modelId) { + auditor.info(modelId, "referenced by ingest processors. Attempting to load model into cache"); + } + + private void auditUnreferencedModel(String modelId) { + auditor.info(modelId, "no longer referenced by any processors"); + } + + private static Queue addFluently(Queue queue, T object) { + queue.add(object); + return queue; + } + + private static Set getReferencedModelKeys(IngestMetadata ingestMetadata) { + Set allReferencedModelKeys = new HashSet<>(); + if (ingestMetadata == null) { + return allReferencedModelKeys; + } + ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> { + Object processors = pipelineConfiguration.getConfigAsMap().get("processors"); + if (processors instanceof List) { + for(Object processor : (List)processors) { + if (processor instanceof Map) { + Object processorConfig = ((Map)processor).get(InferenceProcessor.TYPE); + if (processorConfig instanceof Map) { + Object modelId = ((Map)processorConfig).get(InferenceProcessor.MODEL_ID); + if (modelId != null) { + assert modelId instanceof String; + allReferencedModelKeys.add(modelId.toString()); + } + } + } + } + } + }); + return allReferencedModelKeys; + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java index 19d5d33abe4..aa80807aae8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java @@ -24,6 +24,7 @@ import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappi import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.DYNAMIC; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.ENABLED; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.KEYWORD; +import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.LONG; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.PROPERTIES; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TEXT; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TYPE; @@ -103,6 +104,12 @@ public final class InferenceInternalIndex { .endObject() .startObject(TrainedModelConfig.METADATA.getPreferredName()) .field(ENABLED, false) + .endObject() + .startObject(TrainedModelConfig.ESTIMATED_OPERATIONS.getPreferredName()) + .field(TYPE, LONG) + .endObject() + .startObject(TrainedModelConfig.ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName()) + .field(TYPE, LONG) .endObject(); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 6307b0d3ccc..16c911a0689 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -20,10 +20,19 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.MultiSearchAction; import org.elasticsearch.action.search.MultiSearchRequestBuilder; import org.elasticsearch.action.search.MultiSearchResponse; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.Client; import org.elasticsearch.common.CheckedBiFunction; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.regex.Regex; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; @@ -32,12 +41,21 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.reindex.DeleteByQueryAction; +import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher; +import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; @@ -47,10 +65,17 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_FAILED_TO_DESERIALIZE; public class TrainedModelProvider { @@ -189,6 +214,166 @@ public class TrainedModelProvider { multiSearchResponseActionListener); } + /** + * Gets all the provided trained config model objects + * + * NOTE: + * This does no expansion on the ids. + * It assumes that there are fewer than 10k. + */ + public void getTrainedModels(Set modelIds, boolean allowNoResources, final ActionListener> listener) { + QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0]))); + + SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) + .addSort(TrainedModelConfig.MODEL_ID.getPreferredName(), SortOrder.ASC) + .addSort("_index", SortOrder.DESC) + .setQuery(queryBuilder) + .request(); + + ActionListener configSearchHandler = ActionListener.wrap( + searchResponse -> { + Set observedIds = new HashSet<>(searchResponse.getHits().getHits().length, 1.0f); + List configs = new ArrayList<>(searchResponse.getHits().getHits().length); + for(SearchHit searchHit : searchResponse.getHits().getHits()) { + try { + if (observedIds.contains(searchHit.getId()) == false) { + configs.add( + parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()).build() + ); + observedIds.add(searchHit.getId()); + } + } catch (IOException ex) { + listener.onFailure( + ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId())); + return; + } + } + // We previously expanded the IDs. + // If the config has gone missing between then and now we should throw if allowNoResources is false + // Otherwise, treat it as if it was never expanded to begin with. + Set missingConfigs = Sets.difference(modelIds, observedIds); + if (missingConfigs.isEmpty() == false && allowNoResources == false) { + listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs)); + return; + } + listener.onResponse(configs); + }, + listener::onFailure + ); + + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler); + } + + public void deleteTrainedModel(String modelId, ActionListener listener) { + DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); + + request.indices(InferenceIndexConstants.INDEX_PATTERN); + QueryBuilder query = QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId); + request.setQuery(query); + request.setRefresh(true); + + executeAsyncWithOrigin(client, ML_ORIGIN, DeleteByQueryAction.INSTANCE, request, ActionListener.wrap(deleteResponse -> { + if (deleteResponse.getDeleted() == 0) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); + return; + } + listener.onResponse(true); + }, e -> { + if (e.getClass() == IndexNotFoundException.class) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); + } else { + listener.onFailure(e); + } + })); + } + + public void expandIds(String idExpression, + boolean allowNoResources, + @Nullable PageParams pageParams, + ActionListener>> idsListener) { + String[] tokens = Strings.tokenizeToStringArray(idExpression, ","); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() + .sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName()) + // If there are no resources, there might be no mapping for the id field. + // This makes sure we don't get an error if that happens. + .unmappedType("long")) + .query(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName())); + if (pageParams != null) { + sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize()); + } + sourceBuilder.trackTotalHits(true) + // we only care about the item id's + .fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null); + + IndicesOptions indicesOptions = SearchRequest.DEFAULT_INDICES_OPTIONS; + SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN) + .indicesOptions(IndicesOptions.fromOptions(true, + indicesOptions.allowNoIndices(), + indicesOptions.expandWildcardsOpen(), + indicesOptions.expandWildcardsClosed(), + indicesOptions)) + .source(sourceBuilder); + + executeAsyncWithOrigin(client.threadPool().getThreadContext(), + ML_ORIGIN, + searchRequest, + ActionListener.wrap( + response -> { + Set foundResourceIds = new LinkedHashSet<>(); + long totalHitCount = response.getHits().getTotalHits().value; + for (SearchHit hit : response.getHits().getHits()) { + Map docSource = hit.getSourceAsMap(); + if (docSource == null) { + continue; + } + Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName()); + if (idValue instanceof String) { + foundResourceIds.add(idValue.toString()); + } + } + ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources); + requiredMatches.filterMatchedIds(foundResourceIds); + if (requiredMatches.hasUnmatchedIds()) { + idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString())); + } else { + idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds)); + } + }, + idsListener::onFailure + ), + client::search); + + } + + private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME)); + + if (Strings.isAllOrWildcard(tokens)) { + return boolQuery; + } + // If the resourceId is not _all or *, we should see if it is a comma delimited string with wild-cards + // e.g. id1,id2*,id3 + BoolQueryBuilder shouldQueries = new BoolQueryBuilder(); + List terms = new ArrayList<>(); + for (String token : tokens) { + if (Regex.isSimpleMatchPattern(token)) { + shouldQueries.should(QueryBuilders.wildcardQuery(resourceIdField, token)); + } else { + terms.add(token); + } + } + if (terms.isEmpty() == false) { + shouldQueries.should(QueryBuilders.termsQuery(resourceIdField, terms)); + } + + if (shouldQueries.should().isEmpty() == false) { + boolQuery.filter(shouldQueries); + } + return boolQuery; + } private static T handleSearchItem(MultiSearchResponse.Item item, String resourceId, @@ -202,23 +387,23 @@ public class TrainedModelProvider { return parseLeniently.apply(item.getResponse().getHits().getHits()[0].getSourceRef(), resourceId); } - private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws Exception { + private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws IOException { try (InputStream stream = source.streamInput(); XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { return TrainedModelConfig.fromXContent(parser, true); - } catch (Exception e) { + } catch (IOException e) { logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e); throw e; } } - private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws Exception { + private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws IOException { try (InputStream stream = source.streamInput(); XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { return TrainedModelDefinition.fromXContent(parser, true).build(); - } catch (Exception e) { + } catch (IOException e) { logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), e); throw e; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java new file mode 100644 index 00000000000..dfce44af7c9 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.notifications; + +import org.elasticsearch.client.Client; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor; +import org.elasticsearch.xpack.core.ml.notifications.AuditorField; +import org.elasticsearch.xpack.core.ml.notifications.InferenceAuditMessage; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +public class InferenceAuditor extends AbstractAuditor { + + public InferenceAuditor(Client client, String nodeName) { + super(client, nodeName, AuditorField.NOTIFICATIONS_INDEX, ML_ORIGIN, InferenceAuditMessage::new); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeleteTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeleteTrainedModelAction.java new file mode 100644 index 00000000000..e9675be4d29 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeleteTrainedModelAction.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +import static org.elasticsearch.rest.RestRequest.Method.DELETE; + +public class RestDeleteTrainedModelAction extends BaseRestHandler { + + public RestDeleteTrainedModelAction(RestController controller) { + controller.registerHandler( + DELETE, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this); + } + + @Override + public String getName() { + return "ml_delete_trained_models_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + DeleteTrainedModelAction.Request request = new DeleteTrainedModelAction.Request(modelId); + return channel -> client.execute(DeleteTrainedModelAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java new file mode 100644 index 00000000000..578b75fbc07 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.common.Strings; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +import static org.elasticsearch.rest.RestRequest.Method.GET; +import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH; + +public class RestGetTrainedModelsAction extends BaseRestHandler { + + public RestGetTrainedModelsAction(RestController controller) { + controller.registerHandler( + GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this); + controller.registerHandler(GET, MachineLearning.BASE_PATH + "inference", this); + } + + @Override + public String getName() { + return "ml_get_trained_models_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + if (Strings.isNullOrEmpty(modelId)) { + modelId = MetaData.ALL; + } + boolean includeModelDefinition = restRequest.paramAsBoolean( + GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(), + false + ); + GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition); + if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { + request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), + restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); + } + request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources())); + return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsStatsAction.java new file mode 100644 index 00000000000..100c8cfa2f9 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsStatsAction.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.common.Strings; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +import static org.elasticsearch.rest.RestRequest.Method.GET; +import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH; + +public class RestGetTrainedModelsStatsAction extends BaseRestHandler { + + public RestGetTrainedModelsStatsAction(RestController controller) { + controller.registerHandler( + GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_stats", this); + controller.registerHandler(GET, MachineLearning.BASE_PATH + "inference/_stats", this); + } + + @Override + public String getName() { + return "ml_get_trained_models_stats_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + if (Strings.isNullOrEmpty(modelId)) { + modelId = MetaData.ALL; + } + GetTrainedModelsStatsAction.Request request = new GetTrainedModelsStatsAction.Request(modelId); + if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { + request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), + restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); + } + request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources())); + return channel -> client.execute(GetTrainedModelsStatsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java index 06ccdb3a299..cc3a9725c09 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java @@ -6,13 +6,21 @@ package org.elasticsearch.license; import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.action.ingest.PutPipelineAction; +import org.elasticsearch.action.ingest.PutPipelineRequest; +import org.elasticsearch.action.ingest.SimulatePipelineAction; +import org.elasticsearch.action.ingest.SimulatePipelineRequest; +import org.elasticsearch.action.ingest.SimulatePipelineResponse; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.transport.TransportClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.license.License.OperationMode; import org.elasticsearch.persistent.PersistentTasksCustomMetaData; import org.elasticsearch.rest.RestStatus; @@ -24,6 +32,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteDatafeedAction; import org.elasticsearch.xpack.core.ml.action.DeleteJobAction; import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.OpenJobAction; import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PutJobAction; @@ -31,17 +40,24 @@ import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; import org.elasticsearch.xpack.core.ml.client.MachineLearningClient; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.ml.LocalStateMachineLearning; import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; import org.junit.Before; +import java.nio.charset.StandardCharsets; import java.util.Collections; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; public class MachineLearningLicensingTests extends BaseMlIntegTestCase { @@ -529,6 +545,216 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase { } } + public void testMachineLearningCreateInferenceProcessorRestricted() { + String modelId = "modelprocessorlicensetest"; + assertMLAllowed(true); + putInferenceModel(modelId); + + String pipeline = "{" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"regression_value\",\n" + + " \"model_id\": \"modelprocessorlicensetest\",\n" + + " \"inference_config\": {\"regression\": {}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }]}\n"; + // test that license restricted apis do now work + PlainActionFuture putPipelineListener = PlainActionFuture.newFuture(); + client().execute(PutPipelineAction.INSTANCE, + new PutPipelineRequest("test_infer_license_pipeline", + new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON), + putPipelineListener); + AcknowledgedResponse putPipelineResponse = putPipelineListener.actionGet(); + assertTrue(putPipelineResponse.isAcknowledged()); + + String simulateSource = "{\n" + + " \"pipeline\": \n" + + pipeline + + " ,\n" + + " \"docs\": [\n" + + " {\"_source\": {\n" + + " \"col1\": \"female\",\n" + + " \"col2\": \"M\",\n" + + " \"col3\": \"none\",\n" + + " \"col4\": 10\n" + + " }}]\n" + + "}"; + PlainActionFuture simulatePipelineListener = PlainActionFuture.newFuture(); + client().execute(SimulatePipelineAction.INSTANCE, + new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON), + simulatePipelineListener); + + assertThat(simulatePipelineListener.actionGet().getResults(), is(not(empty()))); + + + // Pick a license that does not allow machine learning + License.OperationMode mode = randomInvalidLicenseType(); + enableLicensing(mode); + assertMLAllowed(false); + + // creating a new pipeline should fail + ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> { + PlainActionFuture listener = PlainActionFuture.newFuture(); + client().execute(PutPipelineAction.INSTANCE, + new PutPipelineRequest("test_infer_license_pipeline_failure", + new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON), + listener); + listener.actionGet(); + }); + assertThat(e.status(), is(RestStatus.FORBIDDEN)); + assertThat(e.getMessage(), containsString("non-compliant")); + assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING)); + + // Simulating the pipeline should fail + e = expectThrows(ElasticsearchSecurityException.class, () -> { + PlainActionFuture listener = PlainActionFuture.newFuture(); + client().execute(SimulatePipelineAction.INSTANCE, + new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON), + listener); + listener.actionGet(); + }); + assertThat(e.status(), is(RestStatus.FORBIDDEN)); + assertThat(e.getMessage(), containsString("non-compliant")); + assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING)); + + // Pick a license that does allow machine learning + mode = randomValidLicenseType(); + enableLicensing(mode); + assertMLAllowed(true); + // test that license restricted apis do now work + PlainActionFuture putPipelineListenerNewLicense = PlainActionFuture.newFuture(); + client().execute(PutPipelineAction.INSTANCE, + new PutPipelineRequest("test_infer_license_pipeline", + new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON), + putPipelineListenerNewLicense); + AcknowledgedResponse putPipelineResponseNewLicense = putPipelineListenerNewLicense.actionGet(); + assertTrue(putPipelineResponseNewLicense.isAcknowledged()); + + PlainActionFuture simulatePipelineListenerNewLicense = PlainActionFuture.newFuture(); + client().execute(SimulatePipelineAction.INSTANCE, + new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON), + simulatePipelineListenerNewLicense); + + assertThat(simulatePipelineListenerNewLicense.actionGet().getResults(), is(not(empty()))); + } + + public void testMachineLearningInferModelRestricted() throws Exception { + String modelId = "modelinfermodellicensetest"; + assertMLAllowed(true); + putInferenceModel(modelId); + + + PlainActionFuture inferModelSuccess = PlainActionFuture.newFuture(); + client().execute(InferModelAction.INSTANCE, new InferModelAction.Request( + modelId, + Collections.singletonList(Collections.emptyMap()), + new RegressionConfig() + ), inferModelSuccess); + assertThat(inferModelSuccess.actionGet().getInferenceResults(), is(not(empty()))); + + // Pick a license that does not allow machine learning + License.OperationMode mode = randomInvalidLicenseType(); + enableLicensing(mode); + assertMLAllowed(false); + + // inferring against a model should now fail + ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> { + PlainActionFuture listener = PlainActionFuture.newFuture(); + client().execute(InferModelAction.INSTANCE, new InferModelAction.Request( + modelId, + Collections.singletonList(Collections.emptyMap()), + new RegressionConfig() + ), listener); + listener.actionGet(); + }); + assertThat(e.status(), is(RestStatus.FORBIDDEN)); + assertThat(e.getMessage(), containsString("non-compliant")); + assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING)); + + // Pick a license that does allow machine learning + mode = randomValidLicenseType(); + enableLicensing(mode); + assertMLAllowed(true); + + PlainActionFuture listener = PlainActionFuture.newFuture(); + client().execute(InferModelAction.INSTANCE, new InferModelAction.Request( + modelId, + Collections.singletonList(Collections.emptyMap()), + new RegressionConfig() + ), listener); + assertThat(listener.actionGet().getInferenceResults(), is(not(empty()))); + } + + private void putInferenceModel(String modelId) { + String config = "" + + "{\n" + + " \"model_id\": \"" + modelId + "\",\n" + + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + + " \"description\": \"test model for classification\",\n" + + " \"version\": \"7.6.0\",\n" + + " \"created_by\": \"benwtrent\",\n" + + " \"estimated_heap_memory_usage_bytes\": 0,\n" + + " \"estimated_operations\": 0,\n" + + " \"created_time\": 0\n" + + "}"; + String definition = "" + + "{" + + " \"trained_model\": {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " }," + + " \"model_id\": \"" + modelId + "\"\n" + + "}"; + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) + .setId(modelId) + .setSource(config, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) + .setId(TrainedModelDefinition.docId(modelId)) + .setSource(definition, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + } + private static OperationMode randomInvalidLicenseType() { return randomFrom(License.OperationMode.GOLD, License.OperationMode.STANDARD, License.OperationMode.BASIC); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningFeatureSetTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningFeatureSetTests.java index 8f2ed47794a..6cf9a7a3d60 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningFeatureSetTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningFeatureSetTests.java @@ -8,8 +8,16 @@ package org.elasticsearch.xpack.ml; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.action.admin.cluster.node.stats.NodeStats; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.cluster.node.DiscoveryNode; @@ -20,13 +28,18 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.env.Environment; import org.elasticsearch.env.TestEnvironment; +import org.elasticsearch.ingest.IngestStats; import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.XPackFeatureSet; import org.elasticsearch.xpack.core.XPackFeatureSet.Usage; import org.elasticsearch.xpack.core.XPackField; @@ -40,6 +53,7 @@ import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; import org.elasticsearch.xpack.core.ml.job.config.Detector; @@ -49,10 +63,12 @@ import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeSta import org.elasticsearch.xpack.core.ml.stats.ForecastStats; import org.elasticsearch.xpack.core.ml.stats.ForecastStatsTests; import org.elasticsearch.xpack.core.watcher.support.xcontent.XContentSource; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.junit.Before; +import java.time.Instant; import java.util.Arrays; import java.util.Collections; import java.util.Date; @@ -61,6 +77,8 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.notNullValue; @@ -98,6 +116,8 @@ public class MachineLearningFeatureSetTests extends ESTestCase { givenJobs(Collections.emptyList(), Collections.emptyList()); givenDatafeeds(Collections.emptyList()); givenDataFrameAnalytics(Collections.emptyList()); + givenProcessorStats(Collections.emptyList()); + givenTrainedModelConfigCount(0); } public void testIsRunningOnMlPlatform() { @@ -175,14 +195,54 @@ public class MachineLearningFeatureSetTests extends ESTestCase { buildDatafeedStats(DatafeedState.STARTED), buildDatafeedStats(DatafeedState.STOPPED) )); + givenDataFrameAnalytics(Arrays.asList( buildDataFrameAnalyticsStats(DataFrameAnalyticsState.STOPPED), buildDataFrameAnalyticsStats(DataFrameAnalyticsState.STOPPED), buildDataFrameAnalyticsStats(DataFrameAnalyticsState.STARTED) )); + givenProcessorStats(Arrays.asList( + buildNodeStats( + Arrays.asList("pipeline1", "pipeline2", "pipeline3"), + Arrays.asList( + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0)), + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)), + new IngestStats.ProcessorStat( + InferenceProcessor.TYPE, + InferenceProcessor.TYPE, + new IngestStats.Stats(100, 10, 0, 1)) + ), + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)), + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)) + ), + Arrays.asList( + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)) + ) + )), + buildNodeStats( + Arrays.asList("pipeline1", "pipeline2", "pipeline3"), + Arrays.asList( + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(0, 0, 0, 0)), + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(0, 0, 0, 0)), + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0)) + ), + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)), + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)) + ), + Arrays.asList( + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)) + ) + )) + )); + givenTrainedModelConfigCount(100); + MachineLearningFeatureSet featureSet = new MachineLearningFeatureSet(TestEnvironment.newEnvironment(settings.build()), - clusterService, client, licenseState, jobManagerHolder); + clusterService, client, licenseState, jobManagerHolder); PlainActionFuture future = new PlainActionFuture<>(); featureSet.usage(future); XPackFeatureSet.Usage mlUsage = future.get(); @@ -258,6 +318,18 @@ public class MachineLearningFeatureSetTests extends ESTestCase { assertThat(source.getValue("jobs.opened.forecasts.total"), equalTo(11)); assertThat(source.getValue("jobs.opened.forecasts.forecasted_jobs"), equalTo(2)); + + assertThat(source.getValue("inference.trained_models._all.count"), equalTo(100)); + assertThat(source.getValue("inference.ingest_processors._all.pipelines.count"), equalTo(2)); + assertThat(source.getValue("inference.ingest_processors._all.num_docs_processed.sum"), equalTo(130)); + assertThat(source.getValue("inference.ingest_processors._all.num_docs_processed.min"), equalTo(0)); + assertThat(source.getValue("inference.ingest_processors._all.num_docs_processed.max"), equalTo(100)); + assertThat(source.getValue("inference.ingest_processors._all.time_ms.sum"), equalTo(14)); + assertThat(source.getValue("inference.ingest_processors._all.time_ms.min"), equalTo(0)); + assertThat(source.getValue("inference.ingest_processors._all.time_ms.max"), equalTo(10)); + assertThat(source.getValue("inference.ingest_processors._all.num_failures.sum"), equalTo(1)); + assertThat(source.getValue("inference.ingest_processors._all.num_failures.min"), equalTo(0)); + assertThat(source.getValue("inference.ingest_processors._all.num_failures.max"), equalTo(1)); } } @@ -444,6 +516,34 @@ public class MachineLearningFeatureSetTests extends ESTestCase { }).when(client).execute(same(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any()); } + private void givenProcessorStats(List stats) { + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + ActionListener listener = + (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(new NodesStatsResponse(new ClusterName("_name"), stats, Collections.emptyList())); + return Void.TYPE; + }).when(client).execute(same(NodesStatsAction.INSTANCE), any(), any()); + } + + private void givenTrainedModelConfigCount(long count) { + when(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)) + .thenReturn(new SearchRequestBuilder(client, SearchAction.INSTANCE)); + ThreadPool pool = mock(ThreadPool.class); + when(pool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + when(client.threadPool()).thenReturn(pool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + ActionListener listener = + (ActionListener) invocationOnMock.getArguments()[1]; + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(count, TotalHits.Relation.EQUAL_TO), (float)0.0); + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getHits()).thenReturn(searchHits); + listener.onResponse(searchResponse); + return Void.TYPE; + }).when(client).search(any(), any()); + } + private static Detector buildMinDetector(String fieldName) { Detector.Builder detectorBuilder = new Detector.Builder(); detectorBuilder.setFunction("min"); @@ -490,6 +590,17 @@ public class MachineLearningFeatureSetTests extends ESTestCase { return stats; } + private static NodeStats buildNodeStats(List pipelineNames, List> processorStats) { + IngestStats ingestStats = new IngestStats( + new IngestStats.Stats(0,0,0,0), + Collections.emptyList(), + IntStream.range(0, pipelineNames.size()).boxed().collect(Collectors.toMap(pipelineNames::get, processorStats::get))); + return new NodeStats(mock(DiscoveryNode.class), + Instant.now().toEpochMilli(), null, null, null, null, null, null, null, null, + null, null, null, ingestStats, null); + + } + private static ForecastStats buildForecastStats(long numberOfForecasts) { return new ForecastStatsTests().createForecastStats(numberOfForecasts, numberOfForecasts); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java new file mode 100644 index 00000000000..bc86512e3ba --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java @@ -0,0 +1,289 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.admin.cluster.node.stats.NodeStats; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.IngestStats; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.ingest.Processor; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.plugins.IngestPlugin; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.junit.Before; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TransportGetTrainedModelsStatsActionTests extends ESTestCase { + + private static class NotInferenceProcessor implements Processor { + + @Override + public IngestDocument execute(IngestDocument ingestDocument) throws Exception { + return ingestDocument; + } + + @Override + public String getType() { + return "not_inference"; + } + + @Override + public String getTag() { + return null; + } + + static class Factory implements Processor.Factory { + + @Override + public Processor create(Map processorFactories, String tag, Map config) { + return new NotInferenceProcessor(); + } + } + } + + private static final IngestPlugin SKINNY_INGEST_PLUGIN = new IngestPlugin() { + @Override + public Map getProcessors(Processor.Parameters parameters) { + Map factoryMap = new HashMap<>(); + XPackLicenseState licenseState = mock(XPackLicenseState.class); + when(licenseState.isMachineLearningAllowed()).thenReturn(true); + factoryMap.put(InferenceProcessor.TYPE, + new InferenceProcessor.Factory(parameters.client, + parameters.ingestService.getClusterService(), + Settings.EMPTY, + parameters.ingestService, + licenseState)); + + factoryMap.put("not_inference", new NotInferenceProcessor.Factory()); + + return factoryMap; + } + }; + + private ClusterService clusterService; + private IngestService ingestService; + private Client client; + + @Before + public void setUpVariables() { + ThreadPool tp = mock(ThreadPool.class); + client = mock(Client.class); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, + Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS)); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + ingestService = new IngestService(clusterService, tp, null, null, + null, Collections.singletonList(SKINNY_INGEST_PLUGIN), client); + } + + + public void testInferenceIngestStatsByPipelineId() throws IOException { + List nodeStatsList = Arrays.asList( + buildNodeStats( + new IngestStats.Stats(2, 2, 3, 4), + Arrays.asList( + new IngestStats.PipelineStat( + "pipeline1", + new IngestStats.Stats(0, 0, 3, 1)), + new IngestStats.PipelineStat( + "pipeline2", + new IngestStats.Stats(1, 1, 0, 1)), + new IngestStats.PipelineStat( + "pipeline3", + new IngestStats.Stats(2, 1, 1, 1)) + ), + Arrays.asList( + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0)), + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)), + new IngestStats.ProcessorStat( + InferenceProcessor.TYPE, + InferenceProcessor.TYPE, + new IngestStats.Stats(100, 10, 0, 1)) + ), + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)), + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)) + ), + Arrays.asList( + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)) + ) + )), + buildNodeStats( + new IngestStats.Stats(15, 5, 3, 4), + Arrays.asList( + new IngestStats.PipelineStat( + "pipeline1", + new IngestStats.Stats(10, 1, 3, 1)), + new IngestStats.PipelineStat( + "pipeline2", + new IngestStats.Stats(1, 1, 0, 1)), + new IngestStats.PipelineStat( + "pipeline3", + new IngestStats.Stats(2, 1, 1, 1)) + ), + Arrays.asList( + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(0, 0, 0, 0)), + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(0, 0, 0, 0)), + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0)) + ), + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)), + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)) + ), + Arrays.asList( + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)) + ) + )) + ); + + NodesStatsResponse response = new NodesStatsResponse(new ClusterName("_name"), nodeStatsList, Collections.emptyList()); + + Map> pipelineIdsByModelIds = new HashMap>(){{ + put("trained_model_1", Collections.singleton("pipeline1")); + put("trained_model_2", new HashSet<>(Arrays.asList("pipeline1", "pipeline2"))); + }}; + Map ingestStatsMap = TransportGetTrainedModelsStatsAction.inferenceIngestStatsByPipelineId(response, + pipelineIdsByModelIds); + + assertThat(ingestStatsMap.keySet(), equalTo(new HashSet<>(Arrays.asList("trained_model_1", "trained_model_2")))); + + IngestStats expectedStatsModel1 = new IngestStats( + new IngestStats.Stats(10, 1, 6, 2), + Collections.singletonList(new IngestStats.PipelineStat("pipeline1", new IngestStats.Stats(10, 1, 6, 2))), + Collections.singletonMap("pipeline1", Arrays.asList( + new IngestStats.ProcessorStat("inference", "inference", new IngestStats.Stats(120, 12, 0, 1)), + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)))) + ); + + IngestStats expectedStatsModel2 = new IngestStats( + new IngestStats.Stats(12, 3, 6, 4), + Arrays.asList( + new IngestStats.PipelineStat("pipeline1", new IngestStats.Stats(10, 1, 6, 2)), + new IngestStats.PipelineStat("pipeline2", new IngestStats.Stats(2, 2, 0, 2))), + new HashMap>() {{ + put("pipeline2", Arrays.asList( + new IngestStats.ProcessorStat("inference", "inference", new IngestStats.Stats(10, 2, 0, 0)), + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(20, 2, 0, 0)))); + put("pipeline1", Arrays.asList( + new IngestStats.ProcessorStat("inference", "inference", new IngestStats.Stats(120, 12, 0, 1)), + new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)))); + }} + ); + + assertThat(ingestStatsMap, hasEntry("trained_model_1", expectedStatsModel1)); + assertThat(ingestStatsMap, hasEntry("trained_model_2", expectedStatsModel2)); + } + + public void testPipelineIdsByModelIds() throws IOException { + String modelId1 = "trained_model_1"; + String modelId2 = "trained_model_2"; + String modelId3 = "trained_model_3"; + Set modelIds = new HashSet<>(Arrays.asList(modelId1, modelId2, modelId3)); + + ClusterState clusterState = buildClusterStateWithModelReferences(modelId1, modelId2, modelId3); + + Map> pipelineIdsByModelIds = + TransportGetTrainedModelsStatsAction.pipelineIdsByModelIds(clusterState, ingestService, modelIds); + + assertThat(pipelineIdsByModelIds.keySet(), equalTo(modelIds)); + assertThat(pipelineIdsByModelIds, + hasEntry(modelId1, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId1 + 0, "pipeline_with_model_" + modelId1 + 1)))); + assertThat(pipelineIdsByModelIds, + hasEntry(modelId2, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId2 + 0, "pipeline_with_model_" + modelId2 + 1)))); + assertThat(pipelineIdsByModelIds, + hasEntry(modelId3, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId3 + 0, "pipeline_with_model_" + modelId3 + 1)))); + + } + + private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { + Map configurations = new HashMap<>(modelId.length); + for (String id : modelId) { + configurations.put("pipeline_with_model_" + id + 0, newConfigurationWithInferenceProcessor(id, 0)); + configurations.put("pipeline_with_model_" + id + 1, newConfigurationWithInferenceProcessor(id, 1)); + } + for (int i = 0; i < 3; i++) { + configurations.put("pipeline_without_model_" + i, newConfigurationWithOutInferenceProcessor(i)); + } + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + + return ClusterState.builder(new ClusterName("_name")) + .metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) + .build(); + } + + private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId, int num) throws IOException { + try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors", + Collections.singletonList( + Collections.singletonMap(InferenceProcessor.TYPE, + new HashMap() {{ + put(InferenceProcessor.MODEL_ID, modelId); + put("inference_config", Collections.singletonMap("regression", Collections.emptyMap())); + put("field_mappings", Collections.emptyMap()); + put("target_field", randomAlphaOfLength(10)); + }}))))) { + return new PipelineConfiguration("pipeline_with_model_" + modelId + num, + BytesReference.bytes(xContentBuilder), + XContentType.JSON); + } + } + + private static PipelineConfiguration newConfigurationWithOutInferenceProcessor(int i) throws IOException { + try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors", + Collections.singletonList(Collections.singletonMap("not_inference", Collections.emptyMap()))))) { + return new PipelineConfiguration("pipeline_without_model_" + i, BytesReference.bytes(xContentBuilder), XContentType.JSON); + } + } + + private static NodeStats buildNodeStats(IngestStats.Stats overallStats, + List pipelineNames, + List> processorStats) { + List pipelineids = pipelineNames.stream().map(IngestStats.PipelineStat::getPipelineId).collect(Collectors.toList()); + IngestStats ingestStats = new IngestStats( + overallStats, + pipelineNames, + IntStream.range(0, pipelineids.size()).boxed().collect(Collectors.toMap(pipelineids::get, processorStats::get))); + return new NodeStats(mock(DiscoveryNode.class), + Instant.now().toEpochMilli(), null, null, null, null, null, null, null, null, + null, null, null, ingestStats, null); + + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index cb90b39772a..bdccdf8c672 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -145,6 +145,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); assertThat(storedModel.getDefinition(), equalTo(inferenceModel.build())); assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames)); + assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed())); + assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations())); Map metadata = storedModel.getMetadata(); assertThat(metadata.size(), equalTo(1)); assertThat(metadata, hasKey("analytics_config")); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java index 36bc727d2b9..0d3121d8681 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java @@ -40,7 +40,7 @@ public class AnalyticsResultTests extends AbstractXContentTestCase getProcessors(Processor.Parameters parameters) { + XPackLicenseState licenseState = mock(XPackLicenseState.class); + when(licenseState.isMachineLearningAllowed()).thenReturn(true); + return Collections.singletonMap(InferenceProcessor.TYPE, + new InferenceProcessor.Factory(parameters.client, + parameters.ingestService.getClusterService(), + Settings.EMPTY, + parameters.ingestService, + licenseState)); + } + }; + private Client client; + private XPackLicenseState licenseState; + private ClusterService clusterService; + private IngestService ingestService; + + @Before + public void setUpVariables() { + ThreadPool tp = mock(ThreadPool.class); + client = mock(Client.class); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, + Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS)); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + ingestService = new IngestService(clusterService, tp, null, null, + null, Collections.singletonList(SKINNY_PLUGIN), client); + licenseState = mock(XPackLicenseState.class); + when(licenseState.isMachineLearningAllowed()).thenReturn(true); + } + + public void testNumInferenceProcessors() throws Exception { + MetaData metaData = null; + + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService, + licenseState); + processorFactory.accept(buildClusterState(metaData)); + + assertThat(processorFactory.numInferenceProcessors(), equalTo(0)); + metaData = MetaData.builder().build(); + + processorFactory.accept(buildClusterState(metaData)); + assertThat(processorFactory.numInferenceProcessors(), equalTo(0)); + + processorFactory.accept(buildClusterStateWithModelReferences("model1", "model2", "model3")); + assertThat(processorFactory.numInferenceProcessors(), equalTo(3)); + } + + public void testCreateProcessorWithTooManyExisting() throws Exception { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.builder().put(InferenceProcessor.MAX_INFERENCE_PROCESSORS.getKey(), 1).build(), + ingestService, + licenseState); + + processorFactory.accept(buildClusterStateWithModelReferences("model1")); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", Collections.emptyMap())); + + assertThat(ex.getMessage(), equalTo("Max number of inference processors reached, total inference processors [1]. " + + "Adjust the setting [xpack.ml.max_inference_processors]: [1] if a greater number is desired.")); + } + + public void testCreateProcessorWithInvalidInferenceConfig() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService, + licenseState); + + Map config = new HashMap() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("unknown_type", Collections.emptyMap())); + }}; + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config)); + assertThat(ex.getMessage(), + equalTo("unrecognized inference configuration type [unknown_type]. Supported types [classification, regression]")); + + Map config2 = new HashMap() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("regression", "boom")); + }}; + ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config2)); + assertThat(ex.getMessage(), + equalTo("inference_config must be an object with one inference type mapped to an object.")); + + Map config3 = new HashMap() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.emptyMap()); + }}; + ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config3)); + assertThat(ex.getMessage(), + equalTo("inference_config must be an object with one inference type mapped to an object.")); + } + + public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService, + licenseState); + processorFactory.accept(builderClusterStateWithModelReferences(Version.V_7_5_0, "model1")); + + Map regression = new HashMap() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap())); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression); + fail("Should not have successfully created"); + } catch (ElasticsearchException ex) { + assertThat(ex.getMessage(), + equalTo("Configuration [regression] requires minimum node version [7.6.0] (current minimum node version [7.5.0]")); + } catch (Exception ex) { + fail(ex.getMessage()); + } + + Map classification = new HashMap() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME, + Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1))); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", classification); + fail("Should not have successfully created"); + } catch (ElasticsearchException ex) { + assertThat(ex.getMessage(), + equalTo("Configuration [classification] requires minimum node version [7.6.0] (current minimum node version [7.5.0]")); + } catch (Exception ex) { + fail(ex.getMessage()); + } + } + + public void testCreateProcessor() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService, + licenseState); + + Map regression = new HashMap() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap())); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression); + } catch (Exception ex) { + fail(ex.getMessage()); + } + + Map classification = new HashMap() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME, + Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1))); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", classification); + } catch (Exception ex) { + fail(ex.getMessage()); + } + } + + private static ClusterState buildClusterState(MetaData metaData) { + return ClusterState.builder(new ClusterName("_name")).metaData(metaData).build(); + } + + private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { + return builderClusterStateWithModelReferences(Version.CURRENT, modelId); + } + + private static ClusterState builderClusterStateWithModelReferences(Version minNodeVersion, String... modelId) throws IOException { + Map configurations = new HashMap<>(modelId.length); + for (String id : modelId) { + configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id)); + } + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + + return ClusterState.builder(new ClusterName("_name")) + .metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) + .nodes(DiscoveryNodes.builder() + .add(new DiscoveryNode("min_node", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + minNodeVersion)) + .add(new DiscoveryNode("current_node", + new TransportAddress(InetAddress.getLoopbackAddress(), 9302), + Version.CURRENT)) + .localNodeId("_node_id") + .masterNodeId("_node_id")) + .build(); + } + + private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException { + try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors", + Collections.singletonList( + Collections.singletonMap(InferenceProcessor.TYPE, + new HashMap() {{ + put(InferenceProcessor.MODEL_ID, modelId); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap())); + put(InferenceProcessor.TARGET_FIELD, "new_field"); + put(InferenceProcessor.FIELD_MAPPINGS, Collections.singletonMap("source", "dest")); + }}))))) { + return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON); + } + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java new file mode 100644 index 00000000000..ac8c0bb3dd6 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -0,0 +1,230 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.ingest; + +import org.elasticsearch.client.Client; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.Is.is; +import static org.mockito.Mockito.mock; + +public class InferenceProcessorTests extends ESTestCase { + + private Client client; + + @Before + public void setUpVariables() { + client = mock(Client.class); + } + + public void testMutateDocumentWithClassification() { + String targetField = "classification_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "classification_model", + new ClassificationConfig(0), + Collections.emptyMap(), + "ml.my_processor", + true); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", null))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, String.class), equalTo("foo")); + assertThat(document.getFieldValue("ml", Map.class), + equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model")))); + } + + @SuppressWarnings("unchecked") + public void testMutateDocumentClassificationTopNClasses() { + String targetField = "classification_value_probabilities"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "classification_model", + new ClassificationConfig(2), + Collections.emptyMap(), + "ml.my_processor", + true); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + List classes = new ArrayList<>(2); + classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6)); + classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4)); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes))); + inferenceProcessor.mutateDocument(response, document); + + assertThat((List>)document.getFieldValue(targetField, List.class), + contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new))); + assertThat(document.getFieldValue("ml", Map.class), + equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model")))); + } + + public void testMutateDocumentRegression() { + String targetField = "regression_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "regression_model", + new RegressionConfig(), + Collections.emptyMap(), + "ml.my_processor", + true); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); + assertThat(document.getFieldValue("ml", Map.class), + equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "regression_model")))); + } + + public void testMutateDocumentNoModelMetaData() { + String targetField = "regression_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "regression_model", + new RegressionConfig(), + Collections.emptyMap(), + "ml.my_processor", + false); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); + assertThat(document.hasField("ml"), is(false)); + } + + public void testMutateDocumentModelMetaDataExistingField() { + String targetField = "regression_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "regression_model", + new RegressionConfig(), + Collections.emptyMap(), + "ml.my_processor", + true); + + //cannot use singleton map as attempting to mutate later + Map ml = new HashMap(){{ + put("regression_prediction", 0.55); + }}; + Map source = new HashMap(){{ + put("ml", ml); + }}; + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); + assertThat(document.getFieldValue("ml", Map.class), + equalTo(new HashMap(){{ + put("my_processor", Collections.singletonMap("model_id", "regression_model")); + put("regression_prediction", 0.55); + }})); + } + + public void testGenerateRequestWithEmptyMapping() { + String modelId = "model"; + Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); + + InferenceProcessor processor = new InferenceProcessor(client, + "my_processor", + "my_field", + modelId, + new ClassificationConfig(topNClasses), + Collections.emptyMap(), + "ml.my_processor", + false); + + Map source = new HashMap(){{ + put("value1", 1); + put("value2", 4); + put("categorical", "foo"); + }}; + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(source)); + } + + public void testGenerateWithMapping() { + String modelId = "model"; + Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); + + Map fieldMapping = new HashMap(3) {{ + put("value1", "new_value1"); + put("value2", "new_value2"); + put("categorical", "new_categorical"); + }}; + + InferenceProcessor processor = new InferenceProcessor(client, + "my_processor", + "my_field", + modelId, + new ClassificationConfig(topNClasses), + fieldMapping, + "ml.my_processor", + false); + + Map source = new HashMap(3){{ + put("value1", 1); + put("categorical", "foo"); + put("un_touched", "bar"); + }}; + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + Map expectedMap = new HashMap(2) {{ + put("new_value1", 1); + put("new_categorical", "foo"); + put("un_touched", "bar"); + }}; + assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java new file mode 100644 index 00000000000..ceaa2d33dd4 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -0,0 +1,212 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class LocalModelTests extends ESTestCase { + + public void testClassificationInfer() throws Exception { + String modelId = "classification_model"; + TrainedModelDefinition definition = new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildClassification(false)) + .build(); + + Model model = new LocalModel(modelId, definition); + Map fields = new HashMap() {{ + put("foo", 1.0); + put("bar", 0.5); + put("categorical", "dog"); + }}; + + SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfig(0)); + assertThat(result.value(), equalTo(0.0)); + assertThat(result.valueAsString(), is("0.0")); + + ClassificationInferenceResults classificationResult = + (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1)); + assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); + assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0")); + + // Test with labels + definition = new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildClassification(true)) + .build(); + model = new LocalModel(modelId, definition); + result = getSingleValue(model, fields, new ClassificationConfig(0)); + assertThat(result.value(), equalTo(0.0)); + assertThat(result.valueAsString(), equalTo("not_to_be")); + + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1)); + assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); + assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); + + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(2)); + assertThat(classificationResult.getTopClasses(), hasSize(2)); + + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(-1)); + assertThat(classificationResult.getTopClasses(), hasSize(2)); + } + + public void testRegression() throws Exception { + TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildRegression()) + .build(); + Model model = new LocalModel("regression_model", trainedModelDefinition); + + Map fields = new HashMap() {{ + put("foo", 1.0); + put("bar", 0.5); + put("categorical", "dog"); + }}; + + SingleValueInferenceResults results = getSingleValue(model, fields, new RegressionConfig()); + assertThat(results.value(), equalTo(1.3)); + + PlainActionFuture failedFuture = new PlainActionFuture<>(); + model.infer(fields, new ClassificationConfig(2), failedFuture); + ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get); + assertThat(ex.getCause().getMessage(), + equalTo("Cannot infer using configuration for [classification] when model target_type is [regression]")); + } + + private static SingleValueInferenceResults getSingleValue(Model model, + Map fields, + InferenceConfig config) throws Exception { + PlainActionFuture future = new PlainActionFuture<>(); + model.infer(fields, config, future); + return (SingleValueInferenceResults)future.get(); + } + + private static Map oneHotMap() { + Map oneHotEncoding = new HashMap<>(); + oneHotEncoding.put("cat", "animal_cat"); + oneHotEncoding.put("dog", "animal_dog"); + return oneHotEncoding; + } + + public static TrainedModel buildClassification(boolean includeLabels) { + List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.0)) + .addNode(TreeNode.builder(4).setLeafValue(1.0)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(3) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(0.0)) + .addNode(TreeNode.builder(2).setLeafValue(1.0)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2).setLeafValue(0.0)) + .build(); + return Ensemble.builder() + .setClassificationLabels(includeLabels ? Arrays.asList("not_to_be", "to_be") : null) + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0})) + .build(); + } + + public static TrainedModel buildRegression() { + List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(0.3)) + .addNode(TreeNode.builder(2) + .setThreshold(0.0) + .setSplitFeature(3) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.1)) + .addNode(TreeNode.builder(4).setLeafValue(0.2)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(2) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(1) + .setThreshold(0.2)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + return Ensemble.builder() + .setTargetType(TargetType.REGRESSION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5, 0.5})) + .build(); + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java new file mode 100644 index 00000000000..272628f4c12 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -0,0 +1,363 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodeRole; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; +import org.junit.After; +import org.junit.Before; +import org.mockito.Mockito; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.atMost; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ModelLoadingServiceTests extends ESTestCase { + + private TrainedModelProvider trainedModelProvider; + private ThreadPool threadPool; + private ClusterService clusterService; + private InferenceAuditor auditor; + + @Before + public void setUpComponents() { + threadPool = new TestThreadPool("ModelLoadingServiceTests", new ScalingExecutorBuilder(UTILITY_THREAD_POOL_NAME, + 1, 4, TimeValue.timeValueMinutes(10), "xpack.ml.utility_thread_pool")); + trainedModelProvider = mock(TrainedModelProvider.class); + clusterService = mock(ClusterService.class); + auditor = mock(InferenceAuditor.class); + doAnswer(a -> null).when(auditor).error(any(String.class), any(String.class)); + doAnswer(a -> null).when(auditor).info(any(String.class), any(String.class)); + doAnswer(a -> null).when(auditor).warning(any(String.class), any(String.class)); + doAnswer((invocationOnMock) -> null).when(clusterService).addListener(any(ClusterStateListener.class)); + when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("_name")).build()); + } + + @After + public void terminateThreadPool() { + terminate(threadPool); + } + + public void testGetCachedModels() throws Exception { + String model1 = "test-load-model-1"; + String model2 = "test-load-model-2"; + String model3 = "test-load-model-3"; + withTrainedModel(model1, 1L); + withTrainedModel(model2, 1L); + withTrainedModel(model3, 1L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + Settings.EMPTY); + + modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); + + String[] modelIds = new String[]{model1, model2, model3}; + for(int i = 0; i < 10; i++) { + String model = modelIds[i%3]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any()); + + // Test invalidate cache for model3 + modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); + for(int i = 0; i < 10; i++) { + String model = modelIds[i%3]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any()); + // It is not referenced, so called eagerly + verify(trainedModelProvider, times(4)).getTrainedModel(eq(model3), eq(true), any()); + } + + public void testMaxCachedLimitReached() throws Exception { + String model1 = "test-cached-limit-load-model-1"; + String model2 = "test-cached-limit-load-model-2"; + String model3 = "test-cached-limit-load-model-3"; + withTrainedModel(model1, 10L); + withTrainedModel(model2, 5L); + withTrainedModel(model3, 15L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build()); + + modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); + + // Should have been loaded from the cluster change event + // Verify that we have at least loaded all three so that evictions occur in the following loop + assertBusy(() -> { + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any()); + }); + + String[] modelIds = new String[]{model1, model2, model3}; + for(int i = 0; i < 10; i++) { + // Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load) + String model = modelIds[i%2]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model1), eq(true), any()); + verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any()); + // Only loaded requested once on the initial load from the change event + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any()); + + // Load model 3, should invalidate 1 + for(int i = 0; i < 10; i++) { + PlainActionFuture future3 = new PlainActionFuture<>(); + modelLoadingService.getModel(model3, future3); + assertThat(future3.get(), is(not(nullValue()))); + } + verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model3), eq(true), any()); + + // Load model 1, should invalidate 2 + for(int i = 0; i < 10; i++) { + PlainActionFuture future1 = new PlainActionFuture<>(); + modelLoadingService.getModel(model1, future1); + assertThat(future1.get(), is(not(nullValue()))); + } + verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any()); + + // Load model 2, should invalidate 3 + for(int i = 0; i < 10; i++) { + PlainActionFuture future2 = new PlainActionFuture<>(); + modelLoadingService.getModel(model2, future2); + assertThat(future2.get(), is(not(nullValue()))); + } + verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any()); + + + // Test invalidate cache for model3 + // Now both model 1 and 2 should fit in cache without issues + modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); + for(int i = 0; i < 10; i++) { + String model = modelIds[i%3]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any()); + verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any()); + verify(trainedModelProvider, Mockito.atLeast(4)).getTrainedModel(eq(model3), eq(true), any()); + verify(trainedModelProvider, Mockito.atMost(5)).getTrainedModel(eq(model3), eq(true), any()); + } + + + public void testWhenCacheEnabledButNotIngestNode() throws Exception { + String model1 = "test-uncached-not-ingest-model-1"; + withTrainedModel(model1, 1L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + Settings.EMPTY); + + modelLoadingService.clusterChanged(ingestChangedEvent(false, model1)); + + for(int i = 0; i < 10; i++) { + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model1, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(10)).getTrainedModel(eq(model1), eq(true), any()); + } + + public void testGetCachedMissingModel() throws Exception { + String model = "test-load-cached-missing-model"; + withMissingModel(model); + + ModelLoadingService modelLoadingService =new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + Settings.EMPTY); + modelLoadingService.clusterChanged(ingestChangedEvent(model)); + + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, future); + + try { + future.get(); + fail("Should not have succeeded in loaded model"); + } catch (Exception ex) { + assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model))); + } + + verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(true), any()); + } + + public void testGetMissingModel() { + String model = "test-load-missing-model"; + withMissingModel(model); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + Settings.EMPTY); + + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, future); + try { + future.get(); + fail("Should not have succeeded"); + } catch (Exception ex) { + assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model))); + } + } + + public void testGetModelEagerly() throws Exception { + String model = "test-get-model-eagerly"; + withTrainedModel(model, 1L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + Settings.EMPTY); + + for(int i = 0; i < 3; i++) { + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(true), any()); + } + + @SuppressWarnings("unchecked") + private void withTrainedModel(String modelId, long size) { + TrainedModelDefinition definition = mock(TrainedModelDefinition.class); + when(definition.ramBytesUsed()).thenReturn(size); + TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class); + when(trainedModelConfig.getDefinition()).thenReturn(definition); + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(trainedModelConfig); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(true), any()); + } + + private void withMissingModel(String modelId) { + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(true), any()); + } + + private static ClusterChangedEvent ingestChangedEvent(String... modelId) throws IOException { + return ingestChangedEvent(true, modelId); + } + + private static ClusterChangedEvent ingestChangedEvent(boolean isIngestNode, String... modelId) throws IOException { + ClusterChangedEvent event = mock(ClusterChangedEvent.class); + when(event.changedCustomMetaDataSet()).thenReturn(Collections.singleton(IngestMetadata.TYPE)); + when(event.state()).thenReturn(buildClusterStateWithModelReferences(isIngestNode, modelId)); + return event; + } + + private static ClusterState buildClusterStateWithModelReferences(boolean isIngestNode, String... modelId) throws IOException { + Map configurations = new HashMap<>(modelId.length); + for (String id : modelId) { + configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id)); + } + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + + return ClusterState.builder(new ClusterName("_name")) + .metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) + .nodes(DiscoveryNodes.builder().add( + new DiscoveryNode("node_name", + "node_id", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + isIngestNode ? Collections.singleton(DiscoveryNodeRole.INGEST_ROLE) : Collections.emptySet(), + Version.CURRENT)) + .localNodeId("node_id") + .build()) + .build(); + } + + private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException { + try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors", + Collections.singletonList( + Collections.singletonMap(InferenceProcessor.TYPE, + Collections.singletonMap(InferenceProcessor.MODEL_ID, + modelId)))))) { + return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON); + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java new file mode 100644 index 00000000000..1f4d0f8b7ba --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -0,0 +1,196 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.junit.Before; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildClassification; +import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildRegression; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.nullValue; + +public class ModelInferenceActionIT extends MlSingleNodeTestCase { + + private TrainedModelProvider trainedModelProvider; + + @Before + public void createComponents() throws Exception { + trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry()); + waitForMlTemplates(); + } + + public void testInferModels() throws Exception { + String modelId1 = "test-load-models-regression"; + String modelId2 = "test-load-models-classification"; + Map oneHotEncoding = new HashMap<>(); + oneHotEncoding.put("cat", "animal_cat"); + oneHotEncoding.put("dog", "animal_dog"); + TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2) + .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) + .setDefinition(new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setTrainedModel(buildClassification(true)) + .setModelId(modelId1)) + .setVersion(Version.CURRENT) + .setCreateTime(Instant.now()) + .setEstimatedOperations(0) + .setEstimatedHeapMemory(0) + .build(); + TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1) + .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) + .setDefinition(new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setTrainedModel(buildRegression()) + .setModelId(modelId2)) + .setVersion(Version.CURRENT) + .setEstimatedOperations(0) + .setEstimatedHeapMemory(0) + .setCreateTime(Instant.now()) + .build(); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config1, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config2, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + + + List> toInfer = new ArrayList<>(); + toInfer.add(new HashMap() {{ + put("foo", 1.0); + put("bar", 0.5); + put("categorical", "dog"); + }}); + toInfer.add(new HashMap() {{ + put("foo", 0.9); + put("bar", 1.5); + put("categorical", "cat"); + }}); + + List> toInfer2 = new ArrayList<>(); + toInfer2.add(new HashMap() {{ + put("foo", 0.0); + put("bar", 0.01); + put("categorical", "dog"); + }}); + toInfer2.add(new HashMap() {{ + put("foo", 1.0); + put("bar", 0.0); + put("categorical", "cat"); + }}); + + // Test regression + InferModelAction.Request request = new InferModelAction.Request(modelId1, toInfer, new RegressionConfig()); + InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), + contains(1.3, 1.25)); + + request = new InferModelAction.Request(modelId1, toInfer2, new RegressionConfig()); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), + contains(1.65, 1.55)); + + + // Test classification + request = new InferModelAction.Request(modelId2, toInfer, new ClassificationConfig(0)); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResults() + .stream() + .map(i -> ((SingleValueInferenceResults)i).valueAsString()) + .collect(Collectors.toList()), + contains("not_to_be", "to_be")); + + // Get top classes + request = new InferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2)); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + + ClassificationInferenceResults classificationInferenceResults = + (ClassificationInferenceResults)response.getInferenceResults().get(0); + + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), + greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); + + classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1); + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("not_to_be")); + // they should always be in order of Most probable to least + assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), + greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); + + // Test that top classes restrict the number returned + request = new InferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1)); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + + classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); + assertThat(classificationInferenceResults.getTopClasses(), hasSize(1)); + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be")); + } + + public void testInferMissingModel() { + String model = "test-infer-missing-model"; + InferModelAction.Request request = new InferModelAction.Request(model, Collections.emptyList(), new RegressionConfig()); + try { + client().execute(InferModelAction.INSTANCE, request).actionGet(); + } catch (ElasticsearchException ex) { + assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model))); + } + } + + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { + return TrainedModelConfig.builder() + .setCreatedBy("ml_test") + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId)) + .setDescription("trained model config for test") + .setModelId(modelId); + } + + @Override + public NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index 84bb1d106fe..d7dd25ec40e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -144,6 +144,8 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { .setDescription("trained model config for test") .setModelId(modelId) .setVersion(Version.CURRENT) + .setEstimatedHeapMemory(0) + .setEstimatedOperations(0) .setInput(TrainedModelInputTests.createRandomInput()); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_trained_model.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_trained_model.json new file mode 100644 index 00000000000..edfc157646f --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_trained_model.json @@ -0,0 +1,24 @@ +{ + "ml.delete_trained_model":{ + "documentation":{ + "url":"TODO" + }, + "stability":"experimental", + "url":{ + "paths":[ + { + "path":"/_ml/inference/{model_id}", + "methods":[ + "DELETE" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the trained model to delete" + } + } + } + ] + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json new file mode 100644 index 00000000000..22d16a6c369 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json @@ -0,0 +1,54 @@ +{ + "ml.get_trained_models":{ + "documentation":{ + "url":"TODO" + }, + "stability":"experimental", + "url":{ + "paths":[ + { + "path":"/_ml/inference/{model_id}", + "methods":[ + "GET" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the trained models to fetch" + } + } + }, + { + "path":"/_ml/inference", + "methods":[ + "GET" + ] + } + ] + }, + "params":{ + "allow_no_match":{ + "type":"boolean", + "required":false, + "description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)", + "default":true + }, + "include_model_definition":{ + "type":"boolean", + "required":false, + "description":"Should the full model definition be included in the results. These definitions can be large", + "default":false + }, + "from":{ + "type":"int", + "description":"skips a number of trained models", + "default":0 + }, + "size":{ + "type":"int", + "description":"specifies a max number of trained models to get", + "default":100 + } + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models_stats.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models_stats.json new file mode 100644 index 00000000000..703380c7087 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models_stats.json @@ -0,0 +1,48 @@ +{ + "ml.get_trained_models_stats":{ + "documentation":{ + "url":"TODO" + }, + "stability":"experimental", + "url":{ + "paths":[ + { + "path":"/_ml/inference/{model_id}/_stats", + "methods":[ + "GET" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the trained models stats to fetch" + } + } + }, + { + "path":"/_ml/inference/_stats", + "methods":[ + "GET" + ] + } + ] + }, + "params":{ + "allow_no_match":{ + "type":"boolean", + "required":false, + "description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)", + "default":true + }, + "from":{ + "type":"int", + "description":"skips a number of trained models", + "default":0 + }, + "size":{ + "type":"int", + "description":"specifies a max number of trained models to get", + "default":100 + } + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml new file mode 100644 index 00000000000..a8b199a7a3b --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -0,0 +1,113 @@ +--- +"Test get-all given no trained models exist": + + - do: + ml.get_trained_models: + model_id: "_all" + - match: { count: 0 } + - match: { trained_model_configs: [] } + + - do: + ml.get_trained_models: + model_id: "*" + - match: { count: 0 } + - match: { trained_model_configs: [] } + +--- +"Test get given missing trained model": + + - do: + catch: missing + ml.get_trained_models: + model_id: "missing-trained-model" +--- +"Test get given expression without matches and allow_no_match is false": + + - do: + catch: missing + ml.get_trained_models: + model_id: "missing-trained-model*" + allow_no_match: false + +--- +"Test get given expression without matches and allow_no_match is true": + + - do: + ml.get_trained_models: + model_id: "missing-trained-model*" + allow_no_match: true + - match: { count: 0 } + - match: { trained_model_configs: [] } +--- +"Test delete given unused trained model": + + - do: + index: + id: trained_model_config-unused-regression-model-0 + index: .ml-inference-000001 + body: > + { + "model_id": "unused-regression-model", + "created_by": "ml_tests", + "version": "8.0.0", + "description": "empty model for tests", + "create_time": 0, + "model_version": 0, + "model_type": "local" + } + - do: + indices.refresh: {} + + - do: + ml.delete_trained_model: + model_id: "unused-regression-model" + - match: { acknowledged: true } + +--- +"Test delete with missing model": + - do: + catch: missing + ml.delete_trained_model: + model_id: "missing-trained-model" + +--- +"Test delete given used trained model": + - do: + index: + id: trained_model_config-used-regression-model-0 + index: .ml-inference-000001 + body: > + { + "model_id": "used-regression-model", + "created_by": "ml_tests", + "version": "8.0.0", + "description": "empty model for tests", + "create_time": 0, + "model_version": 0, + "model_type": "local" + } + - do: + indices.refresh: {} + + - do: + ingest.put_pipeline: + id: "regression-model-pipeline" + body: > + { + "processors": [ + { + "inference" : { + "model_id" : "used-regression-model", + "inference_config": {"regression": {}}, + "target_field": "regression_field", + "field_mappings": {} + } + } + ] + } + - match: { acknowledged: true } + + - do: + catch: conflict + ml.delete_trained_model: + model_id: "used-regression-model" diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml new file mode 100644 index 00000000000..6062f651906 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml @@ -0,0 +1,230 @@ +setup: + - skip: + features: headers + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + index: + id: trained_model_config-unused-regression-model1-0 + index: .ml-inference-000001 + body: > + { + "model_id": "unused-regression-model1", + "created_by": "ml_tests", + "version": "8.0.0", + "description": "empty model for tests", + "create_time": 0, + "model_version": 0, + "model_type": "local", + "doc_type": "trained_model_config" + } + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + index: + id: trained_model_config-unused-regression-model-0 + index: .ml-inference-000001 + body: > + { + "model_id": "unused-regression-model", + "created_by": "ml_tests", + "version": "8.0.0", + "description": "empty model for tests", + "create_time": 0, + "model_version": 0, + "model_type": "local", + "doc_type": "trained_model_config" + } + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + index: + id: trained_model_config-used-regression-model-0 + index: .ml-inference-000001 + body: > + { + "model_id": "used-regression-model", + "created_by": "ml_tests", + "version": "8.0.0", + "description": "empty model for tests", + "create_time": 0, + "model_version": 0, + "model_type": "local", + "doc_type": "trained_model_config" + } + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + indices.refresh: {} + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ingest.put_pipeline: + id: "regression-model-pipeline" + body: > + { + "processors": [ + { + "inference" : { + "model_id" : "used-regression-model", + "inference_config": {"regression": {}}, + "target_field": "regression_field", + "field_mappings": {} + } + } + ] + } + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ingest.put_pipeline: + id: "regression-model-pipeline-1" + body: > + { + "processors": [ + { + "inference" : { + "model_id" : "used-regression-model", + "inference_config": {"regression": {}}, + "target_field": "regression_field", + "field_mappings": {} + } + } + ] + } +--- +"Test get stats given missing trained model": + + - do: + catch: missing + ml.get_trained_models_stats: + model_id: "missing-trained-model" +--- +"Test get stats given expression without matches and allow_no_match is false": + + - do: + catch: missing + ml.get_trained_models_stats: + model_id: "missing-trained-model*" + allow_no_match: false + +--- +"Test get stats given expression without matches and allow_no_match is true": + + - do: + ml.get_trained_models_stats: + model_id: "missing-trained-model*" + allow_no_match: true + - match: { count: 0 } + - match: { trained_model_stats: [] } +--- +"Test get stats given trained models": + + - do: + ml.get_trained_models_stats: + model_id: "unused-regression-model" + + - match: { count: 1 } + + - do: + ml.get_trained_models_stats: + model_id: "_all" + - match: { count: 3 } + - match: { trained_model_stats.0.model_id: unused-regression-model } + - match: { trained_model_stats.0.pipeline_count: 0 } + - is_false: trained_model_stats.0.ingest + - match: { trained_model_stats.1.model_id: unused-regression-model1 } + - match: { trained_model_stats.1.pipeline_count: 0 } + - is_false: trained_model_stats.1.ingest + - match: { trained_model_stats.2.pipeline_count: 2 } + - is_true: trained_model_stats.2.ingest + + - do: + ml.get_trained_models_stats: + model_id: "*" + - match: { count: 3 } + - match: { trained_model_stats.0.model_id: unused-regression-model } + - match: { trained_model_stats.0.pipeline_count: 0 } + - is_false: trained_model_stats.0.ingest + - match: { trained_model_stats.1.model_id: unused-regression-model1 } + - match: { trained_model_stats.1.pipeline_count: 0 } + - is_false: trained_model_stats.1.ingest + - match: { trained_model_stats.2.pipeline_count: 2 } + - is_true: trained_model_stats.2.ingest + + - do: + ml.get_trained_models_stats: + model_id: "unused-regression-model*" + - match: { count: 2 } + - match: { trained_model_stats.0.model_id: unused-regression-model } + - match: { trained_model_stats.0.pipeline_count: 0 } + - is_false: trained_model_stats.0.ingest + - match: { trained_model_stats.1.model_id: unused-regression-model1 } + - match: { trained_model_stats.1.pipeline_count: 0 } + - is_false: trained_model_stats.1.ingest + + - do: + ml.get_trained_models_stats: + model_id: "unused-regression-model*" + size: 1 + - match: { count: 2 } + - match: { trained_model_stats.0.model_id: unused-regression-model } + - match: { trained_model_stats.0.pipeline_count: 0 } + - is_false: trained_model_stats.0.ingest + + - do: + ml.get_trained_models_stats: + model_id: "unused-regression-model*" + from: 1 + size: 1 + - match: { count: 2 } + - match: { trained_model_stats.0.model_id: unused-regression-model1 } + - match: { trained_model_stats.0.pipeline_count: 0 } + - is_false: trained_model_stats.0.ingest + + - do: + ml.get_trained_models_stats: + model_id: "used-regression-model" + + - match: { count: 1 } + - match: { trained_model_stats.0.model_id: used-regression-model } + - match: { trained_model_stats.0.pipeline_count: 2 } + - match: + trained_model_stats.0.ingest.total: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0 + + - match: + trained_model_stats.0.ingest.pipelines.regression-model-pipeline: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0 + processors: + - inference: + type: inference + stats: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0 + + - match: + trained_model_stats.0.ingest.pipelines.regression-model-pipeline-1: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0 + processors: + - inference: + type: inference + stats: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0