[ML] rename EnsembleSizeInfo#inputFieldNameLengths to this.featureNameLengths (#58241) (#58253)

This commit is contained in:
Benjamin Trent 2020-06-17 10:08:55 -04:00 committed by GitHub
parent 41af7f5455
commit 2de242f80e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 10 deletions

View File

@ -29,7 +29,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
public static final ParseField NAME = new ParseField("ensemble_model_size"); public static final ParseField NAME = new ParseField("ensemble_model_size");
private static final ParseField TREE_SIZES = new ParseField("tree_sizes"); private static final ParseField TREE_SIZES = new ParseField("tree_sizes");
private static final ParseField INPUT_FIELD_NAME_LENGHTS = new ParseField("input_field_name_lengths"); private static final ParseField FEATURE_NAME_LENGTHS = new ParseField("feature_name_lengths");
private static final ParseField NUM_OUTPUT_PROCESSOR_WEIGHTS = new ParseField("num_output_processor_weights"); private static final ParseField NUM_OUTPUT_PROCESSOR_WEIGHTS = new ParseField("num_output_processor_weights");
private static final ParseField NUM_CLASSIFICATION_WEIGHTS = new ParseField("num_classification_weights"); private static final ParseField NUM_CLASSIFICATION_WEIGHTS = new ParseField("num_classification_weights");
private static final ParseField NUM_OPERATIONS = new ParseField("num_operations"); private static final ParseField NUM_OPERATIONS = new ParseField("num_operations");
@ -49,7 +49,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
static { static {
PARSER.declareObjectArray(constructorArg(), TreeSizeInfo.PARSER::apply, TREE_SIZES); PARSER.declareObjectArray(constructorArg(), TreeSizeInfo.PARSER::apply, TREE_SIZES);
PARSER.declareInt(constructorArg(), NUM_OPERATIONS); PARSER.declareInt(constructorArg(), NUM_OPERATIONS);
PARSER.declareIntArray(constructorArg(), INPUT_FIELD_NAME_LENGHTS); PARSER.declareIntArray(constructorArg(), FEATURE_NAME_LENGTHS);
PARSER.declareInt(optionalConstructorArg(), NUM_OUTPUT_PROCESSOR_WEIGHTS); PARSER.declareInt(optionalConstructorArg(), NUM_OUTPUT_PROCESSOR_WEIGHTS);
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSIFICATION_WEIGHTS); PARSER.declareInt(optionalConstructorArg(), NUM_CLASSIFICATION_WEIGHTS);
PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES); PARSER.declareInt(optionalConstructorArg(), NUM_CLASSES);
@ -62,20 +62,20 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
private final List<TreeSizeInfo> treeSizeInfos; private final List<TreeSizeInfo> treeSizeInfos;
private final int numOperations; private final int numOperations;
private final int[] inputFieldNameLengths; private final int[] featureNameLengths;
private final int numOutputProcessorWeights; private final int numOutputProcessorWeights;
private final int numClassificationWeights; private final int numClassificationWeights;
private final int numClasses; private final int numClasses;
public EnsembleSizeInfo(List<TreeSizeInfo> treeSizeInfos, public EnsembleSizeInfo(List<TreeSizeInfo> treeSizeInfos,
int numOperations, int numOperations,
List<Integer> inputFieldNameLengths, List<Integer> featureNameLengths,
int numOutputProcessorWeights, int numOutputProcessorWeights,
int numClassificationWeights, int numClassificationWeights,
int numClasses) { int numClasses) {
this.treeSizeInfos = treeSizeInfos; this.treeSizeInfos = treeSizeInfos;
this.numOperations = numOperations; this.numOperations = numOperations;
this.inputFieldNameLengths = inputFieldNameLengths.stream().mapToInt(Integer::intValue).toArray(); this.featureNameLengths = featureNameLengths.stream().mapToInt(Integer::intValue).toArray();
this.numOutputProcessorWeights = numOutputProcessorWeights; this.numOutputProcessorWeights = numOutputProcessorWeights;
this.numClassificationWeights = numClassificationWeights; this.numClassificationWeights = numClassificationWeights;
this.numClasses = numClasses; this.numClasses = numClasses;
@ -90,7 +90,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
long size = EnsembleInferenceModel.SHALLOW_SIZE; long size = EnsembleInferenceModel.SHALLOW_SIZE;
treeSizeInfos.forEach(t -> t.setNumClasses(numClasses).ramBytesUsed()); treeSizeInfos.forEach(t -> t.setNumClasses(numClasses).ramBytesUsed());
size += sizeOfCollection(treeSizeInfos); size += sizeOfCollection(treeSizeInfos);
size += sizeOfStringCollection(inputFieldNameLengths); size += sizeOfStringCollection(featureNameLengths);
size += LogisticRegression.SHALLOW_SIZE + sizeOfDoubleArray(numOutputProcessorWeights); size += LogisticRegression.SHALLOW_SIZE + sizeOfDoubleArray(numOutputProcessorWeights);
size += sizeOfDoubleArray(numClassificationWeights); size += sizeOfDoubleArray(numClassificationWeights);
return alignObjectSize(size); return alignObjectSize(size);
@ -102,7 +102,7 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
builder.field(TREE_SIZES.getPreferredName(), treeSizeInfos); builder.field(TREE_SIZES.getPreferredName(), treeSizeInfos);
builder.field(NUM_OPERATIONS.getPreferredName(), numOperations); builder.field(NUM_OPERATIONS.getPreferredName(), numOperations);
builder.field(NUM_CLASSES.getPreferredName(), numClasses); builder.field(NUM_CLASSES.getPreferredName(), numClasses);
builder.field(INPUT_FIELD_NAME_LENGHTS.getPreferredName(), inputFieldNameLengths); builder.field(FEATURE_NAME_LENGTHS.getPreferredName(), featureNameLengths);
builder.field(NUM_CLASSIFICATION_WEIGHTS.getPreferredName(), numClassificationWeights); builder.field(NUM_CLASSIFICATION_WEIGHTS.getPreferredName(), numClassificationWeights);
builder.field(NUM_OUTPUT_PROCESSOR_WEIGHTS.getPreferredName(), numOutputProcessorWeights); builder.field(NUM_OUTPUT_PROCESSOR_WEIGHTS.getPreferredName(), numOutputProcessorWeights);
builder.endObject(); builder.endObject();
@ -119,13 +119,13 @@ public class EnsembleSizeInfo implements TrainedModelSizeInfo {
numClassificationWeights == that.numClassificationWeights && numClassificationWeights == that.numClassificationWeights &&
numClasses == that.numClasses && numClasses == that.numClasses &&
Objects.equals(treeSizeInfos, that.treeSizeInfos) && Objects.equals(treeSizeInfos, that.treeSizeInfos) &&
Arrays.equals(inputFieldNameLengths, that.inputFieldNameLengths); Arrays.equals(featureNameLengths, that.featureNameLengths);
} }
@Override @Override
public int hashCode() { public int hashCode() {
int result = Objects.hash(treeSizeInfos, numOperations, numOutputProcessorWeights, numClassificationWeights, numClasses); int result = Objects.hash(treeSizeInfos, numOperations, numOutputProcessorWeights, numClassificationWeights, numClasses);
result = 31 * result + Arrays.hashCode(inputFieldNameLengths); result = 31 * result + Arrays.hashCode(featureNameLengths);
return result; return result;
} }

View File

@ -74,7 +74,7 @@ public class ModelSizeInfoTests extends AbstractXContentTestCase<ModelSizeInfo>
" {\"num_nodes\": 3, \"num_leaves\": 4},\n" + " {\"num_nodes\": 3, \"num_leaves\": 4},\n" +
" {\"num_leaves\": 1}\n" + " {\"num_leaves\": 1}\n" +
" ],\n" + " ],\n" +
" \"input_field_name_lengths\": [\n" + " \"feature_name_lengths\": [\n" +
" 14,\n" + " 14,\n" +
" 10,\n" + " 10,\n" +
" 11\n" + " 11\n" +