[ML] adding new inference model size estimate handling from native process (#57930) (#57999)

Adds support for reading in `model_size_info` objects.

These objects contain numeric values indicating the model definition size and complexity.

Additionally, these objects are not stored or serialized to any other node. They are to be used for calculating and storing model metadata. They are much smaller on heap than the true model definition and should help prevent the analytics process from using too much memory.

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Benjamin Trent 2020-06-11 15:59:23 -04:00 committed by GitHub
parent ffc3c77f75
commit 2881995a45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 1339 additions and 42 deletions

View File

@ -27,7 +27,7 @@ import java.util.Objects;
*/ */
public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FrequencyEncoding.class); public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FrequencyEncoding.class);
public static final ParseField NAME = new ParseField("frequency_encoding"); public static final ParseField NAME = new ParseField("frequency_encoding");
public static final ParseField FIELD = new ParseField("field"); public static final ParseField FIELD = new ParseField("field");

View File

@ -27,7 +27,7 @@ import java.util.stream.Collectors;
*/ */
public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(OneHotEncoding.class); public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(OneHotEncoding.class);
public static final ParseField NAME = new ParseField("one_hot_encoding"); public static final ParseField NAME = new ParseField("one_hot_encoding");
public static final ParseField FIELD = new ParseField("field"); public static final ParseField FIELD = new ParseField("field");
public static final ParseField HOT_MAP = new ParseField("hot_map"); public static final ParseField HOT_MAP = new ParseField("hot_map");

View File

@ -27,7 +27,7 @@ import java.util.Objects;
*/ */
public class TargetMeanEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { public class TargetMeanEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TargetMeanEncoding.class); public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TargetMeanEncoding.class);
public static final ParseField NAME = new ParseField("target_mean_encoding"); public static final ParseField NAME = new ParseField("target_mean_encoding");
public static final ParseField FIELD = new ParseField("field"); public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_NAME = new ParseField("feature_name"); public static final ParseField FEATURE_NAME = new ParseField("feature_name");

View File

@ -25,7 +25,7 @@ import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax
public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class); public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class);
public static final ParseField NAME = new ParseField("logistic_regression"); public static final ParseField NAME = new ParseField("logistic_regression");
public static final ParseField WEIGHTS = new ParseField("weights"); public static final ParseField WEIGHTS = new ParseField("weights");

View File

@ -47,7 +47,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.En
public class EnsembleInferenceModel implements InferenceModel { public class EnsembleInferenceModel implements InferenceModel {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class); public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final ConstructingObjectParser<EnsembleInferenceModel, Void> PARSER = new ConstructingObjectParser<>( private static final ConstructingObjectParser<EnsembleInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
@ -275,4 +275,21 @@ public class EnsembleInferenceModel implements InferenceModel {
size += outputAggregator.ramBytesUsed(); size += outputAggregator.ramBytesUsed();
return size; return size;
} }
public List<InferenceModel> getModels() {
return models;
}
public OutputAggregator getOutputAggregator() {
return outputAggregator;
}
public TargetType getTargetType() {
return targetType;
}
public double[] getClassificationWeights() {
return classificationWeights;
}
} }

View File

@ -26,7 +26,7 @@ import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition.T
public class InferenceDefinition { public class InferenceDefinition {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InferenceDefinition.class); public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InferenceDefinition.class);
public static final String NAME = "inference_model_definition"; public static final String NAME = "inference_model_definition";
private final InferenceModel trainedModel; private final InferenceModel trainedModel;

View File

@ -7,7 +7,6 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.Numbers; import org.elasticsearch.common.Numbers;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@ -27,13 +26,15 @@ import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
import org.elasticsearch.xpack.core.ml.job.config.Operator; import org.elasticsearch.xpack.core.ml.job.config.Operator;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
import static org.apache.lucene.util.RamUsageEstimator.sizeOf;
import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
@ -53,7 +54,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNo
public class TreeInferenceModel implements InferenceModel { public class TreeInferenceModel implements InferenceModel {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TreeInferenceModel.class); public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final ConstructingObjectParser<TreeInferenceModel, Void> PARSER = new ConstructingObjectParser<>( private static final ConstructingObjectParser<TreeInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
@ -75,7 +76,7 @@ public class TreeInferenceModel implements InferenceModel {
private final Node[] nodes; private final Node[] nodes;
private String[] featureNames; private String[] featureNames;
private final TargetType targetType; private final TargetType targetType;
private final List<String> classificationLabels; private List<String> classificationLabels;
private final double highOrderCategory; private final double highOrderCategory;
private final int maxDepth; private final int maxDepth;
private final int leafSize; private final int leafSize;
@ -302,18 +303,16 @@ public class TreeInferenceModel implements InferenceModel {
treeNode.splitFeature = newSplitFeatureIndex; treeNode.splitFeature = newSplitFeatureIndex;
} }
this.featureNames = new String[0]; this.featureNames = new String[0];
// Since we are not top level, we no longer need local classification labels
this.classificationLabels = null;
} }
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
long size = SHALLOW_SIZE; long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOfCollection(classificationLabels); size += sizeOfCollection(classificationLabels);
size += RamUsageEstimator.sizeOf(featureNames); size += sizeOf(featureNames);
size += RamUsageEstimator.shallowSizeOf(nodes); size += sizeOf(nodes);
for (Node node : nodes) {
size += node.ramBytesUsed();
}
size += RamUsageEstimator.sizeOfCollection(Arrays.asList(nodes));
return size; return size;
} }
@ -335,6 +334,10 @@ public class TreeInferenceModel implements InferenceModel {
return max; return max;
} }
public Node[] getNodes() {
return nodes;
}
private static int getDepth(Node[] nodes, int nodeIndex, int depth) { private static int getDepth(Node[] nodes, int nodeIndex, int depth) {
Node node = nodes[nodeIndex]; Node node = nodes[nodeIndex];
if (node instanceof LeafNode) { if (node instanceof LeafNode) {
@ -433,21 +436,21 @@ public class TreeInferenceModel implements InferenceModel {
} }
} }
private abstract static class Node implements Accountable { public abstract static class Node implements Accountable {
int compare(double[] features) { int compare(double[] features) {
throw new IllegalArgumentException("cannot call compare against a leaf node."); throw new IllegalArgumentException("cannot call compare against a leaf node.");
} }
abstract long getNumberSamples(); abstract long getNumberSamples();
boolean isLeaf() { public boolean isLeaf() {
return this instanceof LeafNode; return this instanceof LeafNode;
} }
} }
private static class InnerNode extends Node { public static class InnerNode extends Node {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InnerNode.class); public static final long SHALLOW_SIZE = shallowSizeOfInstance(InnerNode.class);
private final Operator operator; private final Operator operator;
private final double threshold; private final double threshold;
@ -498,8 +501,8 @@ public class TreeInferenceModel implements InferenceModel {
} }
} }
private static class LeafNode extends Node { public static class LeafNode extends Node {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LeafNode.class); public static final long SHALLOW_SIZE = shallowSizeOfInstance(LeafNode.class);
private final double[] leafValue; private final double[] leafValue;
private final long numberSamples; private final long numberSamples;
@ -510,12 +513,16 @@ public class TreeInferenceModel implements InferenceModel {
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
return SHALLOW_SIZE; return SHALLOW_SIZE + sizeOf(leafValue);
} }
@Override @Override
long getNumberSamples() { long getNumberSamples() {
return numberSamples; return numberSamples;
} }
public double[] getLeafValue() {
return leafValue;
}
} }
} }

