diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java index 70047744415..eb8d148ad1e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java @@ -27,7 +27,7 @@ import java.util.Objects; */ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FrequencyEncoding.class); + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FrequencyEncoding.class); public static final ParseField NAME = new ParseField("frequency_encoding"); public static final ParseField FIELD = new ParseField("field"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java index cd92a6fe22a..f4f98af125b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java @@ -27,7 +27,7 @@ import java.util.stream.Collectors; */ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(OneHotEncoding.class); + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(OneHotEncoding.class); public static final ParseField NAME = new ParseField("one_hot_encoding"); public static final ParseField FIELD = new ParseField("field"); public static final ParseField HOT_MAP = new ParseField("hot_map"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java index 3ae12e207f1..aa6495848a6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java @@ -27,7 +27,7 @@ import java.util.Objects; */ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TargetMeanEncoding.class); + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TargetMeanEncoding.class); public static final ParseField NAME = new ParseField("target_mean_encoding"); public static final ParseField FIELD = new ParseField("field"); public static final ParseField FEATURE_NAME = new ParseField("feature_name"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java index ccd6adbc9a3..d4995b4b24c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java @@ -25,7 +25,7 @@ import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class); + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class); public static final ParseField NAME = new ParseField("logistic_regression"); public static final ParseField WEIGHTS = new ParseField("weights"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java index 4c2ffc5dc25..29bd8e4579f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java @@ -47,7 +47,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.En public class EnsembleInferenceModel implements InferenceModel { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class); + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -275,4 +275,21 @@ public class EnsembleInferenceModel implements InferenceModel { size += outputAggregator.ramBytesUsed(); return size; } + + public List getModels() { + return models; + } + + public OutputAggregator getOutputAggregator() { + return outputAggregator; + } + + public TargetType getTargetType() { + return targetType; + } + + public double[] getClassificationWeights() { + return classificationWeights; + } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java index d218854c355..111f0785f7a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java @@ -26,7 +26,7 @@ import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition.T public class InferenceDefinition { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InferenceDefinition.class); + public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InferenceDefinition.class); public static final String NAME = "inference_model_definition"; private final InferenceModel trainedModel; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java index 7c8aed0d016..68bbbf84a20 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.Numbers; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -27,13 +26,15 @@ import org.elasticsearch.xpack.core.ml.inference.utils.Statistics; import org.elasticsearch.xpack.core.ml.job.config.Operator; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; +import static org.apache.lucene.util.RamUsageEstimator.sizeOf; +import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; @@ -53,7 +54,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNo public class TreeInferenceModel implements InferenceModel { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TreeInferenceModel.class); + public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -75,7 +76,7 @@ public class TreeInferenceModel implements InferenceModel { private final Node[] nodes; private String[] featureNames; private final TargetType targetType; - private final List classificationLabels; + private List classificationLabels; private final double highOrderCategory; private final int maxDepth; private final int leafSize; @@ -302,18 +303,16 @@ public class TreeInferenceModel implements InferenceModel { treeNode.splitFeature = newSplitFeatureIndex; } this.featureNames = new String[0]; + // Since we are not top level, we no longer need local classification labels + this.classificationLabels = null; } @Override public long ramBytesUsed() { long size = SHALLOW_SIZE; - size += RamUsageEstimator.sizeOfCollection(classificationLabels); - size += RamUsageEstimator.sizeOf(featureNames); - size += RamUsageEstimator.shallowSizeOf(nodes); - for (Node node : nodes) { - size += node.ramBytesUsed(); - } - size += RamUsageEstimator.sizeOfCollection(Arrays.asList(nodes)); + size += sizeOfCollection(classificationLabels); + size += sizeOf(featureNames); + size += sizeOf(nodes); return size; } @@ -335,6 +334,10 @@ public class TreeInferenceModel implements InferenceModel { return max; } + public Node[] getNodes() { + return nodes; + } + private static int getDepth(Node[] nodes, int nodeIndex, int depth) { Node node = nodes[nodeIndex]; if (node instanceof LeafNode) { @@ -433,21 +436,21 @@ public class TreeInferenceModel implements InferenceModel { } } - private abstract static class Node implements Accountable { + public abstract static class Node implements Accountable { int compare(double[] features) { throw new IllegalArgumentException("cannot call compare against a leaf node."); } abstract long getNumberSamples(); - boolean isLeaf() { + public boolean isLeaf() { return this instanceof LeafNode; } } - private static class InnerNode extends Node { + public static class InnerNode extends Node { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InnerNode.class); + public static final long SHALLOW_SIZE = shallowSizeOfInstance(InnerNode.class); private final Operator operator; private final double threshold; @@ -498,8 +501,8 @@ public class TreeInferenceModel implements InferenceModel { } } - private static class LeafNode extends Node { - private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LeafNode.class); + public static class LeafNode extends Node { + public static final long SHALLOW_SIZE = shallowSizeOfInstance(LeafNode.class); private final double[] leafValue; private final long numberSamples; @@ -510,12 +513,16 @@ public class TreeInferenceModel implements InferenceModel { @Override public long ramBytesUsed() { - return SHALLOW_SIZE; + return SHALLOW_SIZE + sizeOf(leafValue); } @Override long getNumberSamples() { return numberSamples; } + + public double[] getLeafValue() { + return leafValue; + } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 6ddbcc7a1c5..9b1c2088aee 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -91,7 +91,7 @@ public class EnsembleTests extends AbstractSerializingTestCase { .toArray() : null; - return new Ensemble(randomBoolean() ? featureNames : Collections.emptyList(), + return new Ensemble(featureNames, models, outputAggregator, targetType, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java index d50980f0696..a337d7d2f98 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java @@ -28,6 +28,7 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; @@ -35,6 +36,13 @@ public class TreeInferenceModelTests extends ESTestCase { private final double eps = 1.0E-8; + public static TreeInferenceModel serializeFromTrainedModel(Tree tree) throws IOException { + NamedXContentRegistry registry = new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + return deserializeFromTrainedModel(tree, + registry, + TreeInferenceModel::fromXContent); + } + @Override protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); @@ -48,7 +56,7 @@ public class TreeInferenceModelTests extends ESTestCase { builder.setFeatureNames(Collections.emptyList()); Tree treeObject = builder.build(); - TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject, + TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); List featureNames = Arrays.asList("foo", "bar"); @@ -71,7 +79,7 @@ public class TreeInferenceModelTests extends ESTestCase { List featureNames = Arrays.asList("foo", "bar"); Tree treeObject = builder.setFeatureNames(featureNames).build(); - TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject, + TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); // This feature vector should hit the right child of the root node @@ -126,7 +134,7 @@ public class TreeInferenceModelTests extends ESTestCase { List featureNames = Arrays.asList("foo", "bar"); Tree treeObject = builder.setFeatureNames(featureNames).setClassificationLabels(Arrays.asList("cat", "dog")).build(); - TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject, + TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); double eps = 0.000001; @@ -200,7 +208,7 @@ public class TreeInferenceModelTests extends ESTestCase { TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L), TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build(); - TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject, + TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 1372db30e5b..aca1dfc6de3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -219,6 +219,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimatio import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; +import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; @@ -1003,6 +1004,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers()); return namedXContent; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java index 5b56ebfdc9a..0a05f13c118 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierD import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import java.io.IOException; import java.util.Collections; @@ -31,6 +32,7 @@ public class AnalyticsResult implements ToXContentObject { private static final ParseField PHASE_PROGRESS = new ParseField("phase_progress"); private static final ParseField INFERENCE_MODEL = new ParseField("inference_model"); + private static final ParseField MODEL_SIZE_INFO = new ParseField("model_size_info"); private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage"); private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats"); private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats"); @@ -44,7 +46,8 @@ public class AnalyticsResult implements ToXContentObject { (MemoryUsage) a[3], (OutlierDetectionStats) a[4], (ClassificationStats) a[5], - (RegressionStats) a[6] + (RegressionStats) a[6], + (ModelSizeInfo) a[7] )); static { @@ -56,6 +59,7 @@ public class AnalyticsResult implements ToXContentObject { PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS); PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS); PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS); + PARSER.declareObject(optionalConstructorArg(), ModelSizeInfo.PARSER, MODEL_SIZE_INFO); } private final RowResults rowResults; @@ -66,6 +70,7 @@ public class AnalyticsResult implements ToXContentObject { private final OutlierDetectionStats outlierDetectionStats; private final ClassificationStats classificationStats; private final RegressionStats regressionStats; + private final ModelSizeInfo modelSizeInfo; public AnalyticsResult(@Nullable RowResults rowResults, @Nullable PhaseProgress phaseProgress, @@ -73,7 +78,8 @@ public class AnalyticsResult implements ToXContentObject { @Nullable MemoryUsage memoryUsage, @Nullable OutlierDetectionStats outlierDetectionStats, @Nullable ClassificationStats classificationStats, - @Nullable RegressionStats regressionStats) { + @Nullable RegressionStats regressionStats, + @Nullable ModelSizeInfo modelSizeInfo) { this.rowResults = rowResults; this.phaseProgress = phaseProgress; this.inferenceModelBuilder = inferenceModelBuilder; @@ -82,6 +88,7 @@ public class AnalyticsResult implements ToXContentObject { this.outlierDetectionStats = outlierDetectionStats; this.classificationStats = classificationStats; this.regressionStats = regressionStats; + this.modelSizeInfo = modelSizeInfo; } public RowResults getRowResults() { @@ -112,6 +119,10 @@ public class AnalyticsResult implements ToXContentObject { return regressionStats; } + public ModelSizeInfo getModelSizeInfo() { + return modelSizeInfo; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -138,6 +149,9 @@ public class AnalyticsResult implements ToXContentObject { if (regressionStats != null) { builder.field(REGRESSION_STATS.getPreferredName(), regressionStats, params); } + if (modelSizeInfo != null) { + builder.field(MODEL_SIZE_INFO.getPreferredName(), modelSizeInfo); + } builder.endObject(); return builder; } @@ -158,6 +172,7 @@ public class AnalyticsResult implements ToXContentObject { && Objects.equals(memoryUsage, that.memoryUsage) && Objects.equals(outlierDetectionStats, that.outlierDetectionStats) && Objects.equals(classificationStats, that.classificationStats) + && Objects.equals(modelSizeInfo, that.modelSizeInfo) && Objects.equals(regressionStats, that.regressionStats); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/EnsembleSizeInfo.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/EnsembleSizeInfo.java new file mode 100644 index 00000000000..c6c2a2d818b --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/EnsembleSizeInfo.java @@ -0,0 +1,136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModel; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfDoubleArray; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfStringCollection; + +public class EnsembleSizeInfo implements TrainedModelSizeInfo { + + public static final ParseField NAME = new ParseField("ensemble_model_size"); + private static final ParseField TREE_SIZES = new ParseField("tree_sizes"); + private static final ParseField INPUT_FIELD_NAME_LENGHTS = new ParseField("input_field_name_lengths"); + private static final ParseField NUM_OUTPUT_PROCESSOR_WEIGHTS = new ParseField("num_output_processor_weights"); + private static final ParseField NUM_CLASSIFICATION_WEIGHTS = new ParseField("num_classification_weights"); + private static final ParseField NUM_OPERATIONS = new ParseField("num_operations"); + private static final ParseField NUM_CLASSES = new ParseField("num_classes"); + + @SuppressWarnings("unchecked") + static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "ensemble_size", + false, + a -> new EnsembleSizeInfo((List)a[0], + (Integer)a[1], + (List)a[2], + a[3] == null ? 0 : (Integer)a[3], + a[4] == null ? 0 : (Integer)a[4], + a[5] == null ? 0 : (Integer)a[5]) + ); + static { + PARSER.declareObjectArray(constructorArg(), TreeSizeInfo.PARSER::apply, TREE_SIZES); + PARSER.declareInt(constructorArg(), NUM_OPERATIONS); + PARSER.declareIntArray(constructorArg(), INPUT_FIELD_NAME_LENGHTS); + PARSER.declareInt(optionalConstructorArg(), NUM_OUTPUT_PROCESSOR_WEIGHTS); + PARSER.declareInt(optionalConstructorArg(), NUM_CLASSIFICATION_WEIGHTS); + PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES); + } + + public static EnsembleSizeInfo fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + + private final List treeSizeInfos; + private final int numOperations; + private final int[] inputFieldNameLengths; + private final int numOutputProcessorWeights; + private final int numClassificationWeights; + private final int numClasses; + + public EnsembleSizeInfo(List treeSizeInfos, + int numOperations, + List inputFieldNameLengths, + int numOutputProcessorWeights, + int numClassificationWeights, + int numClasses) { + this.treeSizeInfos = treeSizeInfos; + this.numOperations = numOperations; + this.inputFieldNameLengths = inputFieldNameLengths.stream().mapToInt(Integer::intValue).toArray(); + this.numOutputProcessorWeights = numOutputProcessorWeights; + this.numClassificationWeights = numClassificationWeights; + this.numClasses = numClasses; + } + + public int getNumOperations() { + return numOperations; + } + + @Override + public long ramBytesUsed() { + long size = EnsembleInferenceModel.SHALLOW_SIZE; + treeSizeInfos.forEach(t -> t.setNumClasses(numClasses).ramBytesUsed()); + size += sizeOfCollection(treeSizeInfos); + size += sizeOfStringCollection(inputFieldNameLengths); + size += LogisticRegression.SHALLOW_SIZE + sizeOfDoubleArray(numOutputProcessorWeights); + size += sizeOfDoubleArray(numClassificationWeights); + return alignObjectSize(size); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TREE_SIZES.getPreferredName(), treeSizeInfos); + builder.field(NUM_OPERATIONS.getPreferredName(), numOperations); + builder.field(NUM_CLASSES.getPreferredName(), numClasses); + builder.field(INPUT_FIELD_NAME_LENGHTS.getPreferredName(), inputFieldNameLengths); + builder.field(NUM_CLASSIFICATION_WEIGHTS.getPreferredName(), numClassificationWeights); + builder.field(NUM_OUTPUT_PROCESSOR_WEIGHTS.getPreferredName(), numOutputProcessorWeights); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + EnsembleSizeInfo that = (EnsembleSizeInfo) o; + return numOperations == that.numOperations && + numOutputProcessorWeights == that.numOutputProcessorWeights && + numClassificationWeights == that.numClassificationWeights && + numClasses == that.numClasses && + Objects.equals(treeSizeInfos, that.treeSizeInfos) && + Arrays.equals(inputFieldNameLengths, that.inputFieldNameLengths); + } + + @Override + public int hashCode() { + int result = Objects.hash(treeSizeInfos, numOperations, numOutputProcessorWeights, numClassificationWeights, numClasses); + result = 31 * result + Arrays.hashCode(inputFieldNameLengths); + return result; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSize.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSize.java new file mode 100644 index 00000000000..f12fcc4733a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSize.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfHashMap; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfString; + +public class FrequencyEncodingSize implements PreprocessorSize { + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "frequency_encoding_size", + false, + a -> new FrequencyEncodingSize((Integer)a[0], (Integer)a[1], (List)a[2]) + ); + static { + PARSER.declareInt(constructorArg(), FIELD_LENGTH); + PARSER.declareInt(constructorArg(), FEATURE_NAME_LENGTH); + PARSER.declareIntArray(constructorArg(), FIELD_VALUE_LENGTHS); + } + + public static FrequencyEncodingSize fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final int fieldLength; + private final int featureNameLength; + private final int[] fieldValueLengths; + + FrequencyEncodingSize(int fieldLength, int featureNameLength, List fieldValueLengths) { + this.fieldLength = fieldLength; + this.featureNameLength = featureNameLength; + this.fieldValueLengths = fieldValueLengths.stream().mapToInt(Integer::intValue).toArray(); + } + + @Override + public long ramBytesUsed() { + final long sizeOfDoubleObject = shallowSizeOfInstance(Double.class); + long size = FrequencyEncoding.SHALLOW_SIZE; + size += sizeOfString(fieldLength); + size += sizeOfString(featureNameLength); + size += sizeOfHashMap( + Arrays.stream(fieldValueLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()), + Stream.generate(() -> sizeOfDoubleObject).limit(fieldValueLengths.length).collect(Collectors.toList())); + return alignObjectSize(size); + } + + @Override + public String getName() { + return FrequencyEncoding.NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD_LENGTH.getPreferredName(), fieldLength); + builder.field(FEATURE_NAME_LENGTH.getPreferredName(), featureNameLength); + builder.field(FIELD_VALUE_LENGTHS.getPreferredName(), fieldValueLengths); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FrequencyEncodingSize that = (FrequencyEncodingSize) o; + return fieldLength == that.fieldLength && + featureNameLength == that.featureNameLength && + Arrays.equals(fieldValueLengths, that.fieldValueLengths); + } + + @Override + public int hashCode() { + int result = Objects.hash(fieldLength, featureNameLength); + result = 31 * result + Arrays.hashCode(fieldValueLengths); + return result; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/MlModelSizeNamedXContentProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/MlModelSizeNamedXContentProvider.java new file mode 100644 index 00000000000..7274940a500 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/MlModelSizeNamedXContentProvider.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding; + +import java.util.Arrays; +import java.util.List; + +public class MlModelSizeNamedXContentProvider implements NamedXContentProvider { + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + new NamedXContentRegistry.Entry(PreprocessorSize.class, + FrequencyEncoding.NAME, + FrequencyEncodingSize::fromXContent), + new NamedXContentRegistry.Entry(PreprocessorSize.class, + OneHotEncoding.NAME, + OneHotEncodingSize::fromXContent), + new NamedXContentRegistry.Entry(PreprocessorSize.class, + TargetMeanEncoding.NAME, + TargetMeanEncodingSize::fromXContent), + new NamedXContentRegistry.Entry(TrainedModelSizeInfo.class, + EnsembleSizeInfo.NAME, + EnsembleSizeInfo::fromXContent) + ); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfo.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfo.java new file mode 100644 index 00000000000..449fc28f71f --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfo.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.apache.lucene.util.Accountable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class ModelSizeInfo implements Accountable, ToXContentObject { + + private static final ParseField PREPROCESSORS = new ParseField("preprocessors"); + private static final ParseField TRAINED_MODEL_SIZE = new ParseField("trained_model_size"); + + @SuppressWarnings("unchecked") + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "model_size", + false, + a -> new ModelSizeInfo((EnsembleSizeInfo)a[0], (List)a[1]) + ); + static { + PARSER.declareNamedObject(constructorArg(), + (p, c, n) -> p.namedObject(TrainedModelSizeInfo.class, n, null), + TRAINED_MODEL_SIZE); + PARSER.declareNamedObjects(optionalConstructorArg(), + (p, c, n) -> p.namedObject(PreprocessorSize.class, n, null), + (val) -> {}, + PREPROCESSORS); + } + + private final EnsembleSizeInfo ensembleSizeInfo; + private final List preprocessorSizes; + + public ModelSizeInfo(EnsembleSizeInfo ensembleSizeInfo, List preprocessorSizes) { + this.ensembleSizeInfo = ensembleSizeInfo; + this.preprocessorSizes = preprocessorSizes == null ? Collections.emptyList() : preprocessorSizes; + } + + public int numOperations() { + return this.preprocessorSizes.size() + this.ensembleSizeInfo.getNumOperations(); + } + + @Override + public long ramBytesUsed() { + long size = InferenceDefinition.SHALLOW_SIZE; + size += ensembleSizeInfo.ramBytesUsed(); + size += preprocessorSizes.stream().mapToLong(PreprocessorSize::ramBytesUsed).sum(); + return alignObjectSize(size); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + NamedXContentObjectHelper.writeNamedObject(builder, params, TRAINED_MODEL_SIZE.getPreferredName(), ensembleSizeInfo); + if (preprocessorSizes.size() > 0) { + NamedXContentObjectHelper.writeNamedObjects(builder, params, true, PREPROCESSORS.getPreferredName(), preprocessorSizes); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ModelSizeInfo modelSizeInfo = (ModelSizeInfo) o; + return Objects.equals(ensembleSizeInfo, modelSizeInfo.ensembleSizeInfo) && + Objects.equals(preprocessorSizes, modelSizeInfo.preprocessorSizes); + } + + @Override + public int hashCode() { + return Objects.hash(ensembleSizeInfo, preprocessorSizes); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSize.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSize.java new file mode 100644 index 00000000000..0255075c45b --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSize.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfHashMap; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfString; + +public class OneHotEncodingSize implements PreprocessorSize { + + private static final ParseField FEATURE_NAME_LENGTHS = new ParseField("feature_name_lengths"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "one_hot_encoding_size", + false, + a -> new OneHotEncodingSize((Integer)a[0], (List)a[1], (List)a[2]) + ); + static { + PARSER.declareInt(constructorArg(), FIELD_LENGTH); + PARSER.declareIntArray(constructorArg(), FEATURE_NAME_LENGTHS); + PARSER.declareIntArray(constructorArg(), FIELD_VALUE_LENGTHS); + } + + public static OneHotEncodingSize fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final int fieldLength; + private final int[] featureNameLengths; + private final int[] fieldValueLengths; + + OneHotEncodingSize(int fieldLength, List featureNameLengths, List fieldValueLengths) { + assert featureNameLengths.size() == fieldValueLengths.size(); + this.fieldLength = fieldLength; + this.featureNameLengths = featureNameLengths.stream().mapToInt(Integer::intValue).toArray(); + this.fieldValueLengths = fieldValueLengths.stream().mapToInt(Integer::intValue).toArray(); + } + + @Override + public long ramBytesUsed() { + long size = OneHotEncoding.SHALLOW_SIZE; + size += sizeOfString(fieldLength); + size += sizeOfHashMap( + Arrays.stream(fieldValueLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()), + Arrays.stream(featureNameLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()) + ); + return alignObjectSize(size); + } + + @Override + public String getName() { + return OneHotEncoding.NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD_LENGTH.getPreferredName(), fieldLength); + builder.field(FEATURE_NAME_LENGTHS.getPreferredName(), featureNameLengths); + builder.field(FIELD_VALUE_LENGTHS.getPreferredName(), fieldValueLengths); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OneHotEncodingSize that = (OneHotEncodingSize) o; + return fieldLength == that.fieldLength && + Arrays.equals(featureNameLengths, that.featureNameLengths) && + Arrays.equals(fieldValueLengths, that.fieldValueLengths); + } + + @Override + public int hashCode() { + int result = Objects.hash(fieldLength); + result = 31 * result + Arrays.hashCode(featureNameLengths); + result = 31 * result + Arrays.hashCode(fieldValueLengths); + return result; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/PreprocessorSize.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/PreprocessorSize.java new file mode 100644 index 00000000000..dfbe10a0464 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/PreprocessorSize.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.apache.lucene.util.Accountable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + + +public interface PreprocessorSize extends Accountable, NamedXContentObject { + ParseField FIELD_LENGTH = new ParseField("field_length"); + ParseField FEATURE_NAME_LENGTH = new ParseField("feature_name_length"); + ParseField FIELD_VALUE_LENGTHS = new ParseField("field_value_lengths"); + + String getName(); +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorHelper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorHelper.java new file mode 100644 index 00000000000..0449d3c6523 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorHelper.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.stream.Stream; + +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; + +final class SizeEstimatorHelper { + + private SizeEstimatorHelper() {} + + private static final int STRING_SIZE = (int) shallowSizeOfInstance(String.class); + + static long sizeOfString(int stringLength) { + // Technically, each value counted in a String.length is 2 bytes. But, this is how `RamUsageEstimator` calculates it + return alignObjectSize(STRING_SIZE + (long)NUM_BYTES_ARRAY_HEADER + (long)(Character.BYTES) * stringLength); + } + + static long sizeOfStringCollection(int[] stringSizes) { + long shallow = alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) NUM_BYTES_OBJECT_REF * stringSizes.length); + return shallow + Arrays.stream(stringSizes).mapToLong(SizeEstimatorHelper::sizeOfString).sum(); + } + + static long sizeOfDoubleArray(int arrayLength) { + return alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) Double.BYTES * arrayLength); + } + + static long sizeOfHashMap(List sizeOfKeys, List sizeOfValues) { + assert sizeOfKeys.size() == sizeOfValues.size(); + long mapsize = shallowSizeOfInstance(HashMap.class); + final long mapEntrySize = shallowSizeOfInstance(HashMap.Entry.class); + mapsize += Stream.concat(sizeOfKeys.stream(), sizeOfValues.stream()).mapToLong(Long::longValue).sum(); + mapsize += mapEntrySize * sizeOfKeys.size(); + mapsize += mapEntrySize * sizeOfValues.size(); + return mapsize; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSize.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSize.java new file mode 100644 index 00000000000..f5367a2a6fd --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSize.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfHashMap; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfString; + +public class TargetMeanEncodingSize implements PreprocessorSize { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "target_mean_encoding_size", + false, + a -> new TargetMeanEncodingSize((Integer)a[0], (Integer)a[1], (List)a[2]) + ); + static { + PARSER.declareInt(constructorArg(), FIELD_LENGTH); + PARSER.declareInt(constructorArg(), FEATURE_NAME_LENGTH); + PARSER.declareIntArray(constructorArg(), FIELD_VALUE_LENGTHS); + } + + public static TargetMeanEncodingSize fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final int fieldLength; + private final int featureNameLength; + private final int[] fieldValueLengths; + + TargetMeanEncodingSize(int fieldLength, int featureNameLength, List fieldValueLengths) { + this.fieldLength = fieldLength; + this.featureNameLength = featureNameLength; + this.fieldValueLengths = fieldValueLengths.stream().mapToInt(Integer::intValue).toArray(); + } + + @Override + public long ramBytesUsed() { + final long sizeOfDoubleObject = shallowSizeOfInstance(Double.class); + long size = TargetMeanEncoding.SHALLOW_SIZE; + size += sizeOfString(fieldLength); + size += sizeOfString(featureNameLength); + size += sizeOfHashMap( + Arrays.stream(fieldValueLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()), + Stream.generate(() -> sizeOfDoubleObject).limit(fieldValueLengths.length).collect(Collectors.toList()) + ); + return alignObjectSize(size); + } + + @Override + public String getName() { + return TargetMeanEncoding.NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD_LENGTH.getPreferredName(), fieldLength); + builder.field(FEATURE_NAME_LENGTH.getPreferredName(), featureNameLength); + builder.field(FIELD_VALUE_LENGTHS.getPreferredName(), fieldValueLengths); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TargetMeanEncodingSize that = (TargetMeanEncodingSize) o; + return fieldLength == that.fieldLength && + featureNameLength == that.featureNameLength && + Arrays.equals(fieldValueLengths, that.fieldValueLengths); + } + + @Override + public int hashCode() { + int result = Objects.hash(fieldLength, featureNameLength); + result = 31 * result + Arrays.hashCode(fieldValueLengths); + return result; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TrainedModelSizeInfo.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TrainedModelSizeInfo.java new file mode 100644 index 00000000000..3aca41846c2 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TrainedModelSizeInfo.java @@ -0,0 +1,13 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.apache.lucene.util.Accountable; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + +interface TrainedModelSizeInfo extends Accountable, NamedXContentObject { +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfo.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfo.java new file mode 100644 index 00000000000..f61fe989654 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfo.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.apache.lucene.util.Accountable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel; + + +import java.io.IOException; +import java.util.Objects; + +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; +import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfDoubleArray; + +public class TreeSizeInfo implements Accountable, ToXContentObject { + + private static final ParseField NUM_NODES = new ParseField("num_nodes"); + private static final ParseField NUM_LEAVES = new ParseField("num_leaves"); + private static final ParseField NUM_CLASSES = new ParseField("num_classes"); + + static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tree_size", + false, + a -> new TreeSizeInfo((Integer)a[0], a[1] == null ? 0 : (Integer)a[1], a[2] == null ? 0 : (Integer)a[2]) + ); + static { + PARSER.declareInt(constructorArg(), NUM_LEAVES); + PARSER.declareInt(optionalConstructorArg(), NUM_NODES); + PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES); + } + + public static TreeSizeInfo fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final int numNodes; + private final int numLeaves; + private int numClasses; + + TreeSizeInfo(int numLeaves, int numNodes, int numClasses) { + this.numLeaves = numLeaves; + this.numNodes = numNodes; + this.numClasses = numClasses; + } + + public TreeSizeInfo setNumClasses(int numClasses) { + this.numClasses = numClasses; + return this; + } + + @Override + public long ramBytesUsed() { + long size = TreeInferenceModel.SHALLOW_SIZE; + // Node shallow sizes, covers most information as elements are primitive + size += NUM_BYTES_ARRAY_HEADER + ((numLeaves + numNodes) * NUM_BYTES_OBJECT_REF); + size += numLeaves * TreeInferenceModel.LeafNode.SHALLOW_SIZE; + size += numNodes * TreeInferenceModel.InnerNode.SHALLOW_SIZE; + // This handles the values within the leaf value array + int numLeafVals = numClasses <= 2 ? 1 : numClasses; + size += sizeOfDoubleArray(numLeafVals) * numLeaves; + return alignObjectSize(size); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(NUM_LEAVES.getPreferredName(), numLeaves); + builder.field(NUM_NODES.getPreferredName(), numNodes); + builder.field(NUM_CLASSES.getPreferredName(), numClasses); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TreeSizeInfo treeSizeInfo = (TreeSizeInfo) o; + return numNodes == treeSizeInfo.numNodes && + numLeaves == treeSizeInfo.numLeaves && + numClasses == treeSizeInfo.numClasses; + } + + @Override + public int hashCode() { + return Objects.hash(numNodes, numLeaves, numClasses); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index d8994fe3d8f..025c42e7c0e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -59,7 +59,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase { private static final String CONFIG_ID = "config-id"; private static final int NUM_ROWS = 100; private static final int NUM_COLS = 4; - private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null); + private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null, null); private Client client; private DataFrameAnalyticsAuditor auditor; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 530bf280aaf..9ae440fdeeb 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -106,8 +106,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { public void testProcess_GivenEmptyResults() { givenDataFrameRows(2); givenProcessResults(Arrays.asList( - new AnalyticsResult(null, null,null, null, null, null, null), - new AnalyticsResult(null, null, null, null, null, null, null))); + new AnalyticsResult(null, null, null,null, null, null, null, null), + new AnalyticsResult(null, null, null, null, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -122,8 +122,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { givenDataFrameRows(2); RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); - givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null,null, null, null, null, null), - new AnalyticsResult(rowResults2, null, null, null, null, null, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null), + new AnalyticsResult(rowResults2, null, null, null, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -140,8 +140,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { givenDataFrameRows(2); RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); - givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null,null, null, null, null, null), - new AnalyticsResult(rowResults2, null, null, null, null, null, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null), + new AnalyticsResult(rowResults2, null, null, null, null, null, null, null))); doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class)); @@ -175,7 +175,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { extractedFieldList.add(new DocValueField("baz", Collections.emptySet())); TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList); resultProcessor.process(process); @@ -239,7 +239,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java index 50bdaba060a..2bf72318295 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java @@ -24,6 +24,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; import java.util.ArrayList; import java.util.Collections; @@ -36,10 +39,10 @@ public class AnalyticsResultTests extends AbstractXContentTestCase namedXContent = new ArrayList<>(); namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers()); return new NamedXContentRegistry(namedXContent); } - @Override protected AnalyticsResult createTestInstance() { RowResults rowResults = null; PhaseProgress phaseProgress = null; @@ -48,6 +51,7 @@ public class AnalyticsResultTests extends AbstractXContentTestCase { + + static EnsembleSizeInfo createRandom() { + return new EnsembleSizeInfo( + Stream.generate(TreeSizeInfoTests::createRandom).limit(randomIntBetween(1, 100)).collect(Collectors.toList()), + randomIntBetween(1, 10000), + Stream.generate(() -> randomIntBetween(1, 10)).limit(randomIntBetween(1, 10)).collect(Collectors.toList()), + randomIntBetween(0, 10), + randomIntBetween(0, 10), + randomIntBetween(0, 10) + ); + } + + static EnsembleSizeInfo translateToEstimate(EnsembleInferenceModel ensemble) { + TreeInferenceModel tree = (TreeInferenceModel)ensemble.getModels().get(0); + int numClasses = Arrays.stream(tree.getNodes()) + .filter(TreeInferenceModel.Node::isLeaf) + .map(n -> (TreeInferenceModel.LeafNode)n) + .findFirst() + .get() + .getLeafValue() + .length; + return new EnsembleSizeInfo( + ensemble.getModels() + .stream() + .map(m -> TreeSizeInfoTests.translateToEstimate((TreeInferenceModel)m)) + .collect(Collectors.toList()), + randomIntBetween(0, 10), + Arrays.stream(ensemble.getFeatureNames()).map(String::length).collect(Collectors.toList()), + ensemble.getOutputAggregator().expectedValueSize() == null ? 0 : ensemble.getOutputAggregator().expectedValueSize(), + ensemble.getClassificationWeights() == null ? 0 : ensemble.getClassificationWeights().length, + numClasses); + } + + @Override + protected EnsembleSizeInfo createTestInstance() { + return createRandom(); + } + + @Override + protected EnsembleSizeInfo doParseInstance(XContentParser parser) { + return EnsembleSizeInfo.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + EnsembleInferenceModel generateTrueObject() { + try { + Ensemble model = EnsembleTests.createRandom(); + EnsembleInferenceModel inferenceModel = EnsembleInferenceModelTests.serializeFromTrainedModel(model); + inferenceModel.rewriteFeatureIndices(Collections.emptyMap()); + return inferenceModel; + } catch (IOException ex) { + throw new ElasticsearchException(ex); + } + } + + @Override + EnsembleSizeInfo translateObject(EnsembleInferenceModel originalObject) { + return translateToEstimate(originalObject); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSizeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSizeTests.java new file mode 100644 index 00000000000..b49384111a3 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/FrequencyEncodingSizeTests.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests; + +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class FrequencyEncodingSizeTests extends SizeEstimatorTestCase { + + static FrequencyEncodingSize createRandom() { + return new FrequencyEncodingSize(randomInt(100), + randomInt(100), + Stream.generate(() -> randomIntBetween(5, 10)) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList())); + } + + static FrequencyEncodingSize translateToEstimate(FrequencyEncoding encoding) { + return new FrequencyEncodingSize(encoding.getField().length(), + encoding.getFeatureName().length(), + encoding.getFrequencyMap().keySet().stream().map(String::length).collect(Collectors.toList())); + } + + @Override + protected FrequencyEncodingSize createTestInstance() { + return createRandom(); + } + + @Override + protected FrequencyEncodingSize doParseInstance(XContentParser parser) { + return FrequencyEncodingSize.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + FrequencyEncoding generateTrueObject() { + return FrequencyEncodingTests.createRandom(); + } + + @Override + FrequencyEncodingSize translateObject(FrequencyEncoding originalObject) { + return translateToEstimate(originalObject); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfoTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfoTests.java new file mode 100644 index 00000000000..dfb500c72e7 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/ModelSizeInfoTests.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class ModelSizeInfoTests extends AbstractXContentTestCase { + + public static ModelSizeInfo createRandom() { + return new ModelSizeInfo(EnsembleSizeInfoTests.createRandom(), + randomBoolean() ? + null : + Stream.generate(() -> randomFrom( + FrequencyEncodingSizeTests.createRandom(), + OneHotEncodingSizeTests.createRandom(), + TargetMeanEncodingSizeTests.createRandom())) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList())); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected ModelSizeInfo createTestInstance() { + return createRandom(); + } + + @Override + protected ModelSizeInfo doParseInstance(XContentParser parser) { + return ModelSizeInfo.PARSER.apply(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + public void testParseDescribedFormat() throws IOException { + XContentParser parser = XContentHelper.createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(FORMAT), + XContentType.JSON); + // Shouldn't throw + doParseInstance(parser); + } + + private static final String FORMAT = "" + + "{\n" + + " \"trained_model_size\": {\n" + + " \"ensemble_model_size\": {\n" + + " \"tree_sizes\": [\n" + + " {\"num_nodes\": 7, \"num_leaves\": 8},\n" + + " {\"num_nodes\": 3, \"num_leaves\": 4},\n" + + " {\"num_leaves\": 1}\n" + + " ],\n" + + " \"input_field_name_lengths\": [\n" + + " 14,\n" + + " 10,\n" + + " 11\n" + + " ],\n" + + " \"num_output_processor_weights\": 3,\n" + + " \"num_classification_weights\": 0,\n" + + " \"num_classes\": 0,\n" + + " \"num_operations\": 3\n" + + " }\n" + + " },\n" + + " \"preprocessors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field_length\": 10,\n" + + " \"field_value_lengths\": [\n" + + " 10,\n" + + " 20\n" + + " ],\n" + + " \"feature_name_lengths\": [\n" + + " 15,\n" + + " 25\n" + + " ]\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field_length\": 10,\n" + + " \"feature_name_length\": 5,\n" + + " \"field_value_lengths\": [\n" + + " 10,\n" + + " 20\n" + + " ]\n" + + " }\n" + + " },\n" + + " {\n" + + " \"target_mean_encoding\": {\n" + + " \"field_length\": 6,\n" + + " \"feature_name_length\": 15,\n" + + " \"field_value_lengths\": [\n" + + " 10,\n" + + " 20\n" + + " ]\n" + + " }\n" + + " }\n" + + " ]\n" + + "} "; +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSizeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSizeTests.java new file mode 100644 index 00000000000..3aba41c3525 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/OneHotEncodingSizeTests.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests; + +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class OneHotEncodingSizeTests extends SizeEstimatorTestCase { + + static OneHotEncodingSize createRandom() { + int numFieldEntries = randomIntBetween(1, 10); + return new OneHotEncodingSize( + randomInt(100), + Stream.generate(() -> randomIntBetween(5, 10)) + .limit(numFieldEntries) + .collect(Collectors.toList()), + Stream.generate(() -> randomIntBetween(5, 10)) + .limit(numFieldEntries) + .collect(Collectors.toList())); + } + + static OneHotEncodingSize translateToEstimate(OneHotEncoding encoding) { + return new OneHotEncodingSize(encoding.getField().length(), + encoding.getHotMap().values().stream().map(String::length).collect(Collectors.toList()), + encoding.getHotMap().keySet().stream().map(String::length).collect(Collectors.toList())); + } + + @Override + protected OneHotEncodingSize createTestInstance() { + return createRandom(); + } + + @Override + protected OneHotEncodingSize doParseInstance(XContentParser parser) { + return OneHotEncodingSize.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + OneHotEncoding generateTrueObject() { + return OneHotEncodingTests.createRandom(); + } + + @Override + OneHotEncodingSize translateObject(OneHotEncoding originalObject) { + return translateToEstimate(originalObject); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorTestCase.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorTestCase.java new file mode 100644 index 00000000000..764157ec465 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/SizeEstimatorTestCase.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.apache.lucene.util.Accountable; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.test.AbstractXContentTestCase; + +import static org.hamcrest.Matchers.is; + +public abstract class SizeEstimatorTestCase + extends AbstractXContentTestCase { + + abstract U generateTrueObject(); + + abstract T translateObject(U originalObject); + + public void testRamUsageEstimationAccuracy() { + final long bytesEps = new ByteSizeValue(2, ByteSizeUnit.KB).getBytes(); + for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) { + U obj = generateTrueObject(); + T estimateObj = translateObject(obj); + long originalBytesUsed = obj.ramBytesUsed(); + long estimateBytesUsed = estimateObj.ramBytesUsed(); + // If we are over by 2kb that is small enough to not be a concern + boolean condition = (Math.abs(obj.ramBytesUsed() - estimateObj.ramBytesUsed()) < bytesEps) || + // If the difference is greater than 2kb, it is better to have overestimated. + originalBytesUsed < estimateBytesUsed; + assertThat("estimation difference greater than 2048 and the estimation is too small. Object [" + + obj.toString() + + "] estimated [" + + originalBytesUsed + + "] translated object [" + + estimateObj + + "] estimated [" + + estimateBytesUsed + + "]" , + condition, + is(true)); + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSizeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSizeTests.java new file mode 100644 index 00000000000..fede92ce51c --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TargetMeanEncodingSizeTests.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; + +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class TargetMeanEncodingSizeTests extends SizeEstimatorTestCase { + + static TargetMeanEncodingSize createRandom() { + return new TargetMeanEncodingSize(randomInt(100), + randomInt(100), + Stream.generate(() -> randomIntBetween(5, 10)) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList())); + } + + static TargetMeanEncodingSize translateToEstimate(TargetMeanEncoding encoding) { + return new TargetMeanEncodingSize(encoding.getField().length(), + encoding.getFeatureName().length(), + encoding.getMeanMap().keySet().stream().map(String::length).collect(Collectors.toList())); + } + + @Override + protected TargetMeanEncodingSize createTestInstance() { + return createRandom(); + } + + @Override + protected TargetMeanEncodingSize doParseInstance(XContentParser parser) { + return TargetMeanEncodingSize.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + TargetMeanEncoding generateTrueObject() { + return TargetMeanEncodingTests.createRandom(); + } + + @Override + TargetMeanEncodingSize translateObject(TargetMeanEncoding originalObject) { + return translateToEstimate(originalObject); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfoTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfoTests.java new file mode 100644 index 00000000000..cc0b2831b3c --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/modelsize/TreeSizeInfoTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.inference.modelsize; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModelTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; + +import java.io.IOException; +import java.util.Arrays; + + +public class TreeSizeInfoTests extends SizeEstimatorTestCase { + + static TreeSizeInfo createRandom() { + return new TreeSizeInfo(randomIntBetween(1, 100), randomIntBetween(0, 100), randomIntBetween(0, 10)); + } + + static TreeSizeInfo translateToEstimate(TreeInferenceModel tree) { + int numClasses = Arrays.stream(tree.getNodes()) + .filter(TreeInferenceModel.Node::isLeaf) + .map(n -> (TreeInferenceModel.LeafNode)n) + .findFirst() + .get() + .getLeafValue() + .length; + return new TreeSizeInfo((int)Arrays.stream(tree.getNodes()).filter(TreeInferenceModel.Node::isLeaf).count(), + (int)Arrays.stream(tree.getNodes()).filter(t -> t.isLeaf() == false).count(), + numClasses); + } + + @Override + protected TreeSizeInfo createTestInstance() { + return createRandom(); + } + + @Override + protected TreeSizeInfo doParseInstance(XContentParser parser) { + return TreeSizeInfo.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + TreeInferenceModel generateTrueObject() { + try { + return TreeInferenceModelTests.serializeFromTrainedModel( + TreeTests.buildRandomTree(Arrays.asList(randomAlphaOfLength(10), randomAlphaOfLength(10)), 6) + ); + } catch (IOException ex) { + throw new ElasticsearchException(ex); + } + } + + @Override + TreeSizeInfo translateObject(TreeInferenceModel originalObject) { + return translateToEstimate(originalObject); + } +}