From 2881995a450a1e07ec939f95aeb61a888b939aed Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 11 Jun 2020 15:59:23 -0400 Subject: [PATCH] [ML] adding new inference model size estimate handling from native process (#57930) (#57999) Adds support for reading in `model_size_info` objects. These objects contain numeric values indicating the model definition size and complexity. Additionally, these objects are not stored or serialized to any other node. They are to be used for calculating and storing model metadata. They are much smaller on heap than the true model definition and should help prevent the analytics process from using too much memory. Co-authored-by: Elastic Machine --- .../preprocessing/FrequencyEncoding.java | 2 +- .../preprocessing/OneHotEncoding.java | 2 +- .../preprocessing/TargetMeanEncoding.java | 2 +- .../ensemble/LogisticRegression.java | 2 +- .../inference/EnsembleInferenceModel.java | 19 ++- .../inference/InferenceDefinition.java | 2 +- .../inference/TreeInferenceModel.java | 43 +++--- .../trainedmodel/ensemble/EnsembleTests.java | 2 +- .../inference/TreeInferenceModelTests.java | 16 ++- .../xpack/ml/MachineLearning.java | 2 + .../process/results/AnalyticsResult.java | 19 ++- .../inference/modelsize/EnsembleSizeInfo.java | 136 ++++++++++++++++++ .../modelsize/FrequencyEncodingSize.java | 98 +++++++++++++ .../MlModelSizeNamedXContentProvider.java | 36 +++++ .../ml/inference/modelsize/ModelSizeInfo.java | 91 ++++++++++++ .../modelsize/OneHotEncodingSize.java | 100 +++++++++++++ .../inference/modelsize/PreprocessorSize.java | 20 +++ .../modelsize/SizeEstimatorHelper.java | 48 +++++++ .../modelsize/TargetMeanEncodingSize.java | 98 +++++++++++++ .../modelsize/TrainedModelSizeInfo.java | 13 ++ .../ml/inference/modelsize/TreeSizeInfo.java | 101 +++++++++++++ .../process/AnalyticsProcessManagerTests.java | 2 +- .../AnalyticsResultProcessorTests.java | 16 +-- .../process/results/AnalyticsResultTests.java | 11 +- .../modelsize/EnsembleSizeInfoTests.java | 88 ++++++++++++ .../modelsize/FrequencyEncodingSizeTests.java | 56 ++++++++ .../modelsize/ModelSizeInfoTests.java | 124 ++++++++++++++++ .../modelsize/OneHotEncodingSizeTests.java | 60 ++++++++ .../modelsize/SizeEstimatorTestCase.java | 48 +++++++ .../TargetMeanEncodingSizeTests.java | 56 ++++++++ .../modelsize/TreeSizeInfoTests.java | 68 +++++++++ 31 files changed, 1339 insertions(+), 42 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/EnsembleSizeInfo.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSize.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/MlModelSizeNamedXContentProvider.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfo.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSize.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/PreprocessorSize.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorHelper.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSize.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TrainedModelSizeInfo.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfo.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/EnsembleSizeInfoTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSizeTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfoTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSizeTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorTestCase.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSizeTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfoTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java index 70047744415..eb8d148ad1e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java @@ -27,7 +27,7 @@ import java.util.Objects; */ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FrequencyEncoding.class); + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FrequencyEncoding.class); public static final ParseField NAME = new ParseField("frequency_encoding"); public static final ParseField FIELD = new ParseField("field"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java index cd92a6fe22a..f4f98af125b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java @@ -27,7 +27,7 @@ import java.util.stream.Collectors; */ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(OneHotEncoding.class); + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(OneHotEncoding.class); public static final ParseField NAME = new ParseField("one_hot_encoding"); public static final ParseField FIELD = new ParseField("field"); public static final ParseField HOT_MAP = new ParseField("hot_map"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java index 3ae12e207f1..aa6495848a6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java @@ -27,7 +27,7 @@ import java.util.Objects; */ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TargetMeanEncoding.class); + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TargetMeanEncoding.class); public static final ParseField NAME = new ParseField("target_mean_encoding"); public static final ParseField FIELD = new ParseField("field"); public static final ParseField FEATURE_NAME = new ParseField("feature_name"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java index ccd6adbc9a3..d4995b4b24c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java @@ -25,7 +25,7 @@ import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class); + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class); public static final ParseField NAME = new ParseField("logistic_regression"); public static final ParseField WEIGHTS = new ParseField("weights"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java index 4c2ffc5dc25..29bd8e4579f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java @@ -47,7 +47,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.En public class EnsembleInferenceModel implements InferenceModel { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class); + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -275,4 +275,21 @@ public class EnsembleInferenceModel implements InferenceModel { size += outputAggregator.ramBytesUsed(); return size; } + + public List getModels() { + return models; + } + + public OutputAggregator getOutputAggregator() { + return outputAggregator; + } + + public TargetType getTargetType() { + return targetType; + } + + public double[] getClassificationWeights() { + return classificationWeights; + } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java index d218854c355..111f0785f7a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java @@ -26,7 +26,7 @@ import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition.T public class InferenceDefinition { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InferenceDefinition.class); + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InferenceDefinition.class); public static final String NAME = "inference_model_definition"; private final InferenceModel trainedModel; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java index 7c8aed0d016..68bbbf84a20 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.Numbers; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -27,13 +26,15 @@ import org.elasticsearch.xpack.core.ml.inference.utils.Statistics; import org.elasticsearch.xpack.core.ml.job.config.Operator; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; +import static org.apache.lucene.util.RamUsageEstimator.sizeOf; +import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; @@ -53,7 +54,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNo public class TreeInferenceModel implements InferenceModel { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TreeInferenceModel.class); + public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -75,7 +76,7 @@ public class TreeInferenceModel implements InferenceModel { private final Node[] nodes; private String[] featureNames; private final TargetType targetType; - private final List classificationLabels; + private List classificationLabels; private final double highOrderCategory; private final int maxDepth; private final int leafSize; @@ -302,18 +303,16 @@ public class TreeInferenceModel implements InferenceModel { treeNode.splitFeature = newSplitFeatureIndex; } this.featureNames = new String[0]; + // Since we are not top level, we no longer need local classification labels + this.classificationLabels = null; } @Override public long ramBytesUsed() { long size = SHALLOW_SIZE; - size += RamUsageEstimator.sizeOfCollection(classificationLabels); - size += RamUsageEstimator.sizeOf(featureNames); - size += RamUsageEstimator.shallowSizeOf(nodes); - for (Node node : nodes) { - size += node.ramBytesUsed(); - } - size += RamUsageEstimator.sizeOfCollection(Arrays.asList(nodes)); + size += sizeOfCollection(classificationLabels); + size += sizeOf(featureNames); + size += sizeOf(nodes); return size; } @@ -335,6 +334,10 @@ public class TreeInferenceModel implements InferenceModel { return max; } + public Node[] getNodes() { + return nodes; + } + private static int getDepth(Node[] nodes, int nodeIndex, int depth) { Node node = nodes[nodeIndex]; if (node instanceof LeafNode) { @@ -433,21 +436,21 @@ public class TreeInferenceModel implements InferenceModel { } } - private abstract static class Node implements Accountable { + public abstract static class Node implements Accountable { int compare(double[] features) { throw new IllegalArgumentException("cannot call compare against a leaf node."); } abstract long getNumberSamples(); - boolean isLeaf() { + public boolean isLeaf() { return this instanceof LeafNode; } } - private static class InnerNode extends Node { + public static class InnerNode extends Node { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InnerNode.class); + public static final long SHALLOW_SIZE = shallowSizeOfInstance(InnerNode.class); private final Operator operator; private final double threshold; @@ -498,8 +501,8 @@ public class TreeInferenceModel implements InferenceModel { } } - private static class LeafNode extends Node { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LeafNode.class); + public static class LeafNode extends Node { + public static final long SHALLOW_SIZE = shallowSizeOfInstance(LeafNode.class); private final double[] leafValue; private final long numberSamples; @@ -510,12 +513,16 @@ public class TreeInferenceModel implements InferenceModel { @Override public long ramBytesUsed() { - return SHALLOW_SIZE; + return SHALLOW_SIZE + sizeOf(leafValue); } @Override long getNumberSamples() { return numberSamples; } + + public double[] getLeafValue() { + return leafValue; + } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 6ddbcc7a1c5..9b1c2088aee 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -91,7 +91,7 @@ public class EnsembleTests extends AbstractSerializingTestCase { .toArray() : null; - return new Ensemble(randomBoolean() ? featureNames : Collections.emptyList(), + return new Ensemble(featureNames, models, outputAggregator, targetType, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java index d50980f0696..a337d7d2f98 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java @@ -28,6 +28,7 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; @@ -35,6 +36,13 @@ public class TreeInferenceModelTests extends ESTestCase { private final double eps = 1.0E-8; + public static TreeInferenceModel serializeFromTrainedModel(Tree tree) throws IOException { + NamedXContentRegistry registry = new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + return deserializeFromTrainedModel(tree, + registry, + TreeInferenceModel::fromXContent); + } + @Override protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); @@ -48,7 +56,7 @@ public class TreeInferenceModelTests extends ESTestCase { builder.setFeatureNames(Collections.emptyList()); Tree treeObject = builder.build(); - TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject, + TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); List featureNames = Arrays.asList("foo", "bar"); @@ -71,7 +79,7 @@ public class TreeInferenceModelTests extends ESTestCase { List featureNames = Arrays.asList("foo", "bar"); Tree treeObject = builder.setFeatureNames(featureNames).build(); - TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject, + TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); // This feature vector should hit the right child of the root node @@ -126,7 +134,7 @@ public class TreeInferenceModelTests extends ESTestCase { List featureNames = Arrays.asList("foo", "bar"); Tree treeObject = builder.setFeatureNames(featureNames).setClassificationLabels(Arrays.asList("cat", "dog")).build(); - TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject, + TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); double eps = 0.000001; @@ -200,7 +208,7 @@ public class TreeInferenceModelTests extends ESTestCase { TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L), TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build(); - TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject, + TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 1372db30e5b..aca1dfc6de3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -219,6 +219,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimatio import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; +import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; @@ -1003,6 +1004,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers()); return namedXContent; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java index 5b56ebfdc9a..0a05f13c118 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierD import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import java.io.IOException; import java.util.Collections; @@ -31,6 +32,7 @@ public class AnalyticsResult implements ToXContentObject { private static final ParseField PHASE_PROGRESS = new ParseField("phase_progress"); private static final ParseField INFERENCE_MODEL = new ParseField("inference_model"); + private static final ParseField MODEL_SIZE_INFO = new ParseField("model_size_info"); private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage"); private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats"); private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats"); @@ -44,7 +46,8 @@ public class AnalyticsResult implements ToXContentObject { (MemoryUsage) a[3], (OutlierDetectionStats) a[4], (ClassificationStats) a[5], - (RegressionStats) a[6] + (RegressionStats) a[6], + (ModelSizeInfo) a[7] )); static { @@ -56,6 +59,7 @@ public class AnalyticsResult implements ToXContentObject { PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS); PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS); PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS); + PARSER.declareObject(optionalConstructorArg(), ModelSizeInfo.PARSER, MODEL_SIZE_INFO); } private final RowResults rowResults; @@ -66,6 +70,7 @@ public class AnalyticsResult implements ToXContentObject { private final OutlierDetectionStats outlierDetectionStats; private final ClassificationStats classificationStats; private final RegressionStats regressionStats; + private final ModelSizeInfo modelSizeInfo; public AnalyticsResult(@Nullable RowResults rowResults, @Nullable PhaseProgress phaseProgress, @@ -73,7 +78,8 @@ public class AnalyticsResult implements ToXContentObject { @Nullable MemoryUsage memoryUsage, @Nullable OutlierDetectionStats outlierDetectionStats, @Nullable ClassificationStats classificationStats, - @Nullable RegressionStats regressionStats) { + @Nullable RegressionStats regressionStats, + @Nullable ModelSizeInfo modelSizeInfo) { this.rowResults = rowResults; this.phaseProgress = phaseProgress; this.inferenceModelBuilder = inferenceModelBuilder; @@ -82,6 +88,7 @@ public class AnalyticsResult implements ToXContentObject { this.outlierDetectionStats = outlierDetectionStats; this.classificationStats = classificationStats; this.regressionStats = regressionStats; + this.modelSizeInfo = modelSizeInfo; } public RowResults getRowResults() { @@ -112,6 +119,10 @@ public class AnalyticsResult implements ToXContentObject { return regressionStats; } + public ModelSizeInfo getModelSizeInfo() { + return modelSizeInfo; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -138,6 +149,9 @@ public class AnalyticsResult implements ToXContentObject { if (regressionStats != null) { builder.field(REGRESSION_STATS.getPreferredName(), regressionStats, params); } + if (modelSizeInfo != null) { + builder.field(MODEL_SIZE_INFO.getPreferredName(), modelSizeInfo); + } builder.endObject(); return builder; } @@ -158,6 +172,7 @@ public class AnalyticsResult implements ToXContentObject { && Objects.equals(memoryUsage, that.memoryUsage) && Objects.equals(outlierDetectionStats, that.outlierDetectionStats) && Objects.equals(classificationStats, that.classificationStats) + && Objects.equals(modelSizeInfo, that.modelSizeInfo) && Objects.equals(regressionStats, that.regressionStats); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/EnsembleSizeInfo.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/EnsembleSizeInfo.java new file mode 100644 index 00000000000..c6c2a2d818b --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/EnsembleSizeInfo.java @@ -0,0 +1,136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModel; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfDoubleArray; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfStringCollection; + +public class EnsembleSizeInfo implements TrainedModelSizeInfo { + + public static final ParseField NAME = new ParseField("ensemble_model_size"); + private static final ParseField TREE_SIZES = new ParseField("tree_sizes"); + private static final ParseField INPUT_FIELD_NAME_LENGHTS = new ParseField("input_field_name_lengths"); + private static final ParseField NUM_OUTPUT_PROCESSOR_WEIGHTS = new ParseField("num_output_processor_weights"); + private static final ParseField NUM_CLASSIFICATION_WEIGHTS = new ParseField("num_classification_weights"); + private static final ParseField NUM_OPERATIONS = new ParseField("num_operations"); + private static final ParseField NUM_CLASSES = new ParseField("num_classes"); + + @SuppressWarnings("unchecked") + static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "ensemble_size", + false, + a -> new EnsembleSizeInfo((List)a[0], + (Integer)a[1], + (List)a[2], + a[3] == null ? 0 : (Integer)a[3], + a[4] == null ? 0 : (Integer)a[4], + a[5] == null ? 0 : (Integer)a[5]) + ); + static { + PARSER.declareObjectArray(constructorArg(), TreeSizeInfo.PARSER::apply, TREE_SIZES); + PARSER.declareInt(constructorArg(), NUM_OPERATIONS); + PARSER.declareIntArray(constructorArg(), INPUT_FIELD_NAME_LENGHTS); + PARSER.declareInt(optionalConstructorArg(), NUM_OUTPUT_PROCESSOR_WEIGHTS); + PARSER.declareInt(optionalConstructorArg(), NUM_CLASSIFICATION_WEIGHTS); + PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES); + } + + public static EnsembleSizeInfo fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + + private final List treeSizeInfos; + private final int numOperations; + private final int[] inputFieldNameLengths; + private final int numOutputProcessorWeights; + private final int numClassificationWeights; + private final int numClasses; + + public EnsembleSizeInfo(List treeSizeInfos, + int numOperations, + List inputFieldNameLengths, + int numOutputProcessorWeights, + int numClassificationWeights, + int numClasses) { + this.treeSizeInfos = treeSizeInfos; + this.numOperations = numOperations; + this.inputFieldNameLengths = inputFieldNameLengths.stream().mapToInt(Integer::intValue).toArray(); + this.numOutputProcessorWeights = numOutputProcessorWeights; + this.numClassificationWeights = numClassificationWeights; + this.numClasses = numClasses; + } + + public int getNumOperations() { + return numOperations; + } + + @Override + public long ramBytesUsed() { + long size = EnsembleInferenceModel.SHALLOW_SIZE; + treeSizeInfos.forEach(t -> t.setNumClasses(numClasses).ramBytesUsed()); + size += sizeOfCollection(treeSizeInfos); + size += sizeOfStringCollection(inputFieldNameLengths); + size += LogisticRegression.SHALLOW_SIZE + sizeOfDoubleArray(numOutputProcessorWeights); + size += sizeOfDoubleArray(numClassificationWeights); + return alignObjectSize(size); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TREE_SIZES.getPreferredName(), treeSizeInfos); + builder.field(NUM_OPERATIONS.getPreferredName(), numOperations); + builder.field(NUM_CLASSES.getPreferredName(), numClasses); + builder.field(INPUT_FIELD_NAME_LENGHTS.getPreferredName(), inputFieldNameLengths); + builder.field(NUM_CLASSIFICATION_WEIGHTS.getPreferredName(), numClassificationWeights); + builder.field(NUM_OUTPUT_PROCESSOR_WEIGHTS.getPreferredName(), numOutputProcessorWeights); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + EnsembleSizeInfo that = (EnsembleSizeInfo) o; + return numOperations == that.numOperations && + numOutputProcessorWeights == that.numOutputProcessorWeights && + numClassificationWeights == that.numClassificationWeights && + numClasses == that.numClasses && + Objects.equals(treeSizeInfos, that.treeSizeInfos) && + Arrays.equals(inputFieldNameLengths, that.inputFieldNameLengths); + } + + @Override + public int hashCode() { + int result = Objects.hash(treeSizeInfos, numOperations, numOutputProcessorWeights, numClassificationWeights, numClasses); + result = 31 * result + Arrays.hashCode(inputFieldNameLengths); + return result; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSize.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSize.java new file mode 100644 index 00000000000..f12fcc4733a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSize.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfHashMap; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfString; + +public class FrequencyEncodingSize implements PreprocessorSize { + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "frequency_encoding_size", + false, + a -> new FrequencyEncodingSize((Integer)a[0], (Integer)a[1], (List)a[2]) + ); + static { + PARSER.declareInt(constructorArg(), FIELD_LENGTH); + PARSER.declareInt(constructorArg(), FEATURE_NAME_LENGTH); + PARSER.declareIntArray(constructorArg(), FIELD_VALUE_LENGTHS); + } + + public static FrequencyEncodingSize fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final int fieldLength; + private final int featureNameLength; + private final int[] fieldValueLengths; + + FrequencyEncodingSize(int fieldLength, int featureNameLength, List fieldValueLengths) { + this.fieldLength = fieldLength; + this.featureNameLength = featureNameLength; + this.fieldValueLengths = fieldValueLengths.stream().mapToInt(Integer::intValue).toArray(); + } + + @Override + public long ramBytesUsed() { + final long sizeOfDoubleObject = shallowSizeOfInstance(Double.class); + long size = FrequencyEncoding.SHALLOW_SIZE; + size += sizeOfString(fieldLength); + size += sizeOfString(featureNameLength); + size += sizeOfHashMap( + Arrays.stream(fieldValueLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()), + Stream.generate(() -> sizeOfDoubleObject).limit(fieldValueLengths.length).collect(Collectors.toList())); + return alignObjectSize(size); + } + + @Override + public String getName() { + return FrequencyEncoding.NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD_LENGTH.getPreferredName(), fieldLength); + builder.field(FEATURE_NAME_LENGTH.getPreferredName(), featureNameLength); + builder.field(FIELD_VALUE_LENGTHS.getPreferredName(), fieldValueLengths); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FrequencyEncodingSize that = (FrequencyEncodingSize) o; + return fieldLength == that.fieldLength && + featureNameLength == that.featureNameLength && + Arrays.equals(fieldValueLengths, that.fieldValueLengths); + } + + @Override + public int hashCode() { + int result = Objects.hash(fieldLength, featureNameLength); + result = 31 * result + Arrays.hashCode(fieldValueLengths); + return result; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/MlModelSizeNamedXContentProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/MlModelSizeNamedXContentProvider.java new file mode 100644 index 00000000000..7274940a500 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/MlModelSizeNamedXContentProvider.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding; + +import java.util.Arrays; +import java.util.List; + +public class MlModelSizeNamedXContentProvider implements NamedXContentProvider { + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + new NamedXContentRegistry.Entry(PreprocessorSize.class, + FrequencyEncoding.NAME, + FrequencyEncodingSize::fromXContent), + new NamedXContentRegistry.Entry(PreprocessorSize.class, + OneHotEncoding.NAME, + OneHotEncodingSize::fromXContent), + new NamedXContentRegistry.Entry(PreprocessorSize.class, + TargetMeanEncoding.NAME, + TargetMeanEncodingSize::fromXContent), + new NamedXContentRegistry.Entry(TrainedModelSizeInfo.class, + EnsembleSizeInfo.NAME, + EnsembleSizeInfo::fromXContent) + ); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfo.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfo.java new file mode 100644 index 00000000000..449fc28f71f --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfo.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.apache.lucene.util.Accountable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class ModelSizeInfo implements Accountable, ToXContentObject { + + private static final ParseField PREPROCESSORS = new ParseField("preprocessors"); + private static final ParseField TRAINED_MODEL_SIZE = new ParseField("trained_model_size"); + + @SuppressWarnings("unchecked") + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "model_size", + false, + a -> new ModelSizeInfo((EnsembleSizeInfo)a[0], (List)a[1]) + ); + static { + PARSER.declareNamedObject(constructorArg(), + (p, c, n) -> p.namedObject(TrainedModelSizeInfo.class, n, null), + TRAINED_MODEL_SIZE); + PARSER.declareNamedObjects(optionalConstructorArg(), + (p, c, n) -> p.namedObject(PreprocessorSize.class, n, null), + (val) -> {}, + PREPROCESSORS); + } + + private final EnsembleSizeInfo ensembleSizeInfo; + private final List preprocessorSizes; + + public ModelSizeInfo(EnsembleSizeInfo ensembleSizeInfo, List preprocessorSizes) { + this.ensembleSizeInfo = ensembleSizeInfo; + this.preprocessorSizes = preprocessorSizes == null ? Collections.emptyList() : preprocessorSizes; + } + + public int numOperations() { + return this.preprocessorSizes.size() + this.ensembleSizeInfo.getNumOperations(); + } + + @Override + public long ramBytesUsed() { + long size = InferenceDefinition.SHALLOW_SIZE; + size += ensembleSizeInfo.ramBytesUsed(); + size += preprocessorSizes.stream().mapToLong(PreprocessorSize::ramBytesUsed).sum(); + return alignObjectSize(size); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + NamedXContentObjectHelper.writeNamedObject(builder, params, TRAINED_MODEL_SIZE.getPreferredName(), ensembleSizeInfo); + if (preprocessorSizes.size() > 0) { + NamedXContentObjectHelper.writeNamedObjects(builder, params, true, PREPROCESSORS.getPreferredName(), preprocessorSizes); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ModelSizeInfo modelSizeInfo = (ModelSizeInfo) o; + return Objects.equals(ensembleSizeInfo, modelSizeInfo.ensembleSizeInfo) && + Objects.equals(preprocessorSizes, modelSizeInfo.preprocessorSizes); + } + + @Override + public int hashCode() { + return Objects.hash(ensembleSizeInfo, preprocessorSizes); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSize.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSize.java new file mode 100644 index 00000000000..0255075c45b --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSize.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfHashMap; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfString; + +public class OneHotEncodingSize implements PreprocessorSize { + + private static final ParseField FEATURE_NAME_LENGTHS = new ParseField("feature_name_lengths"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "one_hot_encoding_size", + false, + a -> new OneHotEncodingSize((Integer)a[0], (List)a[1], (List)a[2]) + ); + static { + PARSER.declareInt(constructorArg(), FIELD_LENGTH); + PARSER.declareIntArray(constructorArg(), FEATURE_NAME_LENGTHS); + PARSER.declareIntArray(constructorArg(), FIELD_VALUE_LENGTHS); + } + + public static OneHotEncodingSize fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final int fieldLength; + private final int[] featureNameLengths; + private final int[] fieldValueLengths; + + OneHotEncodingSize(int fieldLength, List featureNameLengths, List fieldValueLengths) { + assert featureNameLengths.size() == fieldValueLengths.size(); + this.fieldLength = fieldLength; + this.featureNameLengths = featureNameLengths.stream().mapToInt(Integer::intValue).toArray(); + this.fieldValueLengths = fieldValueLengths.stream().mapToInt(Integer::intValue).toArray(); + } + + @Override + public long ramBytesUsed() { + long size = OneHotEncoding.SHALLOW_SIZE; + size += sizeOfString(fieldLength); + size += sizeOfHashMap( + Arrays.stream(fieldValueLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()), + Arrays.stream(featureNameLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()) + ); + return alignObjectSize(size); + } + + @Override + public String getName() { + return OneHotEncoding.NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD_LENGTH.getPreferredName(), fieldLength); + builder.field(FEATURE_NAME_LENGTHS.getPreferredName(), featureNameLengths); + builder.field(FIELD_VALUE_LENGTHS.getPreferredName(), fieldValueLengths); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OneHotEncodingSize that = (OneHotEncodingSize) o; + return fieldLength == that.fieldLength && + Arrays.equals(featureNameLengths, that.featureNameLengths) && + Arrays.equals(fieldValueLengths, that.fieldValueLengths); + } + + @Override + public int hashCode() { + int result = Objects.hash(fieldLength); + result = 31 * result + Arrays.hashCode(featureNameLengths); + result = 31 * result + Arrays.hashCode(fieldValueLengths); + return result; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/PreprocessorSize.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/PreprocessorSize.java new file mode 100644 index 00000000000..dfbe10a0464 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/PreprocessorSize.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.apache.lucene.util.Accountable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + + +public interface PreprocessorSize extends Accountable, NamedXContentObject { + ParseField FIELD_LENGTH = new ParseField("field_length"); + ParseField FEATURE_NAME_LENGTH = new ParseField("feature_name_length"); + ParseField FIELD_VALUE_LENGTHS = new ParseField("field_value_lengths"); + + String getName(); +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorHelper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorHelper.java new file mode 100644 index 00000000000..0449d3c6523 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorHelper.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.stream.Stream; + +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; + +final class SizeEstimatorHelper { + + private SizeEstimatorHelper() {} + + private static final int STRING_SIZE = (int) shallowSizeOfInstance(String.class); + + static long sizeOfString(int stringLength) { + // Technically, each value counted in a String.length is 2 bytes. But, this is how `RamUsageEstimator` calculates it + return alignObjectSize(STRING_SIZE + (long)NUM_BYTES_ARRAY_HEADER + (long)(Character.BYTES) * stringLength); + } + + static long sizeOfStringCollection(int[] stringSizes) { + long shallow = alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) NUM_BYTES_OBJECT_REF * stringSizes.length); + return shallow + Arrays.stream(stringSizes).mapToLong(SizeEstimatorHelper::sizeOfString).sum(); + } + + static long sizeOfDoubleArray(int arrayLength) { + return alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) Double.BYTES * arrayLength); + } + + static long sizeOfHashMap(List sizeOfKeys, List sizeOfValues) { + assert sizeOfKeys.size() == sizeOfValues.size(); + long mapsize = shallowSizeOfInstance(HashMap.class); + final long mapEntrySize = shallowSizeOfInstance(HashMap.Entry.class); + mapsize += Stream.concat(sizeOfKeys.stream(), sizeOfValues.stream()).mapToLong(Long::longValue).sum(); + mapsize += mapEntrySize * sizeOfKeys.size(); + mapsize += mapEntrySize * sizeOfValues.size(); + return mapsize; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSize.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSize.java new file mode 100644 index 00000000000..f5367a2a6fd --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSize.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfHashMap; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfString; + +public class TargetMeanEncodingSize implements PreprocessorSize { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "target_mean_encoding_size", + false, + a -> new TargetMeanEncodingSize((Integer)a[0], (Integer)a[1], (List)a[2]) + ); + static { + PARSER.declareInt(constructorArg(), FIELD_LENGTH); + PARSER.declareInt(constructorArg(), FEATURE_NAME_LENGTH); + PARSER.declareIntArray(constructorArg(), FIELD_VALUE_LENGTHS); + } + + public static TargetMeanEncodingSize fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final int fieldLength; + private final int featureNameLength; + private final int[] fieldValueLengths; + + TargetMeanEncodingSize(int fieldLength, int featureNameLength, List fieldValueLengths) { + this.fieldLength = fieldLength; + this.featureNameLength = featureNameLength; + this.fieldValueLengths = fieldValueLengths.stream().mapToInt(Integer::intValue).toArray(); + } + + @Override + public long ramBytesUsed() { + final long sizeOfDoubleObject = shallowSizeOfInstance(Double.class); + long size = TargetMeanEncoding.SHALLOW_SIZE; + size += sizeOfString(fieldLength); + size += sizeOfString(featureNameLength); + size += sizeOfHashMap( + Arrays.stream(fieldValueLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()), + Stream.generate(() -> sizeOfDoubleObject).limit(fieldValueLengths.length).collect(Collectors.toList()) + ); + return alignObjectSize(size); + } + + @Override + public String getName() { + return TargetMeanEncoding.NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD_LENGTH.getPreferredName(), fieldLength); + builder.field(FEATURE_NAME_LENGTH.getPreferredName(), featureNameLength); + builder.field(FIELD_VALUE_LENGTHS.getPreferredName(), fieldValueLengths); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TargetMeanEncodingSize that = (TargetMeanEncodingSize) o; + return fieldLength == that.fieldLength && + featureNameLength == that.featureNameLength && + Arrays.equals(fieldValueLengths, that.fieldValueLengths); + } + + @Override + public int hashCode() { + int result = Objects.hash(fieldLength, featureNameLength); + result = 31 * result + Arrays.hashCode(fieldValueLengths); + return result; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TrainedModelSizeInfo.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TrainedModelSizeInfo.java new file mode 100644 index 00000000000..3aca41846c2 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TrainedModelSizeInfo.java @@ -0,0 +1,13 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.apache.lucene.util.Accountable; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + +interface TrainedModelSizeInfo extends Accountable, NamedXContentObject { +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfo.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfo.java new file mode 100644 index 00000000000..f61fe989654 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfo.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.apache.lucene.util.Accountable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel; + + +import java.io.IOException; +import java.util.Objects; + +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfDoubleArray; + +public class TreeSizeInfo implements Accountable, ToXContentObject { + + private static final ParseField NUM_NODES = new ParseField("num_nodes"); + private static final ParseField NUM_LEAVES = new ParseField("num_leaves"); + private static final ParseField NUM_CLASSES = new ParseField("num_classes"); + + static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tree_size", + false, + a -> new TreeSizeInfo((Integer)a[0], a[1] == null ? 0 : (Integer)a[1], a[2] == null ? 0 : (Integer)a[2]) + ); + static { + PARSER.declareInt(constructorArg(), NUM_LEAVES); + PARSER.declareInt(optionalConstructorArg(), NUM_NODES); + PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES); + } + + public static TreeSizeInfo fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final int numNodes; + private final int numLeaves; + private int numClasses; + + TreeSizeInfo(int numLeaves, int numNodes, int numClasses) { + this.numLeaves = numLeaves; + this.numNodes = numNodes; + this.numClasses = numClasses; + } + + public TreeSizeInfo setNumClasses(int numClasses) { + this.numClasses = numClasses; + return this; + } + + @Override + public long ramBytesUsed() { + long size = TreeInferenceModel.SHALLOW_SIZE; + // Node shallow sizes, covers most information as elements are primitive + size += NUM_BYTES_ARRAY_HEADER + ((numLeaves + numNodes) * NUM_BYTES_OBJECT_REF); + size += numLeaves * TreeInferenceModel.LeafNode.SHALLOW_SIZE; + size += numNodes * TreeInferenceModel.InnerNode.SHALLOW_SIZE; + // This handles the values within the leaf value array + int numLeafVals = numClasses <= 2 ? 1 : numClasses; + size += sizeOfDoubleArray(numLeafVals) * numLeaves; + return alignObjectSize(size); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(NUM_LEAVES.getPreferredName(), numLeaves); + builder.field(NUM_NODES.getPreferredName(), numNodes); + builder.field(NUM_CLASSES.getPreferredName(), numClasses); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TreeSizeInfo treeSizeInfo = (TreeSizeInfo) o; + return numNodes == treeSizeInfo.numNodes && + numLeaves == treeSizeInfo.numLeaves && + numClasses == treeSizeInfo.numClasses; + } + + @Override + public int hashCode() { + return Objects.hash(numNodes, numLeaves, numClasses); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index d8994fe3d8f..025c42e7c0e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -59,7 +59,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase { private static final String CONFIG_ID = "config-id"; private static final int NUM_ROWS = 100; private static final int NUM_COLS = 4; - private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null); + private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null, null); private Client client; private DataFrameAnalyticsAuditor auditor; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 530bf280aaf..9ae440fdeeb 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -106,8 +106,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { public void testProcess_GivenEmptyResults() { givenDataFrameRows(2); givenProcessResults(Arrays.asList( - new AnalyticsResult(null, null,null, null, null, null, null), - new AnalyticsResult(null, null, null, null, null, null, null))); + new AnalyticsResult(null, null, null,null, null, null, null, null), + new AnalyticsResult(null, null, null, null, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -122,8 +122,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { givenDataFrameRows(2); RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); - givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null,null, null, null, null, null), - new AnalyticsResult(rowResults2, null, null, null, null, null, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null), + new AnalyticsResult(rowResults2, null, null, null, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -140,8 +140,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { givenDataFrameRows(2); RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); - givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null,null, null, null, null, null), - new AnalyticsResult(rowResults2, null, null, null, null, null, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null), + new AnalyticsResult(rowResults2, null, null, null, null, null, null, null))); doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class)); @@ -175,7 +175,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { extractedFieldList.add(new DocValueField("baz", Collections.emptySet())); TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList); resultProcessor.process(process); @@ -239,7 +239,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java index 50bdaba060a..2bf72318295 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java @@ -24,6 +24,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; import java.util.ArrayList; import java.util.Collections; @@ -36,10 +39,10 @@ public class AnalyticsResultTests extends AbstractXContentTestCase namedXContent = new ArrayList<>(); namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers()); return new NamedXContentRegistry(namedXContent); } - @Override protected AnalyticsResult createTestInstance() { RowResults rowResults = null; PhaseProgress phaseProgress = null; @@ -48,6 +51,7 @@ public class AnalyticsResultTests extends AbstractXContentTestCase { + + static EnsembleSizeInfo createRandom() { + return new EnsembleSizeInfo( + Stream.generate(TreeSizeInfoTests::createRandom).limit(randomIntBetween(1, 100)).collect(Collectors.toList()), + randomIntBetween(1, 10000), + Stream.generate(() -> randomIntBetween(1, 10)).limit(randomIntBetween(1, 10)).collect(Collectors.toList()), + randomIntBetween(0, 10), + randomIntBetween(0, 10), + randomIntBetween(0, 10) + ); + } + + static EnsembleSizeInfo translateToEstimate(EnsembleInferenceModel ensemble) { + TreeInferenceModel tree = (TreeInferenceModel)ensemble.getModels().get(0); + int numClasses = Arrays.stream(tree.getNodes()) + .filter(TreeInferenceModel.Node::isLeaf) + .map(n -> (TreeInferenceModel.LeafNode)n) + .findFirst() + .get() + .getLeafValue() + .length; + return new EnsembleSizeInfo( + ensemble.getModels() + .stream() + .map(m -> TreeSizeInfoTests.translateToEstimate((TreeInferenceModel)m)) + .collect(Collectors.toList()), + randomIntBetween(0, 10), + Arrays.stream(ensemble.getFeatureNames()).map(String::length).collect(Collectors.toList()), + ensemble.getOutputAggregator().expectedValueSize() == null ? 0 : ensemble.getOutputAggregator().expectedValueSize(), + ensemble.getClassificationWeights() == null ? 0 : ensemble.getClassificationWeights().length, + numClasses); + } + + @Override + protected EnsembleSizeInfo createTestInstance() { + return createRandom(); + } + + @Override + protected EnsembleSizeInfo doParseInstance(XContentParser parser) { + return EnsembleSizeInfo.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + EnsembleInferenceModel generateTrueObject() { + try { + Ensemble model = EnsembleTests.createRandom(); + EnsembleInferenceModel inferenceModel = EnsembleInferenceModelTests.serializeFromTrainedModel(model); + inferenceModel.rewriteFeatureIndices(Collections.emptyMap()); + return inferenceModel; + } catch (IOException ex) { + throw new ElasticsearchException(ex); + } + } + + @Override + EnsembleSizeInfo translateObject(EnsembleInferenceModel originalObject) { + return translateToEstimate(originalObject); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSizeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSizeTests.java new file mode 100644 index 00000000000..b49384111a3 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSizeTests.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests; + +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class FrequencyEncodingSizeTests extends SizeEstimatorTestCase { + + static FrequencyEncodingSize createRandom() { + return new FrequencyEncodingSize(randomInt(100), + randomInt(100), + Stream.generate(() -> randomIntBetween(5, 10)) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList())); + } + + static FrequencyEncodingSize translateToEstimate(FrequencyEncoding encoding) { + return new FrequencyEncodingSize(encoding.getField().length(), + encoding.getFeatureName().length(), + encoding.getFrequencyMap().keySet().stream().map(String::length).collect(Collectors.toList())); + } + + @Override + protected FrequencyEncodingSize createTestInstance() { + return createRandom(); + } + + @Override + protected FrequencyEncodingSize doParseInstance(XContentParser parser) { + return FrequencyEncodingSize.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + FrequencyEncoding generateTrueObject() { + return FrequencyEncodingTests.createRandom(); + } + + @Override + FrequencyEncodingSize translateObject(FrequencyEncoding originalObject) { + return translateToEstimate(originalObject); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfoTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfoTests.java new file mode 100644 index 00000000000..dfb500c72e7 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfoTests.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class ModelSizeInfoTests extends AbstractXContentTestCase { + + public static ModelSizeInfo createRandom() { + return new ModelSizeInfo(EnsembleSizeInfoTests.createRandom(), + randomBoolean() ? + null : + Stream.generate(() -> randomFrom( + FrequencyEncodingSizeTests.createRandom(), + OneHotEncodingSizeTests.createRandom(), + TargetMeanEncodingSizeTests.createRandom())) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList())); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected ModelSizeInfo createTestInstance() { + return createRandom(); + } + + @Override + protected ModelSizeInfo doParseInstance(XContentParser parser) { + return ModelSizeInfo.PARSER.apply(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + public void testParseDescribedFormat() throws IOException { + XContentParser parser = XContentHelper.createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(FORMAT), + XContentType.JSON); + // Shouldn't throw + doParseInstance(parser); + } + + private static final String FORMAT = "" + + "{\n" + + " \"trained_model_size\": {\n" + + " \"ensemble_model_size\": {\n" + + " \"tree_sizes\": [\n" + + " {\"num_nodes\": 7, \"num_leaves\": 8},\n" + + " {\"num_nodes\": 3, \"num_leaves\": 4},\n" + + " {\"num_leaves\": 1}\n" + + " ],\n" + + " \"input_field_name_lengths\": [\n" + + " 14,\n" + + " 10,\n" + + " 11\n" + + " ],\n" + + " \"num_output_processor_weights\": 3,\n" + + " \"num_classification_weights\": 0,\n" + + " \"num_classes\": 0,\n" + + " \"num_operations\": 3\n" + + " }\n" + + " },\n" + + " \"preprocessors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field_length\": 10,\n" + + " \"field_value_lengths\": [\n" + + " 10,\n" + + " 20\n" + + " ],\n" + + " \"feature_name_lengths\": [\n" + + " 15,\n" + + " 25\n" + + " ]\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field_length\": 10,\n" + + " \"feature_name_length\": 5,\n" + + " \"field_value_lengths\": [\n" + + " 10,\n" + + " 20\n" + + " ]\n" + + " }\n" + + " },\n" + + " {\n" + + " \"target_mean_encoding\": {\n" + + " \"field_length\": 6,\n" + + " \"feature_name_length\": 15,\n" + + " \"field_value_lengths\": [\n" + + " 10,\n" + + " 20\n" + + " ]\n" + + " }\n" + + " }\n" + + " ]\n" + + "} "; +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSizeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSizeTests.java new file mode 100644 index 00000000000..3aba41c3525 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSizeTests.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests; + +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class OneHotEncodingSizeTests extends SizeEstimatorTestCase { + + static OneHotEncodingSize createRandom() { + int numFieldEntries = randomIntBetween(1, 10); + return new OneHotEncodingSize( + randomInt(100), + Stream.generate(() -> randomIntBetween(5, 10)) + .limit(numFieldEntries) + .collect(Collectors.toList()), + Stream.generate(() -> randomIntBetween(5, 10)) + .limit(numFieldEntries) + .collect(Collectors.toList())); + } + + static OneHotEncodingSize translateToEstimate(OneHotEncoding encoding) { + return new OneHotEncodingSize(encoding.getField().length(), + encoding.getHotMap().values().stream().map(String::length).collect(Collectors.toList()), + encoding.getHotMap().keySet().stream().map(String::length).collect(Collectors.toList())); + } + + @Override + protected OneHotEncodingSize createTestInstance() { + return createRandom(); + } + + @Override + protected OneHotEncodingSize doParseInstance(XContentParser parser) { + return OneHotEncodingSize.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + OneHotEncoding generateTrueObject() { + return OneHotEncodingTests.createRandom(); + } + + @Override + OneHotEncodingSize translateObject(OneHotEncoding originalObject) { + return translateToEstimate(originalObject); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorTestCase.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorTestCase.java new file mode 100644 index 00000000000..764157ec465 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorTestCase.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.apache.lucene.util.Accountable; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.test.AbstractXContentTestCase; + +import static org.hamcrest.Matchers.is; + +public abstract class SizeEstimatorTestCase + extends AbstractXContentTestCase { + + abstract U generateTrueObject(); + + abstract T translateObject(U originalObject); + + public void testRamUsageEstimationAccuracy() { + final long bytesEps = new ByteSizeValue(2, ByteSizeUnit.KB).getBytes(); + for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) { + U obj = generateTrueObject(); + T estimateObj = translateObject(obj); + long originalBytesUsed = obj.ramBytesUsed(); + long estimateBytesUsed = estimateObj.ramBytesUsed(); + // If we are over by 2kb that is small enough to not be a concern + boolean condition = (Math.abs(obj.ramBytesUsed() - estimateObj.ramBytesUsed()) < bytesEps) || + // If the difference is greater than 2kb, it is better to have overestimated. + originalBytesUsed < estimateBytesUsed; + assertThat("estimation difference greater than 2048 and the estimation is too small. Object [" + + obj.toString() + + "] estimated [" + + originalBytesUsed + + "] translated object [" + + estimateObj + + "] estimated [" + + estimateBytesUsed + + "]" , + condition, + is(true)); + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSizeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSizeTests.java new file mode 100644 index 00000000000..fede92ce51c --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSizeTests.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; + +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class TargetMeanEncodingSizeTests extends SizeEstimatorTestCase { + + static TargetMeanEncodingSize createRandom() { + return new TargetMeanEncodingSize(randomInt(100), + randomInt(100), + Stream.generate(() -> randomIntBetween(5, 10)) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList())); + } + + static TargetMeanEncodingSize translateToEstimate(TargetMeanEncoding encoding) { + return new TargetMeanEncodingSize(encoding.getField().length(), + encoding.getFeatureName().length(), + encoding.getMeanMap().keySet().stream().map(String::length).collect(Collectors.toList())); + } + + @Override + protected TargetMeanEncodingSize createTestInstance() { + return createRandom(); + } + + @Override + protected TargetMeanEncodingSize doParseInstance(XContentParser parser) { + return TargetMeanEncodingSize.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + TargetMeanEncoding generateTrueObject() { + return TargetMeanEncodingTests.createRandom(); + } + + @Override + TargetMeanEncodingSize translateObject(TargetMeanEncoding originalObject) { + return translateToEstimate(originalObject); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfoTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfoTests.java new file mode 100644 index 00000000000..cc0b2831b3c --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfoTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModelTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; + +import java.io.IOException; +import java.util.Arrays; + + +public class TreeSizeInfoTests extends SizeEstimatorTestCase { + + static TreeSizeInfo createRandom() { + return new TreeSizeInfo(randomIntBetween(1, 100), randomIntBetween(0, 100), randomIntBetween(0, 10)); + } + + static TreeSizeInfo translateToEstimate(TreeInferenceModel tree) { + int numClasses = Arrays.stream(tree.getNodes()) + .filter(TreeInferenceModel.Node::isLeaf) + .map(n -> (TreeInferenceModel.LeafNode)n) + .findFirst() + .get() + .getLeafValue() + .length; + return new TreeSizeInfo((int)Arrays.stream(tree.getNodes()).filter(TreeInferenceModel.Node::isLeaf).count(), + (int)Arrays.stream(tree.getNodes()).filter(t -> t.isLeaf() == false).count(), + numClasses); + } + + @Override + protected TreeSizeInfo createTestInstance() { + return createRandom(); + } + + @Override + protected TreeSizeInfo doParseInstance(XContentParser parser) { + return TreeSizeInfo.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + TreeInferenceModel generateTrueObject() { + try { + return TreeInferenceModelTests.serializeFromTrainedModel( + TreeTests.buildRandomTree(Arrays.asList(randomAlphaOfLength(10), randomAlphaOfLength(10)), 6) + ); + } catch (IOException ex) { + throw new ElasticsearchException(ex); + } + } + + @Override + TreeSizeInfo translateObject(TreeInferenceModel originalObject) { + return translateToEstimate(originalObject); + } +}