View File

@ -91,7 +91,7 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
.toArray() : .toArray() :
null; null;
return new Ensemble(randomBoolean() ? featureNames : Collections.emptyList(), return new Ensemble(featureNames,
models, models,
outputAggregator, outputAggregator,
targetType, targetType,

View File

@ -28,6 +28,7 @@ import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -35,6 +36,13 @@ public class TreeInferenceModelTests extends ESTestCase {
private final double eps = 1.0E-8; private final double eps = 1.0E-8;
public static TreeInferenceModel serializeFromTrainedModel(Tree tree) throws IOException {
NamedXContentRegistry registry = new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return deserializeFromTrainedModel(tree,
registry,
TreeInferenceModel::fromXContent);
}
@Override @Override
protected NamedXContentRegistry xContentRegistry() { protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>(); List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
@ -48,7 +56,7 @@ public class TreeInferenceModelTests extends ESTestCase {
builder.setFeatureNames(Collections.emptyList()); builder.setFeatureNames(Collections.emptyList());
Tree treeObject = builder.build(); Tree treeObject = builder.build();
TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject, TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
xContentRegistry(), xContentRegistry(),
TreeInferenceModel::fromXContent); TreeInferenceModel::fromXContent);
List<String> featureNames = Arrays.asList("foo", "bar"); List<String> featureNames = Arrays.asList("foo", "bar");
@ -71,7 +79,7 @@ public class TreeInferenceModelTests extends ESTestCase {
List<String> featureNames = Arrays.asList("foo", "bar"); List<String> featureNames = Arrays.asList("foo", "bar");
Tree treeObject = builder.setFeatureNames(featureNames).build(); Tree treeObject = builder.setFeatureNames(featureNames).build();
TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject, TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
xContentRegistry(), xContentRegistry(),
TreeInferenceModel::fromXContent); TreeInferenceModel::fromXContent);
// This feature vector should hit the right child of the root node // This feature vector should hit the right child of the root node
@ -126,7 +134,7 @@ public class TreeInferenceModelTests extends ESTestCase {
List<String> featureNames = Arrays.asList("foo", "bar"); List<String> featureNames = Arrays.asList("foo", "bar");
Tree treeObject = builder.setFeatureNames(featureNames).setClassificationLabels(Arrays.asList("cat", "dog")).build(); Tree treeObject = builder.setFeatureNames(featureNames).setClassificationLabels(Arrays.asList("cat", "dog")).build();
TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject, TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
xContentRegistry(), xContentRegistry(),
TreeInferenceModel::fromXContent); TreeInferenceModel::fromXContent);
double eps = 0.000001; double eps = 0.000001;
@ -200,7 +208,7 @@ public class TreeInferenceModelTests extends ESTestCase {
TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L), TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L),
TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build(); TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build();
TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject, TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
xContentRegistry(), xContentRegistry(),
TreeInferenceModel::fromXContent); TreeInferenceModel::fromXContent);

View File

@ -219,6 +219,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimatio
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.JobManagerHolder;
@ -1003,6 +1004,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
return namedXContent; return namedXContent;
} }

View File

