This commit is contained in:
parent
41af7f5455
commit
2de242f80e
|
@ -29,7 +29,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
|
|||
|
||||
public static final ParseField NAME = new ParseField("ensemble_model_size");
|
||||
private static final ParseField TREE_SIZES = new ParseField("tree_sizes");
|
||||
private static final ParseField INPUT_FIELD_NAME_LENGHTS = new ParseField("input_field_name_lengths");
|
||||
private static final ParseField FEATURE_NAME_LENGTHS = new ParseField("feature_name_lengths");
|
||||
private static final ParseField NUM_OUTPUT_PROCESSOR_WEIGHTS = new ParseField("num_output_processor_weights");
|
||||
private static final ParseField NUM_CLASSIFICATION_WEIGHTS = new ParseField("num_classification_weights");
|
||||
private static final ParseField NUM_OPERATIONS = new ParseField("num_operations");
|
||||
|
@ -49,7 +49,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
|
|||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), TreeSizeInfo.PARSER::apply, TREE_SIZES);
|
||||
PARSER.declareInt(constructorArg(), NUM_OPERATIONS);
|
||||
PARSER.declareIntArray(constructorArg(), INPUT_FIELD_NAME_LENGHTS);
|
||||
PARSER.declareIntArray(constructorArg(), FEATURE_NAME_LENGTHS);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_OUTPUT_PROCESSOR_WEIGHTS);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSIFICATION_WEIGHTS);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES);
|
||||
|
@ -62,20 +62,20 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
|
|||
|
||||
private final List<TreeSizeInfo> treeSizeInfos;
|
||||
private final int numOperations;
|
||||
private final int[] inputFieldNameLengths;
|
||||
private final int[] featureNameLengths;
|
||||
private final int numOutputProcessorWeights;
|
||||
private final int numClassificationWeights;
|
||||
private final int numClasses;
|
||||
|
||||
public EnsembleSizeInfo(List<TreeSizeInfo> treeSizeInfos,
|
||||
int numOperations,
|
||||
List<Integer> inputFieldNameLengths,
|
||||
List<Integer> featureNameLengths,
|
||||
int numOutputProcessorWeights,
|
||||
int numClassificationWeights,
|
||||
int numClasses) {
|
||||
this.treeSizeInfos = treeSizeInfos;
|
||||
this.numOperations = numOperations;
|
||||
this.inputFieldNameLengths = inputFieldNameLengths.stream().mapToInt(Integer::intValue).toArray();
|
||||
this.featureNameLengths = featureNameLengths.stream().mapToInt(Integer::intValue).toArray();
|
||||
this.numOutputProcessorWeights = numOutputProcessorWeights;
|
||||
this.numClassificationWeights = numClassificationWeights;
|
||||
this.numClasses = numClasses;
|
||||
|
@ -90,7 +90,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
|
|||
long size = EnsembleInferenceModel.SHALLOW_SIZE;
|
||||
treeSizeInfos.forEach(t -> t.setNumClasses(numClasses).ramBytesUsed());
|
||||
size += sizeOfCollection(treeSizeInfos);
|
||||
size += sizeOfStringCollection(inputFieldNameLengths);
|
||||
size += sizeOfStringCollection(featureNameLengths);
|
||||
size += LogisticRegression.SHALLOW_SIZE + sizeOfDoubleArray(numOutputProcessorWeights);
|
||||
size += sizeOfDoubleArray(numClassificationWeights);
|
||||
return alignObjectSize(size);
|
||||
|
@ -102,7 +102,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
|
|||
builder.field(TREE_SIZES.getPreferredName(), treeSizeInfos);
|
||||
builder.field(NUM_OPERATIONS.getPreferredName(), numOperations);
|
||||
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
|
||||
builder.field(INPUT_FIELD_NAME_LENGHTS.getPreferredName(), inputFieldNameLengths);
|
||||
builder.field(FEATURE_NAME_LENGTHS.getPreferredName(), featureNameLengths);
|
||||
builder.field(NUM_CLASSIFICATION_WEIGHTS.getPreferredName(), numClassificationWeights);
|
||||
builder.field(NUM_OUTPUT_PROCESSOR_WEIGHTS.getPreferredName(), numOutputProcessorWeights);
|
||||
builder.endObject();
|
||||
|
@ -119,13 +119,13 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
|
|||
numClassificationWeights == that.numClassificationWeights &&
|
||||
numClasses == that.numClasses &&
|
||||
Objects.equals(treeSizeInfos, that.treeSizeInfos) &&
|
||||
Arrays.equals(inputFieldNameLengths, that.inputFieldNameLengths);
|
||||
Arrays.equals(featureNameLengths, that.featureNameLengths);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = Objects.hash(treeSizeInfos, numOperations, numOutputProcessorWeights, numClassificationWeights, numClasses);
|
||||
result = 31 * result + Arrays.hashCode(inputFieldNameLengths);
|
||||
result = 31 * result + Arrays.hashCode(featureNameLengths);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ public class ModelSizeInfoTests extends AbstractXContentTestCase<ModelSizeInfo>
|
|||
" {\"num_nodes\": 3, \"num_leaves\": 4},\n" +
|
||||
" {\"num_leaves\": 1}\n" +
|
||||
" ],\n" +
|
||||
" \"input_field_name_lengths\": [\n" +
|
||||
" \"feature_name_lengths\": [\n" +
|
||||
" 14,\n" +
|
||||
" 10,\n" +
|
||||
" 11\n" +
|
||||
|
|
Loading…
Reference in New Issue