This adds a new `baseline` field to the feature importance values. This field contains the baseline importance for a given feature and class.
This commit is contained in:
parent
fbf552d24c
commit
95242eccee
|
@ -40,6 +40,7 @@ public class TotalFeatureImportance implements ToXContentObject {
|
|||
public static final ParseField IMPORTANCE = new ParseField("importance");
|
||||
public static final ParseField CLASSES = new ParseField("classes");
|
||||
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
|
||||
public static final ParseField BASELINE = new ParseField("baseline");
|
||||
public static final ParseField MIN = new ParseField("min");
|
||||
public static final ParseField MAX = new ParseField("max");
|
||||
|
||||
|
@ -102,22 +103,25 @@ public class TotalFeatureImportance implements ToXContentObject {
|
|||
|
||||
public static final ConstructingObjectParser<Importance, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||
true,
|
||||
a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
|
||||
a -> new Importance((double)a[0], (double)a[1], (double)a[2], (Double)a[3]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
|
||||
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
|
||||
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE);
|
||||
}
|
||||
|
||||
private final double meanMagnitude;
|
||||
private final double min;
|
||||
private final double max;
|
||||
private final Double baseline;
|
||||
|
||||
public Importance(double meanMagnitude, double min, double max) {
|
||||
public Importance(double meanMagnitude, double min, double max, Double baseline) {
|
||||
this.meanMagnitude = meanMagnitude;
|
||||
this.min = min;
|
||||
this.max = max;
|
||||
this.baseline = baseline;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -127,12 +131,13 @@ public class TotalFeatureImportance implements ToXContentObject {
|
|||
Importance that = (Importance) o;
|
||||
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
|
||||
Double.compare(that.min, min) == 0 &&
|
||||
Double.compare(that.max, max) == 0;
|
||||
Double.compare(that.max, max) == 0 &&
|
||||
Objects.equals(that.baseline, baseline);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(meanMagnitude, min, max);
|
||||
return Objects.hash(meanMagnitude, min, max, baseline);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -141,6 +146,9 @@ public class TotalFeatureImportance implements ToXContentObject {
|
|||
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
|
||||
builder.field(MIN.getPreferredName(), min);
|
||||
builder.field(MAX.getPreferredName(), max);
|
||||
if (baseline != null) {
|
||||
builder.field(BASELINE.getPreferredName(), baseline);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -50,7 +50,11 @@ public class TotalFeatureImportanceTests extends AbstractXContentTestCase<TotalF
|
|||
}
|
||||
|
||||
private static TotalFeatureImportance.Importance randomImportance() {
|
||||
return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
|
||||
return new TotalFeatureImportance.Importance(
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomBoolean() ? null : randomDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -35,6 +35,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
|||
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
|
||||
public static final ParseField MIN = new ParseField("min");
|
||||
public static final ParseField MAX = new ParseField("max");
|
||||
public static final ParseField BASELINE = new ParseField("baseline");
|
||||
|
||||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
||||
public static final ConstructingObjectParser<TotalFeatureImportance, Void> LENIENT_PARSER = createParser(true);
|
||||
|
@ -124,27 +125,31 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
|||
private static ConstructingObjectParser<Importance, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<Importance, Void> parser = new ConstructingObjectParser<>(NAME,
|
||||
ignoreUnknownFields,
|
||||
a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
|
||||
a -> new Importance((double)a[0], (double)a[1], (double)a[2], (Double)a[3]));
|
||||
parser.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
|
||||
parser.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
|
||||
parser.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE);
|
||||
return parser;
|
||||
}
|
||||
|
||||
private final double meanMagnitude;
|
||||
private final double min;
|
||||
private final double max;
|
||||
private final Double baseline;
|
||||
|
||||
public Importance(double meanMagnitude, double min, double max) {
|
||||
public Importance(double meanMagnitude, double min, double max, Double baseline) {
|
||||
this.meanMagnitude = meanMagnitude;
|
||||
this.min = min;
|
||||
this.max = max;
|
||||
this.baseline = baseline;
|
||||
}
|
||||
|
||||
public Importance(StreamInput in) throws IOException {
|
||||
this.meanMagnitude = in.readDouble();
|
||||
this.min = in.readDouble();
|
||||
this.max = in.readDouble();
|
||||
this.baseline = in.readOptionalDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -154,12 +159,13 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
|||
Importance that = (Importance) o;
|
||||
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
|
||||
Double.compare(that.min, min) == 0 &&
|
||||
Double.compare(that.max, max) == 0;
|
||||
Double.compare(that.max, max) == 0 &&
|
||||
Objects.equals(that.baseline, baseline);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(meanMagnitude, min, max);
|
||||
return Objects.hash(meanMagnitude, min, max, baseline);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -167,6 +173,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
|||
out.writeDouble(meanMagnitude);
|
||||
out.writeDouble(min);
|
||||
out.writeDouble(max);
|
||||
out.writeOptionalDouble(baseline);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -179,6 +186,9 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
|||
map.put(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
|
||||
map.put(MIN.getPreferredName(), min);
|
||||
map.put(MAX.getPreferredName(), max);
|
||||
if (baseline != null) {
|
||||
map.put(BASELINE.getPreferredName(), baseline);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -85,6 +85,9 @@
|
|||
},
|
||||
"mean_magnitude": {
|
||||
"type": "double"
|
||||
},
|
||||
"baseline": {
|
||||
"type": "double"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -105,6 +108,9 @@
|
|||
},
|
||||
"mean_magnitude": {
|
||||
"type": "double"
|
||||
},
|
||||
"baseline": {
|
||||
"type": "double"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
@ -41,7 +41,11 @@ public class TotalFeatureImportanceTests extends AbstractBWCSerializationTestCas
|
|||
}
|
||||
|
||||
private static TotalFeatureImportance.Importance randomImportance() {
|
||||
return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
|
||||
return new TotalFeatureImportance.Importance(
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomDouble(),
|
||||
randomBoolean() ? null : randomDouble());
|
||||
}
|
||||
|
||||
@Before
|
||||
|
|
Loading…
Reference in New Issue