@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierD
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
@ -31,6 +32,7 @@ public class AnalyticsResult implements ToXContentObject {
private static final ParseField PHASE_PROGRESS = new ParseField("phase_progress"); private static final ParseField PHASE_PROGRESS = new ParseField("phase_progress");
private static final ParseField INFERENCE_MODEL = new ParseField("inference_model"); private static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
private static final ParseField MODEL_SIZE_INFO = new ParseField("model_size_info");
private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage"); private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage");
private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats"); private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats");
private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats"); private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats");
@ -44,7 +46,8 @@ public class AnalyticsResult implements ToXContentObject {
(MemoryUsage) a[3], (MemoryUsage) a[3],
(OutlierDetectionStats) a[4], (OutlierDetectionStats) a[4],
(ClassificationStats) a[5], (ClassificationStats) a[5],
(RegressionStats) a[6] (RegressionStats) a[6],
(ModelSizeInfo) a[7]
)); ));
static { static {
@ -56,6 +59,7 @@ public class AnalyticsResult implements ToXContentObject {
PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS); PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS);
PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS); PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS);
PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS); PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS);
PARSER.declareObject(optionalConstructorArg(), ModelSizeInfo.PARSER, MODEL_SIZE_INFO);
} }
private final RowResults rowResults; private final RowResults rowResults;
@ -66,6 +70,7 @@ public class AnalyticsResult implements ToXContentObject {
private final OutlierDetectionStats outlierDetectionStats; private final OutlierDetectionStats outlierDetectionStats;
private final ClassificationStats classificationStats; private final ClassificationStats classificationStats;
private final RegressionStats regressionStats; private final RegressionStats regressionStats;
private final ModelSizeInfo modelSizeInfo;
public AnalyticsResult(@Nullable RowResults rowResults, public AnalyticsResult(@Nullable RowResults rowResults,
@Nullable PhaseProgress phaseProgress, @Nullable PhaseProgress phaseProgress,
@ -73,7 +78,8 @@ public class AnalyticsResult implements ToXContentObject {
@Nullable MemoryUsage memoryUsage, @Nullable MemoryUsage memoryUsage,
@Nullable OutlierDetectionStats outlierDetectionStats, @Nullable OutlierDetectionStats outlierDetectionStats,
@Nullable ClassificationStats classificationStats, @Nullable ClassificationStats classificationStats,
@Nullable RegressionStats regressionStats) { @Nullable RegressionStats regressionStats,
@Nullable ModelSizeInfo modelSizeInfo) {
this.rowResults = rowResults; this.rowResults = rowResults;
this.phaseProgress = phaseProgress; this.phaseProgress = phaseProgress;
this.inferenceModelBuilder = inferenceModelBuilder; this.inferenceModelBuilder = inferenceModelBuilder;
@ -82,6 +88,7 @@ public class AnalyticsResult implements ToXContentObject {
this.outlierDetectionStats = outlierDetectionStats; this.outlierDetectionStats = outlierDetectionStats;
this.classificationStats = classificationStats; this.classificationStats = classificationStats;
this.regressionStats = regressionStats; this.regressionStats = regressionStats;
this.modelSizeInfo = modelSizeInfo;
} }
public RowResults getRowResults() { public RowResults getRowResults() {
@ -112,6 +119,10 @@ public class AnalyticsResult implements ToXContentObject {
return regressionStats; return regressionStats;
} }
public ModelSizeInfo getModelSizeInfo() {
return modelSizeInfo;
}
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
@ -138,6 +149,9 @@ public class AnalyticsResult implements ToXContentObject {
if (regressionStats != null) { if (regressionStats != null) {
builder.field(REGRESSION_STATS.getPreferredName(), regressionStats, params); builder.field(REGRESSION_STATS.getPreferredName(), regressionStats, params);
} }
if (modelSizeInfo != null) {
builder.field(MODEL_SIZE_INFO.getPreferredName(), modelSizeInfo);
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -158,6 +172,7 @@ public class AnalyticsResult implements ToXContentObject {
&& Objects.equals(memoryUsage, that.memoryUsage) && Objects.equals(memoryUsage, that.memoryUsage)
&& Objects.equals(outlierDetectionStats, that.outlierDetectionStats) && Objects.equals(outlierDetectionStats, that.outlierDetectionStats)
&& Objects.equals(classificationStats, that.classificationStats) && Objects.equals(classificationStats, that.classificationStats)
&& Objects.equals(modelSizeInfo, that.modelSizeInfo)
&& Objects.equals(regressionStats, that.regressionStats); && Objects.equals(regressionStats, that.regressionStats);
} }

View File

@ -0,0 +1,136 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModel;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfDoubleArray;
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfStringCollection;
public class EnsembleSizeInfo implements TrainedModelSizeInfo {
public static final ParseField NAME = new ParseField("ensemble_model_size");
private static final ParseField TREE_SIZES = new ParseField("tree_sizes");
private static final ParseField INPUT_FIELD_NAME_LENGHTS = new ParseField("input_field_name_lengths");
private static final ParseField NUM_OUTPUT_PROCESSOR_WEIGHTS = new ParseField("num_output_processor_weights");
private static final ParseField NUM_CLASSIFICATION_WEIGHTS = new ParseField("num_classification_weights");
private static final ParseField NUM_OPERATIONS = new ParseField("num_operations");
private static final ParseField NUM_CLASSES = new ParseField("num_classes");
@SuppressWarnings("unchecked")
static ConstructingObjectParser<EnsembleSizeInfo, Void> PARSER = new ConstructingObjectParser<>(
"ensemble_size",
false,
a -> new EnsembleSizeInfo((List<TreeSizeInfo>)a[0],
(Integer)a[1],
(List<Integer>)a[2],
a[3] == null ? 0 : (Integer)a[3],
a[4] == null ? 0 : (Integer)a[4],
a[5] == null ? 0 : (Integer)a[5])
);
static {
PARSER.declareObjectArray(constructorArg(), TreeSizeInfo.PARSER::apply, TREE_SIZES);
PARSER.declareInt(constructorArg(), NUM_OPERATIONS);
PARSER.declareIntArray(constructorArg(), INPUT_FIELD_NAME_LENGHTS);
PARSER.declareInt(optionalConstructorArg(), NUM_OUTPUT_PROCESSOR_WEIGHTS);
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSIFICATION_WEIGHTS);
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES);
}
public static EnsembleSizeInfo fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final List<TreeSizeInfo> treeSizeInfos;
private final int numOperations;
private final int[] inputFieldNameLengths;
private final int numOutputProcessorWeights;
private final int numClassificationWeights;
private final int numClasses;
public EnsembleSizeInfo(List<TreeSizeInfo> treeSizeInfos,
int numOperations,
List<Integer> inputFieldNameLengths,
int numOutputProcessorWeights,
int numClassificationWeights,
int numClasses) {
this.treeSizeInfos = treeSizeInfos;
this.numOperations = numOperations;
this.inputFieldNameLengths = inputFieldNameLengths.stream().mapToInt(Integer::intValue).toArray();
this.numOutputProcessorWeights = numOutputProcessorWeights;
this.numClassificationWeights = numClassificationWeights;
this.numClasses = numClasses;
}
public int getNumOperations() {
return numOperations;
}
@Override
public long ramBytesUsed() {
long size = EnsembleInferenceModel.SHALLOW_SIZE;
treeSizeInfos.forEach(t -> t.setNumClasses(numClasses).ramBytesUsed());
size += sizeOfCollection(treeSizeInfos);
size += sizeOfStringCollection(inputFieldNameLengths);
size += LogisticRegression.SHALLOW_SIZE + sizeOfDoubleArray(numOutputProcessorWeights);
size += sizeOfDoubleArray(numClassificationWeights);
return alignObjectSize(size);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TREE_SIZES.getPreferredName(), treeSizeInfos);
builder.field(NUM_OPERATIONS.getPreferredName(), numOperations);
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
builder.field(INPUT_FIELD_NAME_LENGHTS.getPreferredName(), inputFieldNameLengths);
builder.field(NUM_CLASSIFICATION_WEIGHTS.getPreferredName(), numClassificationWeights);
builder.field(NUM_OUTPUT_PROCESSOR_WEIGHTS.getPreferredName(), numOutputProcessorWeights);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
EnsembleSizeInfo that = (EnsembleSizeInfo) o;
return numOperations == that.numOperations &&
numOutputProcessorWeights == that.numOutputProcessorWeights &&
numClassificationWeights == that.numClassificationWeights &&
numClasses == that.numClasses &&
Objects.equals(treeSizeInfos, that.treeSizeInfos) &&
Arrays.equals(inputFieldNameLengths, that.inputFieldNameLengths);
}
@Override
public int hashCode() {
int result = Objects.hash(treeSizeInfos, numOperations, numOutputProcessorWeights, numClassificationWeights, numClasses);
result = 31 * result + Arrays.hashCode(inputFieldNameLengths);
return result;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
}

View File

@ -0,0 +1,98 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfHashMap;
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfString;
public class FrequencyEncodingSize implements PreprocessorSize {
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<FrequencyEncodingSize, Void> PARSER = new ConstructingObjectParser<>(
"frequency_encoding_size",
false,
a -> new FrequencyEncodingSize((Integer)a[0], (Integer)a[1], (List<Integer>)a[2])
);
static {
PARSER.declareInt(constructorArg(), FIELD_LENGTH);
PARSER.declareInt(constructorArg(), FEATURE_NAME_LENGTH);
PARSER.declareIntArray(constructorArg(), FIELD_VALUE_LENGTHS);
}
public static FrequencyEncodingSize fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final int fieldLength;
private final int featureNameLength;
private final int[] fieldValueLengths;
FrequencyEncodingSize(int fieldLength, int featureNameLength, List<Integer> fieldValueLengths) {
this.fieldLength = fieldLength;
this.featureNameLength = featureNameLength;
this.fieldValueLengths = fieldValueLengths.stream().mapToInt(Integer::intValue).toArray();
}
@Override
public long ramBytesUsed() {
final long sizeOfDoubleObject = shallowSizeOfInstance(Double.class);
long size = FrequencyEncoding.SHALLOW_SIZE;
size += sizeOfString(fieldLength);
size += sizeOfString(featureNameLength);
size += sizeOfHashMap(
Arrays.stream(fieldValueLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()),
Stream.generate(() -> sizeOfDoubleObject).limit(fieldValueLengths.length).collect(Collectors.toList()));
return alignObjectSize(size);
}
@Override
public String getName() {
return FrequencyEncoding.NAME.getPreferredName();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FIELD_LENGTH.getPreferredName(), fieldLength);
builder.field(FEATURE_NAME_LENGTH.getPreferredName(), featureNameLength);
builder.field(FIELD_VALUE_LENGTHS.getPreferredName(), fieldValueLengths);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FrequencyEncodingSize that = (FrequencyEncodingSize) o;
return fieldLength == that.fieldLength &&
featureNameLength == that.featureNameLength &&
Arrays.equals(fieldValueLengths, that.fieldValueLengths);
}
@Override
public int hashCode() {
int result = Objects.hash(fieldLength, featureNameLength);
result = 31 * result + Arrays.hashCode(fieldValueLengths);
return result;
}
}

View File

@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import java.util.Arrays;
import java.util.List;
public class MlModelSizeNamedXContentProvider implements NamedXContentProvider {
@Override
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
return Arrays.asList(
new NamedXContentRegistry.Entry(PreprocessorSize.class,
FrequencyEncoding.NAME,
FrequencyEncodingSize::fromXContent),
new NamedXContentRegistry.Entry(PreprocessorSize.class,
OneHotEncoding.NAME,
OneHotEncodingSize::fromXContent),
new NamedXContentRegistry.Entry(PreprocessorSize.class,
TargetMeanEncoding.NAME,
TargetMeanEncodingSize::fromXContent),
new NamedXContentRegistry.Entry(TrainedModelSizeInfo.class,
EnsembleSizeInfo.NAME,
EnsembleSizeInfo::fromXContent)
);
}
}

View File

@ -0,0 +1,91 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class ModelSizeInfo implements Accountable, ToXContentObject {
private static final ParseField PREPROCESSORS = new ParseField("preprocessors");
private static final ParseField TRAINED_MODEL_SIZE = new ParseField("trained_model_size");
@SuppressWarnings("unchecked")
public static ConstructingObjectParser<ModelSizeInfo, Void> PARSER = new ConstructingObjectParser<>(
"model_size",
false,
a -> new ModelSizeInfo((EnsembleSizeInfo)a[0], (List<PreprocessorSize>)a[1])
);
static {
PARSER.declareNamedObject(constructorArg(),
(p, c, n) -> p.namedObject(TrainedModelSizeInfo.class, n, null),
TRAINED_MODEL_SIZE);
PARSER.declareNamedObjects(optionalConstructorArg(),
(p, c, n) -> p.namedObject(PreprocessorSize.class, n, null),
(val) -> {},
PREPROCESSORS);
}
private final EnsembleSizeInfo ensembleSizeInfo;
private final List<PreprocessorSize> preprocessorSizes;
public ModelSizeInfo(EnsembleSizeInfo ensembleSizeInfo, List<PreprocessorSize> preprocessorSizes) {
this.ensembleSizeInfo = ensembleSizeInfo;
this.preprocessorSizes = preprocessorSizes == null ? Collections.emptyList() : preprocessorSizes;
}
public int numOperations() {
return this.preprocessorSizes.size() + this.ensembleSizeInfo.getNumOperations();
}
@Override
public long ramBytesUsed() {
long size = InferenceDefinition.SHALLOW_SIZE;
size += ensembleSizeInfo.ramBytesUsed();
size += preprocessorSizes.stream().mapToLong(PreprocessorSize::ramBytesUsed).sum();
return alignObjectSize(size);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
NamedXContentObjectHelper.writeNamedObject(builder, params, TRAINED_MODEL_SIZE.getPreferredName(), ensembleSizeInfo);
if (preprocessorSizes.size() > 0) {
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, PREPROCESSORS.getPreferredName(), preprocessorSizes);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ModelSizeInfo modelSizeInfo = (ModelSizeInfo) o;
return Objects.equals(ensembleSizeInfo, modelSizeInfo.ensembleSizeInfo) &&
Objects.equals(preprocessorSizes, modelSizeInfo.preprocessorSizes);
}
@Override
public int hashCode() {
return Objects.hash(ensembleSizeInfo, preprocessorSizes);
}
}

View File

@ -0,0 +1,100 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfHashMap;
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfString;
public class OneHotEncodingSize implements PreprocessorSize {
private static final ParseField FEATURE_NAME_LENGTHS = new ParseField("feature_name_lengths");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<OneHotEncodingSize, Void> PARSER = new ConstructingObjectParser<>(
"one_hot_encoding_size",
false,
a -> new OneHotEncodingSize((Integer)a[0], (List<Integer>)a[1], (List<Integer>)a[2])
);
static {
PARSER.declareInt(constructorArg(), FIELD_LENGTH);
PARSER.declareIntArray(constructorArg(), FEATURE_NAME_LENGTHS);
PARSER.declareIntArray(constructorArg(), FIELD_VALUE_LENGTHS);
}
public static OneHotEncodingSize fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final int fieldLength;
private final int[] featureNameLengths;
private final int[] fieldValueLengths;
OneHotEncodingSize(int fieldLength, List<Integer> featureNameLengths, List<Integer> fieldValueLengths) {
assert featureNameLengths.size() == fieldValueLengths.size();
this.fieldLength = fieldLength;
this.featureNameLengths = featureNameLengths.stream().mapToInt(Integer::intValue).toArray();
this.fieldValueLengths = fieldValueLengths.stream().mapToInt(Integer::intValue).toArray();
}
@Override
public long ramBytesUsed() {
long size = OneHotEncoding.SHALLOW_SIZE;
size += sizeOfString(fieldLength);
size += sizeOfHashMap(
Arrays.stream(fieldValueLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()),
Arrays.stream(featureNameLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList())
);
return alignObjectSize(size);
}
@Override
public String getName() {
return OneHotEncoding.NAME.getPreferredName();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FIELD_LENGTH.getPreferredName(), fieldLength);
builder.field(FEATURE_NAME_LENGTHS.getPreferredName(), featureNameLengths);
builder.field(FIELD_VALUE_LENGTHS.getPreferredName(), fieldValueLengths);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
OneHotEncodingSize that = (OneHotEncodingSize) o;
return fieldLength == that.fieldLength &&
Arrays.equals(featureNameLengths, that.featureNameLengths) &&
Arrays.equals(fieldValueLengths, that.fieldValueLengths);
}
@Override
public int hashCode() {
int result = Objects.hash(fieldLength);
result = 31 * result + Arrays.hashCode(featureNameLengths);
result = 31 * result + Arrays.hashCode(fieldValueLengths);
return result;
}
}

