This commit is contained in:
parent
41af7f5455
commit
2de242f80e
|
@ -29,7 +29,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
|
||||||
|
|
||||||
public static final ParseField NAME = new ParseField("ensemble_model_size");
|
public static final ParseField NAME = new ParseField("ensemble_model_size");
|
||||||
private static final ParseField TREE_SIZES = new ParseField("tree_sizes");
|
private static final ParseField TREE_SIZES = new ParseField("tree_sizes");
|
||||||
private static final ParseField INPUT_FIELD_NAME_LENGHTS = new ParseField("input_field_name_lengths");
|
private static final ParseField FEATURE_NAME_LENGTHS = new ParseField("feature_name_lengths");
|
||||||
private static final ParseField NUM_OUTPUT_PROCESSOR_WEIGHTS = new ParseField("num_output_processor_weights");
|
private static final ParseField NUM_OUTPUT_PROCESSOR_WEIGHTS = new ParseField("num_output_processor_weights");
|
||||||
private static final ParseField NUM_CLASSIFICATION_WEIGHTS = new ParseField("num_classification_weights");
|
private static final ParseField NUM_CLASSIFICATION_WEIGHTS = new ParseField("num_classification_weights");
|
||||||
private static final ParseField NUM_OPERATIONS = new ParseField("num_operations");
|
private static final ParseField NUM_OPERATIONS = new ParseField("num_operations");
|
||||||
|
@ -49,7 +49,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
|
||||||
static {
|
static {
|
||||||
PARSER.declareObjectArray(constructorArg(), TreeSizeInfo.PARSER::apply, TREE_SIZES);
|
PARSER.declareObjectArray(constructorArg(), TreeSizeInfo.PARSER::apply, TREE_SIZES);
|
||||||
PARSER.declareInt(constructorArg(), NUM_OPERATIONS);
|
PARSER.declareInt(constructorArg(), NUM_OPERATIONS);
|
||||||
PARSER.declareIntArray(constructorArg(), INPUT_FIELD_NAME_LENGHTS);
|
PARSER.declareIntArray(constructorArg(), FEATURE_NAME_LENGTHS);
|
||||||
PARSER.declareInt(optionalConstructorArg(), NUM_OUTPUT_PROCESSOR_WEIGHTS);
|
PARSER.declareInt(optionalConstructorArg(), NUM_OUTPUT_PROCESSOR_WEIGHTS);
|
||||||
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSIFICATION_WEIGHTS);
|
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSIFICATION_WEIGHTS);
|
||||||
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES);
|
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES);
|
||||||
|
@ -62,20 +62,20 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
|
||||||
|
|
||||||
private final List<TreeSizeInfo> treeSizeInfos;
|
private final List<TreeSizeInfo> treeSizeInfos;
|
||||||
private final int numOperations;
|
private final int numOperations;
|
||||||
private final int[] inputFieldNameLengths;
|
private final int[] featureNameLengths;
|
||||||
private final int numOutputProcessorWeights;
|
private final int numOutputProcessorWeights;
|
||||||
private final int numClassificationWeights;
|
private final int numClassificationWeights;
|
||||||
private final int numClasses;
|
private final int numClasses;
|
||||||
|
|
||||||
public EnsembleSizeInfo(List<TreeSizeInfo> treeSizeInfos,
|
public EnsembleSizeInfo(List<TreeSizeInfo> treeSizeInfos,
|
||||||
int numOperations,
|
int numOperations,
|
||||||
List<Integer> inputFieldNameLengths,
|
List<Integer> featureNameLengths,
|
||||||
int numOutputProcessorWeights,
|
int numOutputProcessorWeights,
|
||||||
int numClassificationWeights,
|
int numClassificationWeights,
|
||||||
int numClasses) {
|
int numClasses) {
|
||||||
this.treeSizeInfos = treeSizeInfos;
|
this.treeSizeInfos = treeSizeInfos;
|
||||||
this.numOperations = numOperations;
|
this.numOperations = numOperations;
|
||||||
this.inputFieldNameLengths = inputFieldNameLengths.stream().mapToInt(Integer::intValue).toArray();
|
this.featureNameLengths = featureNameLengths.stream().mapToInt(Integer::intValue).toArray();
|
||||||
this.numOutputProcessorWeights = numOutputProcessorWeights;
|
this.numOutputProcessorWeights = numOutputProcessorWeights;
|
||||||
this.numClassificationWeights = numClassificationWeights;
|
this.numClassificationWeights = numClassificationWeights;
|
||||||
this.numClasses = numClasses;
|
this.numClasses = numClasses;
|
||||||
|
@ -90,7 +90,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
|
||||||
long size = EnsembleInferenceModel.SHALLOW_SIZE;
|
long size = EnsembleInferenceModel.SHALLOW_SIZE;
|
||||||
treeSizeInfos.forEach(t -> t.setNumClasses(numClasses).ramBytesUsed());
|
treeSizeInfos.forEach(t -> t.setNumClasses(numClasses).ramBytesUsed());
|
||||||
size += sizeOfCollection(treeSizeInfos);
|
size += sizeOfCollection(treeSizeInfos);
|
||||||
size += sizeOfStringCollection(inputFieldNameLengths);
|
size += sizeOfStringCollection(featureNameLengths);
|
||||||
size += LogisticRegression.SHALLOW_SIZE + sizeOfDoubleArray(numOutputProcessorWeights);
|
size += LogisticRegression.SHALLOW_SIZE + sizeOfDoubleArray(numOutputProcessorWeights);
|
||||||
size += sizeOfDoubleArray(numClassificationWeights);
|
size += sizeOfDoubleArray(numClassificationWeights);
|
||||||
return alignObjectSize(size);
|
return alignObjectSize(size);
|
||||||
|
@ -102,7 +102,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
|
||||||
builder.field(TREE_SIZES.getPreferredName(), treeSizeInfos);
|
builder.field(TREE_SIZES.getPreferredName(), treeSizeInfos);
|
||||||
builder.field(NUM_OPERATIONS.getPreferredName(), numOperations);
|
builder.field(NUM_OPERATIONS.getPreferredName(), numOperations);
|
||||||
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
|
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
|
||||||
builder.field(INPUT_FIELD_NAME_LENGHTS.getPreferredName(), inputFieldNameLengths);
|
builder.field(FEATURE_NAME_LENGTHS.getPreferredName(), featureNameLengths);
|
||||||
builder.field(NUM_CLASSIFICATION_WEIGHTS.getPreferredName(), numClassificationWeights);
|
builder.field(NUM_CLASSIFICATION_WEIGHTS.getPreferredName(), numClassificationWeights);
|
||||||
builder.field(NUM_OUTPUT_PROCESSOR_WEIGHTS.getPreferredName(), numOutputProcessorWeights);
|
builder.field(NUM_OUTPUT_PROCESSOR_WEIGHTS.getPreferredName(), numOutputProcessorWeights);
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
|
@ -119,13 +119,13 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
|
||||||
numClassificationWeights == that.numClassificationWeights &&
|
numClassificationWeights == that.numClassificationWeights &&
|
||||||
numClasses == that.numClasses &&
|
numClasses == that.numClasses &&
|
||||||
Objects.equals(treeSizeInfos, that.treeSizeInfos) &&
|
Objects.equals(treeSizeInfos, that.treeSizeInfos) &&
|
||||||
Arrays.equals(inputFieldNameLengths, that.inputFieldNameLengths);
|
Arrays.equals(featureNameLengths, that.featureNameLengths);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
int result = Objects.hash(treeSizeInfos, numOperations, numOutputProcessorWeights, numClassificationWeights, numClasses);
|
int result = Objects.hash(treeSizeInfos, numOperations, numOutputProcessorWeights, numClassificationWeights, numClasses);
|
||||||
result = 31 * result + Arrays.hashCode(inputFieldNameLengths);
|
result = 31 * result + Arrays.hashCode(featureNameLengths);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ public class ModelSizeInfoTests extends AbstractXContentTestCase<ModelSizeInfo>
|
||||||
" {\"num_nodes\": 3, \"num_leaves\": 4},\n" +
|
" {\"num_nodes\": 3, \"num_leaves\": 4},\n" +
|
||||||
" {\"num_leaves\": 1}\n" +
|
" {\"num_leaves\": 1}\n" +
|
||||||
" ],\n" +
|
" ],\n" +
|
||||||
" \"input_field_name_lengths\": [\n" +
|
" \"feature_name_lengths\": [\n" +
|
||||||
" 14,\n" +
|
" 14,\n" +
|
||||||
" 10,\n" +
|
" 10,\n" +
|
||||||
" 11\n" +
|
" 11\n" +
|
||||||
|
|
Loading…
Reference in New Issue