[ML] adding `baseline` field to total_feature_importance objects (#63098) (#63125)

This adds a new `baseline` field to the feature importance values. 

This field contains the baseline importance for a given feature and class.
This commit is contained in:
Benjamin Trent 2020-10-01 09:48:07 -04:00 committed by GitHub
parent fbf552d24c
commit 95242eccee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 42 additions and 10 deletions

View File

@ -40,6 +40,7 @@ public class TotalFeatureImportance implements ToXContentObject {
public static final ParseField IMPORTANCE = new ParseField("importance");
public static final ParseField CLASSES = new ParseField("classes");
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
public static final ParseField BASELINE = new ParseField("baseline");
public static final ParseField MIN = new ParseField("min");
public static final ParseField MAX = new ParseField("max");
@ -102,22 +103,25 @@ public class TotalFeatureImportance implements ToXContentObject {
public static final ConstructingObjectParser<Importance, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
a -> new Importance((double)a[0], (double)a[1], (double)a[2], (Double)a[3]));
static {
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE);
}
private final double meanMagnitude;
private final double min;
private final double max;
private final Double baseline;
public Importance(double meanMagnitude, double min, double max) {
public Importance(double meanMagnitude, double min, double max, Double baseline) {
this.meanMagnitude = meanMagnitude;
this.min = min;
this.max = max;
this.baseline = baseline;
}
@Override
@ -127,12 +131,13 @@ public class TotalFeatureImportance implements ToXContentObject {
Importance that = (Importance) o;
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
Double.compare(that.min, min) == 0 &&
Double.compare(that.max, max) == 0;
Double.compare(that.max, max) == 0 &&
Objects.equals(that.baseline, baseline);
}
@Override
public int hashCode() {
return Objects.hash(meanMagnitude, min, max);
return Objects.hash(meanMagnitude, min, max, baseline);
}
@Override
@ -141,6 +146,9 @@ public class TotalFeatureImportance implements ToXContentObject {
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
builder.field(MIN.getPreferredName(), min);
builder.field(MAX.getPreferredName(), max);
if (baseline != null) {
builder.field(BASELINE.getPreferredName(), baseline);
}
builder.endObject();
return builder;
}

View File

@ -50,7 +50,11 @@ public class TotalFeatureImportanceTests extends AbstractXContentTestCase<TotalF
}
private static TotalFeatureImportance.Importance randomImportance() {
return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
return new TotalFeatureImportance.Importance(
randomDouble(),
randomDouble(),
randomDouble(),
randomBoolean() ? null : randomDouble());
}
@Override

View File

@ -35,6 +35,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
public static final ParseField MIN = new ParseField("min");
public static final ParseField MAX = new ParseField("max");
public static final ParseField BASELINE = new ParseField("baseline");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ConstructingObjectParser<TotalFeatureImportance, Void> LENIENT_PARSER = createParser(true);
@ -124,27 +125,31 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
private static ConstructingObjectParser<Importance, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<Importance, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
a -> new Importance((double)a[0], (double)a[1], (double)a[2], (Double)a[3]));
parser.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
parser.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
parser.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE);
return parser;
}
private final double meanMagnitude;
private final double min;
private final double max;
private final Double baseline;
public Importance(double meanMagnitude, double min, double max) {
public Importance(double meanMagnitude, double min, double max, Double baseline) {
this.meanMagnitude = meanMagnitude;
this.min = min;
this.max = max;
this.baseline = baseline;
}
public Importance(StreamInput in) throws IOException {
this.meanMagnitude = in.readDouble();
this.min = in.readDouble();
this.max = in.readDouble();
this.baseline = in.readOptionalDouble();
}
@Override
@ -154,12 +159,13 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
Importance that = (Importance) o;
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
Double.compare(that.min, min) == 0 &&
Double.compare(that.max, max) == 0;
Double.compare(that.max, max) == 0 &&
Objects.equals(that.baseline, baseline);
}
@Override
public int hashCode() {
return Objects.hash(meanMagnitude, min, max);
return Objects.hash(meanMagnitude, min, max, baseline);
}
@Override
@ -167,6 +173,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
out.writeDouble(meanMagnitude);
out.writeDouble(min);
out.writeDouble(max);
out.writeOptionalDouble(baseline);
}
@Override
@ -179,6 +186,9 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
map.put(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
map.put(MIN.getPreferredName(), min);
map.put(MAX.getPreferredName(), max);
if (baseline != null) {
map.put(BASELINE.getPreferredName(), baseline);
}
return map;
}
}

View File

@ -85,6 +85,9 @@
},
"mean_magnitude": {
"type": "double"
},
"baseline": {
"type": "double"
}
}
},
@ -105,6 +108,9 @@
},
"mean_magnitude": {
"type": "double"
},
"baseline": {
"type": "double"
}
}
},

View File

@ -41,7 +41,11 @@ public class TotalFeatureImportanceTests extends AbstractBWCSerializationTestCas
}
private static TotalFeatureImportance.Importance randomImportance() {
return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
return new TotalFeatureImportance.Importance(
randomDouble(),
randomDouble(),
randomDouble(),
randomBoolean() ? null : randomDouble());
}
@Before