View File

@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
public interface PreprocessorSize extends Accountable, NamedXContentObject {
ParseField FIELD_LENGTH = new ParseField("field_length");
ParseField FEATURE_NAME_LENGTH = new ParseField("feature_name_length");
ParseField FIELD_VALUE_LENGTHS = new ParseField("field_value_lengths");
String getName();
}

View File

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Stream;
import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF;
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
final class SizeEstimatorHelper {
private SizeEstimatorHelper() {}
private static final int STRING_SIZE = (int) shallowSizeOfInstance(String.class);
static long sizeOfString(int stringLength) {
// Technically, each value counted in a String.length is 2 bytes. But, this is how `RamUsageEstimator` calculates it
return alignObjectSize(STRING_SIZE + (long)NUM_BYTES_ARRAY_HEADER + (long)(Character.BYTES) * stringLength);
}
static long sizeOfStringCollection(int[] stringSizes) {
long shallow = alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) NUM_BYTES_OBJECT_REF * stringSizes.length);
return shallow + Arrays.stream(stringSizes).mapToLong(SizeEstimatorHelper::sizeOfString).sum();
}
static long sizeOfDoubleArray(int arrayLength) {
return alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) Double.BYTES * arrayLength);
}
static long sizeOfHashMap(List<Long> sizeOfKeys, List<Long> sizeOfValues) {
assert sizeOfKeys.size() == sizeOfValues.size();
long mapsize = shallowSizeOfInstance(HashMap.class);
final long mapEntrySize = shallowSizeOfInstance(HashMap.Entry.class);
mapsize += Stream.concat(sizeOfKeys.stream(), sizeOfValues.stream()).mapToLong(Long::longValue).sum();
mapsize += mapEntrySize * sizeOfKeys.size();
mapsize += mapEntrySize * sizeOfValues.size();
return mapsize;
}
}

