Adds support for reading in `model_size_info` objects. These objects contain numeric values indicating the model definition size and complexity. Additionally, these objects are not stored or serialized to any other node. They are to be used for calculating and storing model metadata. They are much smaller on heap than the true model definition and should help prevent the analytics process from using too much memory. Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
ffc3c77f75
commit
2881995a45
|
@ -27,7 +27,7 @@ import java.util.Objects;
|
|||
*/
|
||||
public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FrequencyEncoding.class);
|
||||
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FrequencyEncoding.class);
|
||||
|
||||
public static final ParseField NAME = new ParseField("frequency_encoding");
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
|
|
|
@ -27,7 +27,7 @@ import java.util.stream.Collectors;
|
|||
*/
|
||||
public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(OneHotEncoding.class);
|
||||
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(OneHotEncoding.class);
|
||||
public static final ParseField NAME = new ParseField("one_hot_encoding");
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField HOT_MAP = new ParseField("hot_map");
|
||||
|
|
|
@ -27,7 +27,7 @@ import java.util.Objects;
|
|||
*/
|
||||
public class TargetMeanEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TargetMeanEncoding.class);
|
||||
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TargetMeanEncoding.class);
|
||||
public static final ParseField NAME = new ParseField("target_mean_encoding");
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
|
|
|
@ -25,7 +25,7 @@ import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax
|
|||
|
||||
public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class);
|
||||
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class);
|
||||
public static final ParseField NAME = new ParseField("logistic_regression");
|
||||
public static final ParseField WEIGHTS = new ParseField("weights");
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.En
|
|||
|
||||
public class EnsembleInferenceModel implements InferenceModel {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
|
||||
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<EnsembleInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
|
||||
|
@ -275,4 +275,21 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
size += outputAggregator.ramBytesUsed();
|
||||
return size;
|
||||
}
|
||||
|
||||
public List<InferenceModel> getModels() {
|
||||
return models;
|
||||
}
|
||||
|
||||
public OutputAggregator getOutputAggregator() {
|
||||
return outputAggregator;
|
||||
}
|
||||
|
||||
public TargetType getTargetType() {
|
||||
return targetType;
|
||||
}
|
||||
|
||||
public double[] getClassificationWeights() {
|
||||
return classificationWeights;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition.T
|
|||
|
||||
public class InferenceDefinition {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InferenceDefinition.class);
|
||||
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InferenceDefinition.class);
|
||||
|
||||
public static final String NAME = "inference_model_definition";
|
||||
private final InferenceModel trainedModel;
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.Numbers;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
|
@ -27,13 +26,15 @@ import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
|
|||
import org.elasticsearch.xpack.core.ml.job.config.Operator;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
|
||||
import static org.apache.lucene.util.RamUsageEstimator.sizeOf;
|
||||
import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
|
||||
|
@ -53,7 +54,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNo
|
|||
|
||||
public class TreeInferenceModel implements InferenceModel {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TreeInferenceModel.class);
|
||||
public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<TreeInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
|
||||
|
@ -75,7 +76,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
private final Node[] nodes;
|
||||
private String[] featureNames;
|
||||
private final TargetType targetType;
|
||||
private final List<String> classificationLabels;
|
||||
private List<String> classificationLabels;
|
||||
private final double highOrderCategory;
|
||||
private final int maxDepth;
|
||||
private final int leafSize;
|
||||
|
@ -302,18 +303,16 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
treeNode.splitFeature = newSplitFeatureIndex;
|
||||
}
|
||||
this.featureNames = new String[0];
|
||||
// Since we are not top level, we no longer need local classification labels
|
||||
this.classificationLabels = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
size += RamUsageEstimator.sizeOfCollection(classificationLabels);
|
||||
size += RamUsageEstimator.sizeOf(featureNames);
|
||||
size += RamUsageEstimator.shallowSizeOf(nodes);
|
||||
for (Node node : nodes) {
|
||||
size += node.ramBytesUsed();
|
||||
}
|
||||
size += RamUsageEstimator.sizeOfCollection(Arrays.asList(nodes));
|
||||
size += sizeOfCollection(classificationLabels);
|
||||
size += sizeOf(featureNames);
|
||||
size += sizeOf(nodes);
|
||||
return size;
|
||||
}
|
||||
|
||||
|
@ -335,6 +334,10 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
return max;
|
||||
}
|
||||
|
||||
public Node[] getNodes() {
|
||||
return nodes;
|
||||
}
|
||||
|
||||
private static int getDepth(Node[] nodes, int nodeIndex, int depth) {
|
||||
Node node = nodes[nodeIndex];
|
||||
if (node instanceof LeafNode) {
|
||||
|
@ -433,21 +436,21 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
}
|
||||
}
|
||||
|
||||
private abstract static class Node implements Accountable {
|
||||
public abstract static class Node implements Accountable {
|
||||
int compare(double[] features) {
|
||||
throw new IllegalArgumentException("cannot call compare against a leaf node.");
|
||||
}
|
||||
|
||||
abstract long getNumberSamples();
|
||||
|
||||
boolean isLeaf() {
|
||||
public boolean isLeaf() {
|
||||
return this instanceof LeafNode;
|
||||
}
|
||||
}
|
||||
|
||||
private static class InnerNode extends Node {
|
||||
public static class InnerNode extends Node {
|
||||
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InnerNode.class);
|
||||
public static final long SHALLOW_SIZE = shallowSizeOfInstance(InnerNode.class);
|
||||
|
||||
private final Operator operator;
|
||||
private final double threshold;
|
||||
|
@ -498,8 +501,8 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
}
|
||||
}
|
||||
|
||||
private static class LeafNode extends Node {
|
||||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LeafNode.class);
|
||||
public static class LeafNode extends Node {
|
||||
public static final long SHALLOW_SIZE = shallowSizeOfInstance(LeafNode.class);
|
||||
private final double[] leafValue;
|
||||
private final long numberSamples;
|
||||
|
||||
|
@ -510,12 +513,16 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return SHALLOW_SIZE;
|
||||
return SHALLOW_SIZE + sizeOf(leafValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
long getNumberSamples() {
|
||||
return numberSamples;
|
||||
}
|
||||
|
||||
public double[] getLeafValue() {
|
||||
return leafValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -91,7 +91,7 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
.toArray() :
|
||||
null;
|
||||
|
||||
return new Ensemble(randomBoolean() ? featureNames : Collections.emptyList(),
|
||||
return new Ensemble(featureNames,
|
||||
models,
|
||||
outputAggregator,
|
||||
targetType,
|
||||
|
|
|
@ -28,6 +28,7 @@ import java.util.Map;
|
|||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
|
@ -35,6 +36,13 @@ public class TreeInferenceModelTests extends ESTestCase {
|
|||
|
||||
private final double eps = 1.0E-8;
|
||||
|
||||
public static TreeInferenceModel serializeFromTrainedModel(Tree tree) throws IOException {
|
||||
NamedXContentRegistry registry = new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
return deserializeFromTrainedModel(tree,
|
||||
registry,
|
||||
TreeInferenceModel::fromXContent);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
|
@ -48,7 +56,7 @@ public class TreeInferenceModelTests extends ESTestCase {
|
|||
builder.setFeatureNames(Collections.emptyList());
|
||||
|
||||
Tree treeObject = builder.build();
|
||||
TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject,
|
||||
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
|
||||
xContentRegistry(),
|
||||
TreeInferenceModel::fromXContent);
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
|
@ -71,7 +79,7 @@ public class TreeInferenceModelTests extends ESTestCase {
|
|||
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree treeObject = builder.setFeatureNames(featureNames).build();
|
||||
TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject,
|
||||
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
|
||||
xContentRegistry(),
|
||||
TreeInferenceModel::fromXContent);
|
||||
// This feature vector should hit the right child of the root node
|
||||
|
@ -126,7 +134,7 @@ public class TreeInferenceModelTests extends ESTestCase {
|
|||
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree treeObject = builder.setFeatureNames(featureNames).setClassificationLabels(Arrays.asList("cat", "dog")).build();
|
||||
TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject,
|
||||
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
|
||||
xContentRegistry(),
|
||||
TreeInferenceModel::fromXContent);
|
||||
double eps = 0.000001;
|
||||
|
@ -200,7 +208,7 @@ public class TreeInferenceModelTests extends ESTestCase {
|
|||
TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L),
|
||||
TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build();
|
||||
|
||||
TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject,
|
||||
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
|
||||
xContentRegistry(),
|
||||
TreeInferenceModel::fromXContent);
|
||||
|
||||
|
|
|
@ -219,6 +219,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimatio
|
|||
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.job.JobManager;
|
||||
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
|
||||
|
@ -1003,6 +1004,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
|
|||
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
|
||||
return namedXContent;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierD
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
@ -31,6 +32,7 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
|
||||
private static final ParseField PHASE_PROGRESS = new ParseField("phase_progress");
|
||||
private static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
|
||||
private static final ParseField MODEL_SIZE_INFO = new ParseField("model_size_info");
|
||||
private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage");
|
||||
private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats");
|
||||
private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats");
|
||||
|
@ -44,7 +46,8 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
(MemoryUsage) a[3],
|
||||
(OutlierDetectionStats) a[4],
|
||||
(ClassificationStats) a[5],
|
||||
(RegressionStats) a[6]
|
||||
(RegressionStats) a[6],
|
||||
(ModelSizeInfo) a[7]
|
||||
));
|
||||
|
||||
static {
|
||||
|
@ -56,6 +59,7 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS);
|
||||
PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS);
|
||||
PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS);
|
||||
PARSER.declareObject(optionalConstructorArg(), ModelSizeInfo.PARSER, MODEL_SIZE_INFO);
|
||||
}
|
||||
|
||||
private final RowResults rowResults;
|
||||
|
@ -66,6 +70,7 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
private final OutlierDetectionStats outlierDetectionStats;
|
||||
private final ClassificationStats classificationStats;
|
||||
private final RegressionStats regressionStats;
|
||||
private final ModelSizeInfo modelSizeInfo;
|
||||
|
||||
public AnalyticsResult(@Nullable RowResults rowResults,
|
||||
@Nullable PhaseProgress phaseProgress,
|
||||
|
@ -73,7 +78,8 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
@Nullable MemoryUsage memoryUsage,
|
||||
@Nullable OutlierDetectionStats outlierDetectionStats,
|
||||
@Nullable ClassificationStats classificationStats,
|
||||
@Nullable RegressionStats regressionStats) {
|
||||
@Nullable RegressionStats regressionStats,
|
||||
@Nullable ModelSizeInfo modelSizeInfo) {
|
||||
this.rowResults = rowResults;
|
||||
this.phaseProgress = phaseProgress;
|
||||
this.inferenceModelBuilder = inferenceModelBuilder;
|
||||
|
@ -82,6 +88,7 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
this.outlierDetectionStats = outlierDetectionStats;
|
||||
this.classificationStats = classificationStats;
|
||||
this.regressionStats = regressionStats;
|
||||
this.modelSizeInfo = modelSizeInfo;
|
||||
}
|
||||
|
||||
public RowResults getRowResults() {
|
||||
|
@ -112,6 +119,10 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
return regressionStats;
|
||||
}
|
||||
|
||||
public ModelSizeInfo getModelSizeInfo() {
|
||||
return modelSizeInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
@ -138,6 +149,9 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
if (regressionStats != null) {
|
||||
builder.field(REGRESSION_STATS.getPreferredName(), regressionStats, params);
|
||||
}
|
||||
if (modelSizeInfo != null) {
|
||||
builder.field(MODEL_SIZE_INFO.getPreferredName(), modelSizeInfo);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -158,6 +172,7 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
&& Objects.equals(memoryUsage, that.memoryUsage)
|
||||
&& Objects.equals(outlierDetectionStats, that.outlierDetectionStats)
|
||||
&& Objects.equals(classificationStats, that.classificationStats)
|
||||
&& Objects.equals(modelSizeInfo, that.modelSizeInfo)
|
||||
&& Objects.equals(regressionStats, that.regressionStats);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModel;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
|
||||
import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfDoubleArray;
|
||||
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfStringCollection;
|
||||
|
||||
public class EnsembleSizeInfo implements TrainedModelSizeInfo {
|
||||
|
||||
public static final ParseField NAME = new ParseField("ensemble_model_size");
|
||||
private static final ParseField TREE_SIZES = new ParseField("tree_sizes");
|
||||
private static final ParseField INPUT_FIELD_NAME_LENGHTS = new ParseField("input_field_name_lengths");
|
||||
private static final ParseField NUM_OUTPUT_PROCESSOR_WEIGHTS = new ParseField("num_output_processor_weights");
|
||||
private static final ParseField NUM_CLASSIFICATION_WEIGHTS = new ParseField("num_classification_weights");
|
||||
private static final ParseField NUM_OPERATIONS = new ParseField("num_operations");
|
||||
private static final ParseField NUM_CLASSES = new ParseField("num_classes");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
static ConstructingObjectParser<EnsembleSizeInfo, Void> PARSER = new ConstructingObjectParser<>(
|
||||
"ensemble_size",
|
||||
false,
|
||||
a -> new EnsembleSizeInfo((List<TreeSizeInfo>)a[0],
|
||||
(Integer)a[1],
|
||||
(List<Integer>)a[2],
|
||||
a[3] == null ? 0 : (Integer)a[3],
|
||||
a[4] == null ? 0 : (Integer)a[4],
|
||||
a[5] == null ? 0 : (Integer)a[5])
|
||||
);
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), TreeSizeInfo.PARSER::apply, TREE_SIZES);
|
||||
PARSER.declareInt(constructorArg(), NUM_OPERATIONS);
|
||||
PARSER.declareIntArray(constructorArg(), INPUT_FIELD_NAME_LENGHTS);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_OUTPUT_PROCESSOR_WEIGHTS);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSIFICATION_WEIGHTS);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES);
|
||||
}
|
||||
|
||||
public static EnsembleSizeInfo fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
|
||||
private final List<TreeSizeInfo> treeSizeInfos;
|
||||
private final int numOperations;
|
||||
private final int[] inputFieldNameLengths;
|
||||
private final int numOutputProcessorWeights;
|
||||
private final int numClassificationWeights;
|
||||
private final int numClasses;
|
||||
|
||||
public EnsembleSizeInfo(List<TreeSizeInfo> treeSizeInfos,
|
||||
int numOperations,
|
||||
List<Integer> inputFieldNameLengths,
|
||||
int numOutputProcessorWeights,
|
||||
int numClassificationWeights,
|
||||
int numClasses) {
|
||||
this.treeSizeInfos = treeSizeInfos;
|
||||
this.numOperations = numOperations;
|
||||
this.inputFieldNameLengths = inputFieldNameLengths.stream().mapToInt(Integer::intValue).toArray();
|
||||
this.numOutputProcessorWeights = numOutputProcessorWeights;
|
||||
this.numClassificationWeights = numClassificationWeights;
|
||||
this.numClasses = numClasses;
|
||||
}
|
||||
|
||||
public int getNumOperations() {
|
||||
return numOperations;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = EnsembleInferenceModel.SHALLOW_SIZE;
|
||||
treeSizeInfos.forEach(t -> t.setNumClasses(numClasses).ramBytesUsed());
|
||||
size += sizeOfCollection(treeSizeInfos);
|
||||
size += sizeOfStringCollection(inputFieldNameLengths);
|
||||
size += LogisticRegression.SHALLOW_SIZE + sizeOfDoubleArray(numOutputProcessorWeights);
|
||||
size += sizeOfDoubleArray(numClassificationWeights);
|
||||
return alignObjectSize(size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(TREE_SIZES.getPreferredName(), treeSizeInfos);
|
||||
builder.field(NUM_OPERATIONS.getPreferredName(), numOperations);
|
||||
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
|
||||
builder.field(INPUT_FIELD_NAME_LENGHTS.getPreferredName(), inputFieldNameLengths);
|
||||
builder.field(NUM_CLASSIFICATION_WEIGHTS.getPreferredName(), numClassificationWeights);
|
||||
builder.field(NUM_OUTPUT_PROCESSOR_WEIGHTS.getPreferredName(), numOutputProcessorWeights);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
EnsembleSizeInfo that = (EnsembleSizeInfo) o;
|
||||
return numOperations == that.numOperations &&
|
||||
numOutputProcessorWeights == that.numOutputProcessorWeights &&
|
||||
numClassificationWeights == that.numClassificationWeights &&
|
||||
numClasses == that.numClasses &&
|
||||
Objects.equals(treeSizeInfos, that.treeSizeInfos) &&
|
||||
Arrays.equals(inputFieldNameLengths, that.inputFieldNameLengths);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = Objects.hash(treeSizeInfos, numOperations, numOutputProcessorWeights, numClassificationWeights, numClasses);
|
||||
result = 31 * result + Arrays.hashCode(inputFieldNameLengths);
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
|
||||
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfHashMap;
|
||||
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfString;
|
||||
|
||||
public class FrequencyEncodingSize implements PreprocessorSize {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<FrequencyEncodingSize, Void> PARSER = new ConstructingObjectParser<>(
|
||||
"frequency_encoding_size",
|
||||
false,
|
||||
a -> new FrequencyEncodingSize((Integer)a[0], (Integer)a[1], (List<Integer>)a[2])
|
||||
);
|
||||
static {
|
||||
PARSER.declareInt(constructorArg(), FIELD_LENGTH);
|
||||
PARSER.declareInt(constructorArg(), FEATURE_NAME_LENGTH);
|
||||
PARSER.declareIntArray(constructorArg(), FIELD_VALUE_LENGTHS);
|
||||
}
|
||||
|
||||
public static FrequencyEncodingSize fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final int fieldLength;
|
||||
private final int featureNameLength;
|
||||
private final int[] fieldValueLengths;
|
||||
|
||||
FrequencyEncodingSize(int fieldLength, int featureNameLength, List<Integer> fieldValueLengths) {
|
||||
this.fieldLength = fieldLength;
|
||||
this.featureNameLength = featureNameLength;
|
||||
this.fieldValueLengths = fieldValueLengths.stream().mapToInt(Integer::intValue).toArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
final long sizeOfDoubleObject = shallowSizeOfInstance(Double.class);
|
||||
long size = FrequencyEncoding.SHALLOW_SIZE;
|
||||
size += sizeOfString(fieldLength);
|
||||
size += sizeOfString(featureNameLength);
|
||||
size += sizeOfHashMap(
|
||||
Arrays.stream(fieldValueLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()),
|
||||
Stream.generate(() -> sizeOfDoubleObject).limit(fieldValueLengths.length).collect(Collectors.toList()));
|
||||
return alignObjectSize(size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return FrequencyEncoding.NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD_LENGTH.getPreferredName(), fieldLength);
|
||||
builder.field(FEATURE_NAME_LENGTH.getPreferredName(), featureNameLength);
|
||||
builder.field(FIELD_VALUE_LENGTHS.getPreferredName(), fieldValueLengths);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
FrequencyEncodingSize that = (FrequencyEncodingSize) o;
|
||||
return fieldLength == that.fieldLength &&
|
||||
featureNameLength == that.featureNameLength &&
|
||||
Arrays.equals(fieldValueLengths, that.fieldValueLengths);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = Objects.hash(fieldLength, featureNameLength);
|
||||
result = 31 * result + Arrays.hashCode(fieldValueLengths);
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class MlModelSizeNamedXContentProvider implements NamedXContentProvider {
|
||||
@Override
|
||||
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
|
||||
return Arrays.asList(
|
||||
new NamedXContentRegistry.Entry(PreprocessorSize.class,
|
||||
FrequencyEncoding.NAME,
|
||||
FrequencyEncodingSize::fromXContent),
|
||||
new NamedXContentRegistry.Entry(PreprocessorSize.class,
|
||||
OneHotEncoding.NAME,
|
||||
OneHotEncodingSize::fromXContent),
|
||||
new NamedXContentRegistry.Entry(PreprocessorSize.class,
|
||||
TargetMeanEncoding.NAME,
|
||||
TargetMeanEncodingSize::fromXContent),
|
||||
new NamedXContentRegistry.Entry(TrainedModelSizeInfo.class,
|
||||
EnsembleSizeInfo.NAME,
|
||||
EnsembleSizeInfo::fromXContent)
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class ModelSizeInfo implements Accountable, ToXContentObject {
|
||||
|
||||
private static final ParseField PREPROCESSORS = new ParseField("preprocessors");
|
||||
private static final ParseField TRAINED_MODEL_SIZE = new ParseField("trained_model_size");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static ConstructingObjectParser<ModelSizeInfo, Void> PARSER = new ConstructingObjectParser<>(
|
||||
"model_size",
|
||||
false,
|
||||
a -> new ModelSizeInfo((EnsembleSizeInfo)a[0], (List<PreprocessorSize>)a[1])
|
||||
);
|
||||
static {
|
||||
PARSER.declareNamedObject(constructorArg(),
|
||||
(p, c, n) -> p.namedObject(TrainedModelSizeInfo.class, n, null),
|
||||
TRAINED_MODEL_SIZE);
|
||||
PARSER.declareNamedObjects(optionalConstructorArg(),
|
||||
(p, c, n) -> p.namedObject(PreprocessorSize.class, n, null),
|
||||
(val) -> {},
|
||||
PREPROCESSORS);
|
||||
}
|
||||
|
||||
private final EnsembleSizeInfo ensembleSizeInfo;
|
||||
private final List<PreprocessorSize> preprocessorSizes;
|
||||
|
||||
public ModelSizeInfo(EnsembleSizeInfo ensembleSizeInfo, List<PreprocessorSize> preprocessorSizes) {
|
||||
this.ensembleSizeInfo = ensembleSizeInfo;
|
||||
this.preprocessorSizes = preprocessorSizes == null ? Collections.emptyList() : preprocessorSizes;
|
||||
}
|
||||
|
||||
public int numOperations() {
|
||||
return this.preprocessorSizes.size() + this.ensembleSizeInfo.getNumOperations();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = InferenceDefinition.SHALLOW_SIZE;
|
||||
size += ensembleSizeInfo.ramBytesUsed();
|
||||
size += preprocessorSizes.stream().mapToLong(PreprocessorSize::ramBytesUsed).sum();
|
||||
return alignObjectSize(size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
NamedXContentObjectHelper.writeNamedObject(builder, params, TRAINED_MODEL_SIZE.getPreferredName(), ensembleSizeInfo);
|
||||
if (preprocessorSizes.size() > 0) {
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, PREPROCESSORS.getPreferredName(), preprocessorSizes);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ModelSizeInfo modelSizeInfo = (ModelSizeInfo) o;
|
||||
return Objects.equals(ensembleSizeInfo, modelSizeInfo.ensembleSizeInfo) &&
|
||||
Objects.equals(preprocessorSizes, modelSizeInfo.preprocessorSizes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(ensembleSizeInfo, preprocessorSizes);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfHashMap;
|
||||
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfString;
|
||||
|
||||
public class OneHotEncodingSize implements PreprocessorSize {
|
||||
|
||||
private static final ParseField FEATURE_NAME_LENGTHS = new ParseField("feature_name_lengths");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<OneHotEncodingSize, Void> PARSER = new ConstructingObjectParser<>(
|
||||
"one_hot_encoding_size",
|
||||
false,
|
||||
a -> new OneHotEncodingSize((Integer)a[0], (List<Integer>)a[1], (List<Integer>)a[2])
|
||||
);
|
||||
static {
|
||||
PARSER.declareInt(constructorArg(), FIELD_LENGTH);
|
||||
PARSER.declareIntArray(constructorArg(), FEATURE_NAME_LENGTHS);
|
||||
PARSER.declareIntArray(constructorArg(), FIELD_VALUE_LENGTHS);
|
||||
}
|
||||
|
||||
public static OneHotEncodingSize fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final int fieldLength;
|
||||
private final int[] featureNameLengths;
|
||||
private final int[] fieldValueLengths;
|
||||
|
||||
OneHotEncodingSize(int fieldLength, List<Integer> featureNameLengths, List<Integer> fieldValueLengths) {
|
||||
assert featureNameLengths.size() == fieldValueLengths.size();
|
||||
this.fieldLength = fieldLength;
|
||||
this.featureNameLengths = featureNameLengths.stream().mapToInt(Integer::intValue).toArray();
|
||||
this.fieldValueLengths = fieldValueLengths.stream().mapToInt(Integer::intValue).toArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = OneHotEncoding.SHALLOW_SIZE;
|
||||
size += sizeOfString(fieldLength);
|
||||
size += sizeOfHashMap(
|
||||
Arrays.stream(fieldValueLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()),
|
||||
Arrays.stream(featureNameLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList())
|
||||
);
|
||||
return alignObjectSize(size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return OneHotEncoding.NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD_LENGTH.getPreferredName(), fieldLength);
|
||||
builder.field(FEATURE_NAME_LENGTHS.getPreferredName(), featureNameLengths);
|
||||
builder.field(FIELD_VALUE_LENGTHS.getPreferredName(), fieldValueLengths);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
OneHotEncodingSize that = (OneHotEncodingSize) o;
|
||||
return fieldLength == that.fieldLength &&
|
||||
Arrays.equals(featureNameLengths, that.featureNameLengths) &&
|
||||
Arrays.equals(fieldValueLengths, that.fieldValueLengths);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = Objects.hash(fieldLength);
|
||||
result = 31 * result + Arrays.hashCode(featureNameLengths);
|
||||
result = 31 * result + Arrays.hashCode(fieldValueLengths);
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
|
||||
public interface PreprocessorSize extends Accountable, NamedXContentObject {
|
||||
ParseField FIELD_LENGTH = new ParseField("field_length");
|
||||
ParseField FEATURE_NAME_LENGTH = new ParseField("feature_name_length");
|
||||
ParseField FIELD_VALUE_LENGTHS = new ParseField("field_value_lengths");
|
||||
|
||||
String getName();
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
|
||||
import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF;
|
||||
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
|
||||
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
|
||||
|
||||
final class SizeEstimatorHelper {
|
||||
|
||||
private SizeEstimatorHelper() {}
|
||||
|
||||
private static final int STRING_SIZE = (int) shallowSizeOfInstance(String.class);
|
||||
|
||||
static long sizeOfString(int stringLength) {
|
||||
// Technically, each value counted in a String.length is 2 bytes. But, this is how `RamUsageEstimator` calculates it
|
||||
return alignObjectSize(STRING_SIZE + (long)NUM_BYTES_ARRAY_HEADER + (long)(Character.BYTES) * stringLength);
|
||||
}
|
||||
|
||||
static long sizeOfStringCollection(int[] stringSizes) {
|
||||
long shallow = alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) NUM_BYTES_OBJECT_REF * stringSizes.length);
|
||||
return shallow + Arrays.stream(stringSizes).mapToLong(SizeEstimatorHelper::sizeOfString).sum();
|
||||
}
|
||||
|
||||
static long sizeOfDoubleArray(int arrayLength) {
|
||||
return alignObjectSize((long) NUM_BYTES_ARRAY_HEADER + (long) Double.BYTES * arrayLength);
|
||||
}
|
||||
|
||||
static long sizeOfHashMap(List<Long> sizeOfKeys, List<Long> sizeOfValues) {
|
||||
assert sizeOfKeys.size() == sizeOfValues.size();
|
||||
long mapsize = shallowSizeOfInstance(HashMap.class);
|
||||
final long mapEntrySize = shallowSizeOfInstance(HashMap.Entry.class);
|
||||
mapsize += Stream.concat(sizeOfKeys.stream(), sizeOfValues.stream()).mapToLong(Long::longValue).sum();
|
||||
mapsize += mapEntrySize * sizeOfKeys.size();
|
||||
mapsize += mapEntrySize * sizeOfValues.size();
|
||||
return mapsize;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
|
||||
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfHashMap;
|
||||
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfString;
|
||||
|
||||
public class TargetMeanEncodingSize implements PreprocessorSize {
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<TargetMeanEncodingSize, Void> PARSER = new ConstructingObjectParser<>(
|
||||
"target_mean_encoding_size",
|
||||
false,
|
||||
a -> new TargetMeanEncodingSize((Integer)a[0], (Integer)a[1], (List<Integer>)a[2])
|
||||
);
|
||||
static {
|
||||
PARSER.declareInt(constructorArg(), FIELD_LENGTH);
|
||||
PARSER.declareInt(constructorArg(), FEATURE_NAME_LENGTH);
|
||||
PARSER.declareIntArray(constructorArg(), FIELD_VALUE_LENGTHS);
|
||||
}
|
||||
|
||||
public static TargetMeanEncodingSize fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final int fieldLength;
|
||||
private final int featureNameLength;
|
||||
private final int[] fieldValueLengths;
|
||||
|
||||
TargetMeanEncodingSize(int fieldLength, int featureNameLength, List<Integer> fieldValueLengths) {
|
||||
this.fieldLength = fieldLength;
|
||||
this.featureNameLength = featureNameLength;
|
||||
this.fieldValueLengths = fieldValueLengths.stream().mapToInt(Integer::intValue).toArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
final long sizeOfDoubleObject = shallowSizeOfInstance(Double.class);
|
||||
long size = TargetMeanEncoding.SHALLOW_SIZE;
|
||||
size += sizeOfString(fieldLength);
|
||||
size += sizeOfString(featureNameLength);
|
||||
size += sizeOfHashMap(
|
||||
Arrays.stream(fieldValueLengths).mapToLong(SizeEstimatorHelper::sizeOfString).boxed().collect(Collectors.toList()),
|
||||
Stream.generate(() -> sizeOfDoubleObject).limit(fieldValueLengths.length).collect(Collectors.toList())
|
||||
);
|
||||
return alignObjectSize(size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return TargetMeanEncoding.NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD_LENGTH.getPreferredName(), fieldLength);
|
||||
builder.field(FEATURE_NAME_LENGTH.getPreferredName(), featureNameLength);
|
||||
builder.field(FIELD_VALUE_LENGTHS.getPreferredName(), fieldValueLengths);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TargetMeanEncodingSize that = (TargetMeanEncodingSize) o;
|
||||
return fieldLength == that.fieldLength &&
|
||||
featureNameLength == that.featureNameLength &&
|
||||
Arrays.equals(fieldValueLengths, that.fieldValueLengths);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = Objects.hash(fieldLength, featureNameLength);
|
||||
result = 31 * result + Arrays.hashCode(fieldValueLengths);
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
interface TrainedModelSizeInfo extends Accountable, NamedXContentObject {
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel;
|
||||
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
|
||||
import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF;
|
||||
import static org.apache.lucene.util.RamUsageEstimator.alignObjectSize;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.ml.inference.modelsize.SizeEstimatorHelper.sizeOfDoubleArray;
|
||||
|
||||
public class TreeSizeInfo implements Accountable, ToXContentObject {
|
||||
|
||||
private static final ParseField NUM_NODES = new ParseField("num_nodes");
|
||||
private static final ParseField NUM_LEAVES = new ParseField("num_leaves");
|
||||
private static final ParseField NUM_CLASSES = new ParseField("num_classes");
|
||||
|
||||
static ConstructingObjectParser<TreeSizeInfo, Void> PARSER = new ConstructingObjectParser<>(
|
||||
"tree_size",
|
||||
false,
|
||||
a -> new TreeSizeInfo((Integer)a[0], a[1] == null ? 0 : (Integer)a[1], a[2] == null ? 0 : (Integer)a[2])
|
||||
);
|
||||
static {
|
||||
PARSER.declareInt(constructorArg(), NUM_LEAVES);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_NODES);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES);
|
||||
}
|
||||
|
||||
public static TreeSizeInfo fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final int numNodes;
|
||||
private final int numLeaves;
|
||||
private int numClasses;
|
||||
|
||||
TreeSizeInfo(int numLeaves, int numNodes, int numClasses) {
|
||||
this.numLeaves = numLeaves;
|
||||
this.numNodes = numNodes;
|
||||
this.numClasses = numClasses;
|
||||
}
|
||||
|
||||
public TreeSizeInfo setNumClasses(int numClasses) {
|
||||
this.numClasses = numClasses;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = TreeInferenceModel.SHALLOW_SIZE;
|
||||
// Node shallow sizes, covers most information as elements are primitive
|
||||
size += NUM_BYTES_ARRAY_HEADER + ((numLeaves + numNodes) * NUM_BYTES_OBJECT_REF);
|
||||
size += numLeaves * TreeInferenceModel.LeafNode.SHALLOW_SIZE;
|
||||
size += numNodes * TreeInferenceModel.InnerNode.SHALLOW_SIZE;
|
||||
// This handles the values within the leaf value array
|
||||
int numLeafVals = numClasses <= 2 ? 1 : numClasses;
|
||||
size += sizeOfDoubleArray(numLeafVals) * numLeaves;
|
||||
return alignObjectSize(size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(NUM_LEAVES.getPreferredName(), numLeaves);
|
||||
builder.field(NUM_NODES.getPreferredName(), numNodes);
|
||||
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TreeSizeInfo treeSizeInfo = (TreeSizeInfo) o;
|
||||
return numNodes == treeSizeInfo.numNodes &&
|
||||
numLeaves == treeSizeInfo.numLeaves &&
|
||||
numClasses == treeSizeInfo.numClasses;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(numNodes, numLeaves, numClasses);
|
||||
}
|
||||
}
|
|
@ -59,7 +59,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
|
|||
private static final String CONFIG_ID = "config-id";
|
||||
private static final int NUM_ROWS = 100;
|
||||
private static final int NUM_COLS = 4;
|
||||
private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null);
|
||||
private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null, null);
|
||||
|
||||
private Client client;
|
||||
private DataFrameAnalyticsAuditor auditor;
|
||||
|
|
|
@ -106,8 +106,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
public void testProcess_GivenEmptyResults() {
|
||||
givenDataFrameRows(2);
|
||||
givenProcessResults(Arrays.asList(
|
||||
new AnalyticsResult(null, null,null, null, null, null, null),
|
||||
new AnalyticsResult(null, null, null, null, null, null, null)));
|
||||
new AnalyticsResult(null, null, null,null, null, null, null, null),
|
||||
new AnalyticsResult(null, null, null, null, null, null, null, null)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
||||
|
||||
resultProcessor.process(process);
|
||||
|
@ -122,8 +122,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
givenDataFrameRows(2);
|
||||
RowResults rowResults1 = mock(RowResults.class);
|
||||
RowResults rowResults2 = mock(RowResults.class);
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null,null, null, null, null, null),
|
||||
new AnalyticsResult(rowResults2, null, null, null, null, null, null)));
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null),
|
||||
new AnalyticsResult(rowResults2, null, null, null, null, null, null, null)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
||||
|
||||
resultProcessor.process(process);
|
||||
|
@ -140,8 +140,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
givenDataFrameRows(2);
|
||||
RowResults rowResults1 = mock(RowResults.class);
|
||||
RowResults rowResults2 = mock(RowResults.class);
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null,null, null, null, null, null),
|
||||
new AnalyticsResult(rowResults2, null, null, null, null, null, null)));
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null),
|
||||
new AnalyticsResult(rowResults2, null, null, null, null, null, null, null)));
|
||||
|
||||
doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class));
|
||||
|
||||
|
@ -175,7 +175,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
extractedFieldList.add(new DocValueField("baz", Collections.emptySet()));
|
||||
TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
|
||||
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null)));
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList);
|
||||
|
||||
resultProcessor.process(process);
|
||||
|
@ -239,7 +239,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
|
||||
TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
|
||||
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null)));
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
||||
|
||||
resultProcessor.process(process);
|
||||
|
|
|
@ -24,6 +24,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
|||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
|
@ -36,10 +39,10 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
|
|||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AnalyticsResult createTestInstance() {
|
||||
RowResults rowResults = null;
|
||||
PhaseProgress phaseProgress = null;
|
||||
|
@ -48,6 +51,7 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
|
|||
OutlierDetectionStats outlierDetectionStats = null;
|
||||
ClassificationStats classificationStats = null;
|
||||
RegressionStats regressionStats = null;
|
||||
ModelSizeInfo modelSizeInfo = null;
|
||||
if (randomBoolean()) {
|
||||
rowResults = RowResultsTests.createRandom();
|
||||
}
|
||||
|
@ -69,8 +73,11 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
|
|||
if (randomBoolean()) {
|
||||
regressionStats = RegressionStatsTests.createRandom();
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
modelSizeInfo = ModelSizeInfoTests.createRandom();
|
||||
}
|
||||
return new AnalyticsResult(rowResults, phaseProgress, inferenceModel, memoryUsage, outlierDetectionStats,
|
||||
classificationStats, regressionStats);
|
||||
classificationStats, regressionStats, modelSizeInfo);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModelTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class EnsembleSizeInfoTests extends SizeEstimatorTestCase<EnsembleSizeInfo, EnsembleInferenceModel> {
|
||||
|
||||
static EnsembleSizeInfo createRandom() {
|
||||
return new EnsembleSizeInfo(
|
||||
Stream.generate(TreeSizeInfoTests::createRandom).limit(randomIntBetween(1, 100)).collect(Collectors.toList()),
|
||||
randomIntBetween(1, 10000),
|
||||
Stream.generate(() -> randomIntBetween(1, 10)).limit(randomIntBetween(1, 10)).collect(Collectors.toList()),
|
||||
randomIntBetween(0, 10),
|
||||
randomIntBetween(0, 10),
|
||||
randomIntBetween(0, 10)
|
||||
);
|
||||
}
|
||||
|
||||
static EnsembleSizeInfo translateToEstimate(EnsembleInferenceModel ensemble) {
|
||||
TreeInferenceModel tree = (TreeInferenceModel)ensemble.getModels().get(0);
|
||||
int numClasses = Arrays.stream(tree.getNodes())
|
||||
.filter(TreeInferenceModel.Node::isLeaf)
|
||||
.map(n -> (TreeInferenceModel.LeafNode)n)
|
||||
.findFirst()
|
||||
.get()
|
||||
.getLeafValue()
|
||||
.length;
|
||||
return new EnsembleSizeInfo(
|
||||
ensemble.getModels()
|
||||
.stream()
|
||||
.map(m -> TreeSizeInfoTests.translateToEstimate((TreeInferenceModel)m))
|
||||
.collect(Collectors.toList()),
|
||||
randomIntBetween(0, 10),
|
||||
Arrays.stream(ensemble.getFeatureNames()).map(String::length).collect(Collectors.toList()),
|
||||
ensemble.getOutputAggregator().expectedValueSize() == null ? 0 : ensemble.getOutputAggregator().expectedValueSize(),
|
||||
ensemble.getClassificationWeights() == null ? 0 : ensemble.getClassificationWeights().length,
|
||||
numClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EnsembleSizeInfo createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected EnsembleSizeInfo doParseInstance(XContentParser parser) {
|
||||
return EnsembleSizeInfo.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
EnsembleInferenceModel generateTrueObject() {
|
||||
try {
|
||||
Ensemble model = EnsembleTests.createRandom();
|
||||
EnsembleInferenceModel inferenceModel = EnsembleInferenceModelTests.serializeFromTrainedModel(model);
|
||||
inferenceModel.rewriteFeatureIndices(Collections.emptyMap());
|
||||
return inferenceModel;
|
||||
} catch (IOException ex) {
|
||||
throw new ElasticsearchException(ex);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
EnsembleSizeInfo translateObject(EnsembleInferenceModel originalObject) {
|
||||
return translateToEstimate(originalObject);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class FrequencyEncodingSizeTests extends SizeEstimatorTestCase<FrequencyEncodingSize, FrequencyEncoding> {
|
||||
|
||||
static FrequencyEncodingSize createRandom() {
|
||||
return new FrequencyEncodingSize(randomInt(100),
|
||||
randomInt(100),
|
||||
Stream.generate(() -> randomIntBetween(5, 10))
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
static FrequencyEncodingSize translateToEstimate(FrequencyEncoding encoding) {
|
||||
return new FrequencyEncodingSize(encoding.getField().length(),
|
||||
encoding.getFeatureName().length(),
|
||||
encoding.getFrequencyMap().keySet().stream().map(String::length).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FrequencyEncodingSize createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FrequencyEncodingSize doParseInstance(XContentParser parser) {
|
||||
return FrequencyEncodingSize.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
FrequencyEncoding generateTrueObject() {
|
||||
return FrequencyEncodingTests.createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
FrequencyEncodingSize translateObject(FrequencyEncoding originalObject) {
|
||||
return translateToEstimate(originalObject);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,124 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.xcontent.DeprecationHandler;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class ModelSizeInfoTests extends AbstractXContentTestCase<ModelSizeInfo> {
|
||||
|
||||
public static ModelSizeInfo createRandom() {
|
||||
return new ModelSizeInfo(EnsembleSizeInfoTests.createRandom(),
|
||||
randomBoolean() ?
|
||||
null :
|
||||
Stream.generate(() -> randomFrom(
|
||||
FrequencyEncodingSizeTests.createRandom(),
|
||||
OneHotEncodingSizeTests.createRandom(),
|
||||
TargetMeanEncodingSizeTests.createRandom()))
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ModelSizeInfo createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ModelSizeInfo doParseInstance(XContentParser parser) {
|
||||
return ModelSizeInfo.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return false;
|
||||
}
|
||||
|
||||
public void testParseDescribedFormat() throws IOException {
|
||||
XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
||||
new BytesArray(FORMAT),
|
||||
XContentType.JSON);
|
||||
// Shouldn't throw
|
||||
doParseInstance(parser);
|
||||
}
|
||||
|
||||
private static final String FORMAT = "" +
|
||||
"{\n" +
|
||||
" \"trained_model_size\": {\n" +
|
||||
" \"ensemble_model_size\": {\n" +
|
||||
" \"tree_sizes\": [\n" +
|
||||
" {\"num_nodes\": 7, \"num_leaves\": 8},\n" +
|
||||
" {\"num_nodes\": 3, \"num_leaves\": 4},\n" +
|
||||
" {\"num_leaves\": 1}\n" +
|
||||
" ],\n" +
|
||||
" \"input_field_name_lengths\": [\n" +
|
||||
" 14,\n" +
|
||||
" 10,\n" +
|
||||
" 11\n" +
|
||||
" ],\n" +
|
||||
" \"num_output_processor_weights\": 3,\n" +
|
||||
" \"num_classification_weights\": 0,\n" +
|
||||
" \"num_classes\": 0,\n" +
|
||||
" \"num_operations\": 3\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" \"preprocessors\": [\n" +
|
||||
" {\n" +
|
||||
" \"one_hot_encoding\": {\n" +
|
||||
" \"field_length\": 10,\n" +
|
||||
" \"field_value_lengths\": [\n" +
|
||||
" 10,\n" +
|
||||
" 20\n" +
|
||||
" ],\n" +
|
||||
" \"feature_name_lengths\": [\n" +
|
||||
" 15,\n" +
|
||||
" 25\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"frequency_encoding\": {\n" +
|
||||
" \"field_length\": 10,\n" +
|
||||
" \"feature_name_length\": 5,\n" +
|
||||
" \"field_value_lengths\": [\n" +
|
||||
" 10,\n" +
|
||||
" 20\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"target_mean_encoding\": {\n" +
|
||||
" \"field_length\": 6,\n" +
|
||||
" \"feature_name_length\": 15,\n" +
|
||||
" \"field_value_lengths\": [\n" +
|
||||
" 10,\n" +
|
||||
" 20\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ]\n" +
|
||||
"} ";
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
||||
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class OneHotEncodingSizeTests extends SizeEstimatorTestCase<OneHotEncodingSize, OneHotEncoding> {
|
||||
|
||||
static OneHotEncodingSize createRandom() {
|
||||
int numFieldEntries = randomIntBetween(1, 10);
|
||||
return new OneHotEncodingSize(
|
||||
randomInt(100),
|
||||
Stream.generate(() -> randomIntBetween(5, 10))
|
||||
.limit(numFieldEntries)
|
||||
.collect(Collectors.toList()),
|
||||
Stream.generate(() -> randomIntBetween(5, 10))
|
||||
.limit(numFieldEntries)
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
static OneHotEncodingSize translateToEstimate(OneHotEncoding encoding) {
|
||||
return new OneHotEncodingSize(encoding.getField().length(),
|
||||
encoding.getHotMap().values().stream().map(String::length).collect(Collectors.toList()),
|
||||
encoding.getHotMap().keySet().stream().map(String::length).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OneHotEncodingSize createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OneHotEncodingSize doParseInstance(XContentParser parser) {
|
||||
return OneHotEncodingSize.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
OneHotEncoding generateTrueObject() {
|
||||
return OneHotEncodingTests.createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
OneHotEncodingSize translateObject(OneHotEncoding originalObject) {
|
||||
return translateToEstimate(originalObject);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public abstract class SizeEstimatorTestCase<T extends ToXContentObject & Accountable, U extends Accountable>
|
||||
extends AbstractXContentTestCase<T> {
|
||||
|
||||
abstract U generateTrueObject();
|
||||
|
||||
abstract T translateObject(U originalObject);
|
||||
|
||||
public void testRamUsageEstimationAccuracy() {
|
||||
final long bytesEps = new ByteSizeValue(2, ByteSizeUnit.KB).getBytes();
|
||||
for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) {
|
||||
U obj = generateTrueObject();
|
||||
T estimateObj = translateObject(obj);
|
||||
long originalBytesUsed = obj.ramBytesUsed();
|
||||
long estimateBytesUsed = estimateObj.ramBytesUsed();
|
||||
// If we are over by 2kb that is small enough to not be a concern
|
||||
boolean condition = (Math.abs(obj.ramBytesUsed() - estimateObj.ramBytesUsed()) < bytesEps) ||
|
||||
// If the difference is greater than 2kb, it is better to have overestimated.
|
||||
originalBytesUsed < estimateBytesUsed;
|
||||
assertThat("estimation difference greater than 2048 and the estimation is too small. Object ["
|
||||
+ obj.toString()
|
||||
+ "] estimated ["
|
||||
+ originalBytesUsed
|
||||
+ "] translated object ["
|
||||
+ estimateObj
|
||||
+ "] estimated ["
|
||||
+ estimateBytesUsed
|
||||
+ "]" ,
|
||||
condition,
|
||||
is(true));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class TargetMeanEncodingSizeTests extends SizeEstimatorTestCase<TargetMeanEncodingSize, TargetMeanEncoding> {
|
||||
|
||||
static TargetMeanEncodingSize createRandom() {
|
||||
return new TargetMeanEncodingSize(randomInt(100),
|
||||
randomInt(100),
|
||||
Stream.generate(() -> randomIntBetween(5, 10))
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
static TargetMeanEncodingSize translateToEstimate(TargetMeanEncoding encoding) {
|
||||
return new TargetMeanEncodingSize(encoding.getField().length(),
|
||||
encoding.getFeatureName().length(),
|
||||
encoding.getMeanMap().keySet().stream().map(String::length).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TargetMeanEncodingSize createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TargetMeanEncodingSize doParseInstance(XContentParser parser) {
|
||||
return TargetMeanEncodingSize.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
TargetMeanEncoding generateTrueObject() {
|
||||
return TargetMeanEncodingTests.createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
TargetMeanEncodingSize translateObject(TargetMeanEncoding originalObject) {
|
||||
return translateToEstimate(originalObject);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.inference.modelsize;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModelTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
||||
|
||||
public class TreeSizeInfoTests extends SizeEstimatorTestCase<TreeSizeInfo, TreeInferenceModel> {
|
||||
|
||||
static TreeSizeInfo createRandom() {
|
||||
return new TreeSizeInfo(randomIntBetween(1, 100), randomIntBetween(0, 100), randomIntBetween(0, 10));
|
||||
}
|
||||
|
||||
static TreeSizeInfo translateToEstimate(TreeInferenceModel tree) {
|
||||
int numClasses = Arrays.stream(tree.getNodes())
|
||||
.filter(TreeInferenceModel.Node::isLeaf)
|
||||
.map(n -> (TreeInferenceModel.LeafNode)n)
|
||||
.findFirst()
|
||||
.get()
|
||||
.getLeafValue()
|
||||
.length;
|
||||
return new TreeSizeInfo((int)Arrays.stream(tree.getNodes()).filter(TreeInferenceModel.Node::isLeaf).count(),
|
||||
(int)Arrays.stream(tree.getNodes()).filter(t -> t.isLeaf() == false).count(),
|
||||
numClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TreeSizeInfo createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TreeSizeInfo doParseInstance(XContentParser parser) {
|
||||
return TreeSizeInfo.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
TreeInferenceModel generateTrueObject() {
|
||||
try {
|
||||
return TreeInferenceModelTests.serializeFromTrainedModel(
|
||||
TreeTests.buildRandomTree(Arrays.asList(randomAlphaOfLength(10), randomAlphaOfLength(10)), 6)
|
||||
);
|
||||
} catch (IOException ex) {
|
||||
throw new ElasticsearchException(ex);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
TreeSizeInfo translateObject(TreeInferenceModel originalObject) {
|
||||
return translateToEstimate(originalObject);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue