[ML] rename EnsembleSizeInfo#inputFieldNameLengths to this.featureNameLengths (#58241) (#58253)

This commit is contained in:
Benjamin Trent 2020-06-17 10:08:55 -04:00 committed by GitHub
parent 41af7f5455
commit 2de242f80e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 10 deletions

View File

@ -29,7 +29,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
public static final ParseField NAME = new ParseField("ensemble_model_size");
private static final ParseField TREE_SIZES = new ParseField("tree_sizes");
private static final ParseField INPUT_FIELD_NAME_LENGHTS = new ParseField("input_field_name_lengths");
private static final ParseField FEATURE_NAME_LENGTHS = new ParseField("feature_name_lengths");
private static final ParseField NUM_OUTPUT_PROCESSOR_WEIGHTS = new ParseField("num_output_processor_weights");
private static final ParseField NUM_CLASSIFICATION_WEIGHTS = new ParseField("num_classification_weights");
private static final ParseField NUM_OPERATIONS = new ParseField("num_operations");
@ -49,7 +49,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
static {
PARSER.declareObjectArray(constructorArg(), TreeSizeInfo.PARSER::apply, TREE_SIZES);
PARSER.declareInt(constructorArg(), NUM_OPERATIONS);
PARSER.declareIntArray(constructorArg(), INPUT_FIELD_NAME_LENGHTS);
PARSER.declareIntArray(constructorArg(), FEATURE_NAME_LENGTHS);
PARSER.declareInt(optionalConstructorArg(), NUM_OUTPUT_PROCESSOR_WEIGHTS);
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSIFICATION_WEIGHTS);
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES);
@ -62,20 +62,20 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
private final List<TreeSizeInfo> treeSizeInfos;
private final int numOperations;
private final int[] inputFieldNameLengths;
private final int[] featureNameLengths;
private final int numOutputProcessorWeights;
private final int numClassificationWeights;
private final int numClasses;
public EnsembleSizeInfo(List<TreeSizeInfo> treeSizeInfos,
int numOperations,
List<Integer> inputFieldNameLengths,
List<Integer> featureNameLengths,
int numOutputProcessorWeights,
int numClassificationWeights,
int numClasses) {
this.treeSizeInfos = treeSizeInfos;
this.numOperations = numOperations;
this.inputFieldNameLengths = inputFieldNameLengths.stream().mapToInt(Integer::intValue).toArray();
this.featureNameLengths = featureNameLengths.stream().mapToInt(Integer::intValue).toArray();
this.numOutputProcessorWeights = numOutputProcessorWeights;
this.numClassificationWeights = numClassificationWeights;
this.numClasses = numClasses;
@ -90,7 +90,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
long size = EnsembleInferenceModel.SHALLOW_SIZE;
treeSizeInfos.forEach(t -> t.setNumClasses(numClasses).ramBytesUsed());
size += sizeOfCollection(treeSizeInfos);
size += sizeOfStringCollection(inputFieldNameLengths);
size += sizeOfStringCollection(featureNameLengths);
size += LogisticRegression.SHALLOW_SIZE + sizeOfDoubleArray(numOutputProcessorWeights);
size += sizeOfDoubleArray(numClassificationWeights);
return alignObjectSize(size);
@ -102,7 +102,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
builder.field(TREE_SIZES.getPreferredName(), treeSizeInfos);
builder.field(NUM_OPERATIONS.getPreferredName(), numOperations);
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
builder.field(INPUT_FIELD_NAME_LENGHTS.getPreferredName(), inputFieldNameLengths);
builder.field(FEATURE_NAME_LENGTHS.getPreferredName(), featureNameLengths);
builder.field(NUM_CLASSIFICATION_WEIGHTS.getPreferredName(), numClassificationWeights);
builder.field(NUM_OUTPUT_PROCESSOR_WEIGHTS.getPreferredName(), numOutputProcessorWeights);
builder.endObject();
@ -119,13 +119,13 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
numClassificationWeights == that.numClassificationWeights &&
numClasses == that.numClasses &&
Objects.equals(treeSizeInfos, that.treeSizeInfos) &&
Arrays.equals(inputFieldNameLengths, that.inputFieldNameLengths);
Arrays.equals(featureNameLengths, that.featureNameLengths);
}
@Override
public int hashCode() {
int result = Objects.hash(treeSizeInfos, numOperations, numOutputProcessorWeights, numClassificationWeights, numClasses);
result = 31 * result + Arrays.hashCode(inputFieldNameLengths);
result = 31 * result + Arrays.hashCode(featureNameLengths);
return result;
}

View File

@ -74,7 +74,7 @@ public class ModelSizeInfoTests extends AbstractXContentTestCase<ModelSizeInfo>
" {\"num_nodes\": 3, \"num_leaves\": 4},\n" +
" {\"num_leaves\": 1}\n" +
" ],\n" +
" \"input_field_name_lengths\": [\n" +
" \"feature_name_lengths\": [\n" +
" 14,\n" +
" 10,\n" +
" 11\n" +