View File

@ -0,0 +1,98 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfHashMap;
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfString;
public class TargetMeanEncodingSize implements PreprocessorSize {
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<TargetMeanEncodingSize, Void> PARSER = new ConstructingObjectParser<>(
"target_mean_encoding_size",
false,
a -> new TargetMeanEncodingSize((Integer)a[0], (Integer)a[1], (List<Integer>)a[2])
);
static {
PARSER.declareInt(constructorArg(), FIELD_LENGTH);
PARSER.declareInt(constructorArg(), FEATURE_NAME_LENGTH);
PARSER.declareIntArray(constructorArg(), FIELD_VALUE_LENGTHS);
}
public static TargetMeanEncodingSize fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final int fieldLength;
private final int featureNameLength;
private final int[] fieldValueLengths;
TargetMeanEncodingSize(int fieldLength, int featureNameLength, List<Integer> fieldValueLengths) {
this.fieldLength = fieldLength;
this.featureNameLength = featureNameLength;
this.fieldValueLengths = fieldValueLengths.stream().mapToInt(Integer::intValue).toArray();
}
@Override
public long ramBytesUsed() {
final long sizeOfDoubleObject = shallowSizeOfInstance(Double.class);
long size = TargetMeanEncoding.SHALLOW_SIZE;
size += sizeOfString(fieldLength);
size += sizeOfString(featureNameLength);
size += sizeOfHashMap(
Arrays.stream(fieldValueLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()),
Stream.generate(() -> sizeOfDoubleObject).limit(fieldValueLengths.length).collect(Collectors.toList())
);
return alignObjectSize(size);
}
@Override
public String getName() {
return TargetMeanEncoding.NAME.getPreferredName();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FIELD_LENGTH.getPreferredName(), fieldLength);
builder.field(FEATURE_NAME_LENGTH.getPreferredName(), featureNameLength);
builder.field(FIELD_VALUE_LENGTHS.getPreferredName(), fieldValueLengths);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TargetMeanEncodingSize that = (TargetMeanEncodingSize) o;
return fieldLength == that.fieldLength &&
featureNameLength == that.featureNameLength &&
Arrays.equals(fieldValueLengths, that.fieldValueLengths);
}
@Override
public int hashCode() {
int result = Objects.hash(fieldLength, featureNameLength);
result = 31 * result + Arrays.hashCode(fieldValueLengths);
return result;
}
}

View File

@ -0,0 +1,13 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
interface TrainedModelSizeInfo extends Accountable, NamedXContentObject {
}

View File

@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel;
import java.io.IOException;
import java.util.Objects;
import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF;
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfDoubleArray;
public class TreeSizeInfo implements Accountable, ToXContentObject {
private static final ParseField NUM_NODES = new ParseField("num_nodes");
private static final ParseField NUM_LEAVES = new ParseField("num_leaves");
private static final ParseField NUM_CLASSES = new ParseField("num_classes");
static ConstructingObjectParser<TreeSizeInfo, Void> PARSER = new ConstructingObjectParser<>(
"tree_size",
false,
a -> new TreeSizeInfo((Integer)a[0], a[1] == null ? 0 : (Integer)a[1], a[2] == null ? 0 : (Integer)a[2])
);
static {
PARSER.declareInt(constructorArg(), NUM_LEAVES);
PARSER.declareInt(optionalConstructorArg(), NUM_NODES);
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES);
}
public static TreeSizeInfo fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final int numNodes;
private final int numLeaves;
private int numClasses;
TreeSizeInfo(int numLeaves, int numNodes, int numClasses) {
this.numLeaves = numLeaves;
this.numNodes = numNodes;
this.numClasses = numClasses;
}
public TreeSizeInfo setNumClasses(int numClasses) {
this.numClasses = numClasses;
return this;
}
@Override
public long ramBytesUsed() {
long size = TreeInferenceModel.SHALLOW_SIZE;
// Node shallow sizes, covers most information as elements are primitive
size += NUM_BYTES_ARRAY_HEADER + ((numLeaves + numNodes) * NUM_BYTES_OBJECT_REF);
size += numLeaves * TreeInferenceModel.LeafNode.SHALLOW_SIZE;
size += numNodes * TreeInferenceModel.InnerNode.SHALLOW_SIZE;
// This handles the values within the leaf value array
int numLeafVals = numClasses <= 2 ? 1 : numClasses;
size += sizeOfDoubleArray(numLeafVals) * numLeaves;
return alignObjectSize(size);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(NUM_LEAVES.getPreferredName(), numLeaves);
builder.field(NUM_NODES.getPreferredName(), numNodes);
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TreeSizeInfo treeSizeInfo = (TreeSizeInfo) o;
return numNodes == treeSizeInfo.numNodes &&
numLeaves == treeSizeInfo.numLeaves &&
numClasses == treeSizeInfo.numClasses;
}
@Override
public int hashCode() {
return Objects.hash(numNodes, numLeaves, numClasses);
}
}

View File

@ -59,7 +59,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
private static final String CONFIG_ID = "config-id"; private static final String CONFIG_ID = "config-id";
private static final int NUM_ROWS = 100; private static final int NUM_ROWS = 100;
private static final int NUM_COLS = 4; private static final int NUM_COLS = 4;
private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null); private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null, null);
private Client client; private Client client;
private DataFrameAnalyticsAuditor auditor; private DataFrameAnalyticsAuditor auditor;

View File

@ -106,8 +106,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
public void testProcess_GivenEmptyResults() { public void testProcess_GivenEmptyResults() {
givenDataFrameRows(2); givenDataFrameRows(2);
givenProcessResults(Arrays.asList( givenProcessResults(Arrays.asList(
new AnalyticsResult(null, null,null, null, null, null, null), new AnalyticsResult(null, null, null,null, null, null, null, null),
new AnalyticsResult(null, null, null, null, null, null, null))); new AnalyticsResult(null, null, null, null, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor(); AnalyticsResultProcessor resultProcessor = createResultProcessor();
resultProcessor.process(process); resultProcessor.process(process);
@ -122,8 +122,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
givenDataFrameRows(2); givenDataFrameRows(2);
RowResults rowResults1 = mock(RowResults.class); RowResults rowResults1 = mock(RowResults.class);
RowResults rowResults2 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class);
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null,null, null, null, null, null), givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null),
new AnalyticsResult(rowResults2, null, null, null, null, null, null))); new AnalyticsResult(rowResults2, null, null, null, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor(); AnalyticsResultProcessor resultProcessor = createResultProcessor();
resultProcessor.process(process); resultProcessor.process(process);
@ -140,8 +140,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
givenDataFrameRows(2); givenDataFrameRows(2);
RowResults rowResults1 = mock(RowResults.class); RowResults rowResults1 = mock(RowResults.class);
RowResults rowResults2 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class);
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null,null, null, null, null, null), givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null),
new AnalyticsResult(rowResults2, null, null, null, null, null, null))); new AnalyticsResult(rowResults2, null, null, null, null, null, null, null)));
doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class)); doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class));
@ -175,7 +175,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
extractedFieldList.add(new DocValueField("baz", Collections.emptySet())); extractedFieldList.add(new DocValueField("baz", Collections.emptySet()));
TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null))); givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList); AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList);
resultProcessor.process(process); resultProcessor.process(process);
@ -239,7 +239,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null))); givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor(); AnalyticsResultProcessor resultProcessor = createResultProcessor();
resultProcessor.process(process); resultProcessor.process(process);

View File

@ -24,6 +24,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@ -36,10 +39,10 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>(); List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent); return new NamedXContentRegistry(namedXContent);
} }
@Override
protected AnalyticsResult createTestInstance() { protected AnalyticsResult createTestInstance() {
RowResults rowResults = null; RowResults rowResults = null;
PhaseProgress phaseProgress = null; PhaseProgress phaseProgress = null;
@ -48,6 +51,7 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
OutlierDetectionStats outlierDetectionStats = null; OutlierDetectionStats outlierDetectionStats = null;
ClassificationStats classificationStats = null; ClassificationStats classificationStats = null;
RegressionStats regressionStats = null; RegressionStats regressionStats = null;
ModelSizeInfo modelSizeInfo = null;
if (randomBoolean()) { if (randomBoolean()) {
rowResults = RowResultsTests.createRandom(); rowResults = RowResultsTests.createRandom();
} }
@ -69,8 +73,11 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
if (randomBoolean()) { if (randomBoolean()) {
regressionStats = RegressionStatsTests.createRandom(); regressionStats = RegressionStatsTests.createRandom();
} }
if (randomBoolean()) {
modelSizeInfo = ModelSizeInfoTests.createRandom();
}
return new AnalyticsResult(rowResults, phaseProgress, inferenceModel, memoryUsage, outlierDetectionStats, return new AnalyticsResult(rowResults, phaseProgress, inferenceModel, memoryUsage, outlierDetectionStats,
classificationStats, regressionStats); classificationStats, regressionStats, modelSizeInfo);
} }
@Override @Override

View File

@ -0,0 +1,88 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModelTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class EnsembleSizeInfoTests extends SizeEstimatorTestCase<EnsembleSizeInfo, EnsembleInferenceModel> {
static EnsembleSizeInfo createRandom() {
return new EnsembleSizeInfo(
Stream.generate(TreeSizeInfoTests::createRandom).limit(randomIntBetween(1, 100)).collect(Collectors.toList()),
randomIntBetween(1, 10000),
Stream.generate(() -> randomIntBetween(1, 10)).limit(randomIntBetween(1, 10)).collect(Collectors.toList()),
randomIntBetween(0, 10),
randomIntBetween(0, 10),
randomIntBetween(0, 10)
);
}
static EnsembleSizeInfo translateToEstimate(EnsembleInferenceModel ensemble) {
TreeInferenceModel tree = (TreeInferenceModel)ensemble.getModels().get(0);
int numClasses = Arrays.stream(tree.getNodes())
.filter(TreeInferenceModel.Node::isLeaf)
.map(n -> (TreeInferenceModel.LeafNode)n)
.findFirst()
.get()
.getLeafValue()
.length;
return new EnsembleSizeInfo(
ensemble.getModels()
.stream()
.map(m -> TreeSizeInfoTests.translateToEstimate((TreeInferenceModel)m))
.collect(Collectors.toList()),
randomIntBetween(0, 10),
Arrays.stream(ensemble.getFeatureNames()).map(String::length).collect(Collectors.toList()),
ensemble.getOutputAggregator().expectedValueSize() == null ? 0 : ensemble.getOutputAggregator().expectedValueSize(),
ensemble.getClassificationWeights() == null ? 0 : ensemble.getClassificationWeights().length,
numClasses);
}
@Override
protected EnsembleSizeInfo createTestInstance() {
return createRandom();
}
@Override
protected EnsembleSizeInfo doParseInstance(XContentParser parser) {
return EnsembleSizeInfo.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return false;
}
@Override
EnsembleInferenceModel generateTrueObject() {
try {
Ensemble model = EnsembleTests.createRandom();
EnsembleInferenceModel inferenceModel = EnsembleInferenceModelTests.serializeFromTrainedModel(model);
inferenceModel.rewriteFeatureIndices(Collections.emptyMap());
return inferenceModel;
} catch (IOException ex) {
throw new ElasticsearchException(ex);
}
}
@Override
EnsembleSizeInfo translateObject(EnsembleInferenceModel originalObject) {
return translateToEstimate(originalObject);
}
}

View File

@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class FrequencyEncodingSizeTests extends SizeEstimatorTestCase<FrequencyEncodingSize, FrequencyEncoding> {
static FrequencyEncodingSize createRandom() {
return new FrequencyEncodingSize(randomInt(100),
randomInt(100),
Stream.generate(() -> randomIntBetween(5, 10))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList()));
}
static FrequencyEncodingSize translateToEstimate(FrequencyEncoding encoding) {
return new FrequencyEncodingSize(encoding.getField().length(),
encoding.getFeatureName().length(),
encoding.getFrequencyMap().keySet().stream().map(String::length).collect(Collectors.toList()));
}
@Override
protected FrequencyEncodingSize createTestInstance() {
return createRandom();
}
@Override
protected FrequencyEncodingSize doParseInstance(XContentParser parser) {
return FrequencyEncodingSize.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return false;
}
@Override
FrequencyEncoding generateTrueObject() {
return FrequencyEncodingTests.createRandom();
}
@Override
FrequencyEncodingSize translateObject(FrequencyEncoding originalObject) {
return translateToEstimate(originalObject);
}
}

View File

@ -0,0 +1,124 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class ModelSizeInfoTests extends AbstractXContentTestCase<ModelSizeInfo> {
public static ModelSizeInfo createRandom() {
return new ModelSizeInfo(EnsembleSizeInfoTests.createRandom(),
randomBoolean() ?
null :
Stream.generate(() -> randomFrom(
FrequencyEncodingSizeTests.createRandom(),
OneHotEncodingSizeTests.createRandom(),
TargetMeanEncodingSizeTests.createRandom()))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList()));
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}
@Override
protected ModelSizeInfo createTestInstance() {
return createRandom();
}
@Override
protected ModelSizeInfo doParseInstance(XContentParser parser) {
return ModelSizeInfo.PARSER.apply(parser, null);
}
@Override
protected boolean supportsUnknownFields() {
return false;
}
public void testParseDescribedFormat() throws IOException {
XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(FORMAT),
XContentType.JSON);
// Shouldn't throw
doParseInstance(parser);
}
private static final String FORMAT = "" +
"{\n" +
" \"trained_model_size\": {\n" +
" \"ensemble_model_size\": {\n" +
" \"tree_sizes\": [\n" +
" {\"num_nodes\": 7, \"num_leaves\": 8},\n" +
" {\"num_nodes\": 3, \"num_leaves\": 4},\n" +
" {\"num_leaves\": 1}\n" +
" ],\n" +
" \"input_field_name_lengths\": [\n" +
" 14,\n" +
" 10,\n" +
" 11\n" +
" ],\n" +
" \"num_output_processor_weights\": 3,\n" +
" \"num_classification_weights\": 0,\n" +
" \"num_classes\": 0,\n" +
" \"num_operations\": 3\n" +
" }\n" +
" },\n" +
" \"preprocessors\": [\n" +
" {\n" +
" \"one_hot_encoding\": {\n" +
" \"field_length\": 10,\n" +
" \"field_value_lengths\": [\n" +
" 10,\n" +
" 20\n" +
" ],\n" +
" \"feature_name_lengths\": [\n" +
" 15,\n" +
" 25\n" +
" ]\n" +
" }\n" +
" },\n" +
" {\n" +
" \"frequency_encoding\": {\n" +
" \"field_length\": 10,\n" +
" \"feature_name_length\": 5,\n" +
" \"field_value_lengths\": [\n" +
" 10,\n" +
" 20\n" +
" ]\n" +
" }\n" +
" },\n" +
" {\n" +
" \"target_mean_encoding\": {\n" +
" \"field_length\": 6,\n" +
" \"feature_name_length\": 15,\n" +
" \"field_value_lengths\": [\n" +
" 10,\n" +
" 20\n" +
" ]\n" +
" }\n" +
" }\n" +
" ]\n" +
"} ";
}

View File

@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class OneHotEncodingSizeTests extends SizeEstimatorTestCase<OneHotEncodingSize, OneHotEncoding> {
static OneHotEncodingSize createRandom() {
int numFieldEntries = randomIntBetween(1, 10);
return new OneHotEncodingSize(
randomInt(100),
Stream.generate(() -> randomIntBetween(5, 10))
.limit(numFieldEntries)
.collect(Collectors.toList()),
Stream.generate(() -> randomIntBetween(5, 10))
.limit(numFieldEntries)
.collect(Collectors.toList()));
}
static OneHotEncodingSize translateToEstimate(OneHotEncoding encoding) {
return new OneHotEncodingSize(encoding.getField().length(),
encoding.getHotMap().values().stream().map(String::length).collect(Collectors.toList()),
encoding.getHotMap().keySet().stream().map(String::length).collect(Collectors.toList()));
}
@Override
protected OneHotEncodingSize createTestInstance() {
return createRandom();
}
@Override
protected OneHotEncodingSize doParseInstance(XContentParser parser) {
return OneHotEncodingSize.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return false;
}
@Override
OneHotEncoding generateTrueObject() {
return OneHotEncodingTests.createRandom();
}
@Override
OneHotEncodingSize translateObject(OneHotEncoding originalObject) {
return translateToEstimate(originalObject);
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.test.AbstractXContentTestCase;
import static org.hamcrest.Matchers.is;
public abstract class SizeEstimatorTestCase<T extends ToXContentObject & Accountable, U extends Accountable>
extends AbstractXContentTestCase<T> {
abstract U generateTrueObject();
abstract T translateObject(U originalObject);
public void testRamUsageEstimationAccuracy() {
final long bytesEps = new ByteSizeValue(2, ByteSizeUnit.KB).getBytes();
for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) {
U obj = generateTrueObject();
T estimateObj = translateObject(obj);
long originalBytesUsed = obj.ramBytesUsed();
long estimateBytesUsed = estimateObj.ramBytesUsed();
// If we are over by 2kb that is small enough to not be a concern
boolean condition = (Math.abs(obj.ramBytesUsed() - estimateObj.ramBytesUsed()) < bytesEps) ||
// If the difference is greater than 2kb, it is better to have overestimated.
originalBytesUsed < estimateBytesUsed;
assertThat("estimation difference greater than 2048 and the estimation is too small. Object ["
+ obj.toString()
+ "] estimated ["
+ originalBytesUsed
+ "] translated object ["
+ estimateObj
+ "] estimated ["
+ estimateBytesUsed
+ "]" ,
condition,
is(true));
}
}
}

View File

@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class TargetMeanEncodingSizeTests extends SizeEstimatorTestCase<TargetMeanEncodingSize, TargetMeanEncoding> {
static TargetMeanEncodingSize createRandom() {
return new TargetMeanEncodingSize(randomInt(100),
randomInt(100),
Stream.generate(() -> randomIntBetween(5, 10))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList()));
}
static TargetMeanEncodingSize translateToEstimate(TargetMeanEncoding encoding) {
return new TargetMeanEncodingSize(encoding.getField().length(),
encoding.getFeatureName().length(),
encoding.getMeanMap().keySet().stream().map(String::length).collect(Collectors.toList()));
}
@Override
protected TargetMeanEncodingSize createTestInstance() {
return createRandom();
}
@Override
protected TargetMeanEncodingSize doParseInstance(XContentParser parser) {
return TargetMeanEncodingSize.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return false;
}
@Override
TargetMeanEncoding generateTrueObject() {
return TargetMeanEncodingTests.createRandom();
}
@Override
TargetMeanEncodingSize translateObject(TargetMeanEncoding originalObject) {
return translateToEstimate(originalObject);
}
}

View File

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.modelsize;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModelTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import java.io.IOException;
import java.util.Arrays;
public class TreeSizeInfoTests extends SizeEstimatorTestCase<TreeSizeInfo, TreeInferenceModel> {
static TreeSizeInfo createRandom() {
return new TreeSizeInfo(randomIntBetween(1, 100), randomIntBetween(0, 100), randomIntBetween(0, 10));
}
static TreeSizeInfo translateToEstimate(TreeInferenceModel tree) {
int numClasses = Arrays.stream(tree.getNodes())
.filter(TreeInferenceModel.Node::isLeaf)
.map(n -> (TreeInferenceModel.LeafNode)n)
.findFirst()
.get()
.getLeafValue()
.length;
return new TreeSizeInfo((int)Arrays.stream(tree.getNodes()).filter(TreeInferenceModel.Node::isLeaf).count(),
(int)Arrays.stream(tree.getNodes()).filter(t -> t.isLeaf() == false).count(),
numClasses);
}
@Override
protected TreeSizeInfo createTestInstance() {
return createRandom();
}
@Override
protected TreeSizeInfo doParseInstance(XContentParser parser) {
return TreeSizeInfo.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return false;
}
@Override
TreeInferenceModel generateTrueObject() {
try {
return TreeInferenceModelTests.serializeFromTrainedModel(
TreeTests.buildRandomTree(Arrays.asList(randomAlphaOfLength(10), randomAlphaOfLength(10)), 6)
);
} catch (IOException ex) {
throw new ElasticsearchException(ex);
}
}
@Override
TreeSizeInfo translateObject(TreeInferenceModel originalObject) {
return translateToEstimate(originalObject);
}
}