This commit addresses two issues: - per class feature importance is now written out for binary classification (logistic regression) - The `class_name` in per class feature importance now matches what is written in the `top_classes` array. backport of https://github.com/elastic/elasticsearch/pull/61597
This commit is contained in:
parent
2858e1efc4
commit
8b33d8813a
|
@ -39,6 +39,12 @@ public class FeatureImportance implements Writeable, ToXContentObject {
|
||||||
return new FeatureImportance(featureName, importance, null);
|
return new FeatureImportance(featureName, importance, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static FeatureImportance forBinaryClassification(String featureName, double importance, List<ClassImportance> classImportance) {
|
||||||
|
return new FeatureImportance(featureName,
|
||||||
|
importance,
|
||||||
|
classImportance);
|
||||||
|
}
|
||||||
|
|
||||||
public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
|
public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
|
||||||
return new FeatureImportance(featureName,
|
return new FeatureImportance(featureName,
|
||||||
classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
|
classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
|
||||||
|
@ -170,27 +176,27 @@ public class FeatureImportance implements Writeable, ToXContentObject {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Map<String, Double> toMap(List<ClassImportance> importances) {
|
private static Map<String, Double> toMap(List<ClassImportance> importances) {
|
||||||
return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance));
|
return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static ClassImportance fromXContent(XContentParser parser) {
|
public static ClassImportance fromXContent(XContentParser parser) {
|
||||||
return PARSER.apply(parser, null);
|
return PARSER.apply(parser, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final String className;
|
private final Object className;
|
||||||
private final double importance;
|
private final double importance;
|
||||||
|
|
||||||
public ClassImportance(String className, double importance) {
|
public ClassImportance(Object className, double importance) {
|
||||||
this.className = className;
|
this.className = className;
|
||||||
this.importance = importance;
|
this.importance = importance;
|
||||||
}
|
}
|
||||||
|
|
||||||
public ClassImportance(StreamInput in) throws IOException {
|
public ClassImportance(StreamInput in) throws IOException {
|
||||||
this.className = in.readString();
|
this.className = in.readGenericValue();
|
||||||
this.importance = in.readDouble();
|
this.importance = in.readDouble();
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getClassName() {
|
public Object getClassName() {
|
||||||
return className;
|
return className;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,7 +213,7 @@ public class FeatureImportance implements Writeable, ToXContentObject {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void writeTo(StreamOutput out) throws IOException {
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
out.writeString(className);
|
out.writeGenericValue(className);
|
||||||
out.writeDouble(importance);
|
out.writeDouble(importance);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
@ -129,21 +130,46 @@ public final class InferenceHelpers {
|
||||||
return originalFeatureImportance;
|
return originalFeatureImportance;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance,
|
public static List<FeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) {
|
||||||
@Nullable List<String> classificationLabels) {
|
|
||||||
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
|
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
|
||||||
|
featureImportance.forEach((k, v) -> importances.add(FeatureImportance.forRegression(k, v[0])));
|
||||||
|
return importances;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<FeatureImportance> transformFeatureImportanceClassification(Map<String, double[]> featureImportance,
|
||||||
|
final int predictedValue,
|
||||||
|
@Nullable List<String> classificationLabels,
|
||||||
|
@Nullable PredictionFieldType predictionFieldType) {
|
||||||
|
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
|
||||||
|
final PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
|
||||||
featureImportance.forEach((k, v) -> {
|
featureImportance.forEach((k, v) -> {
|
||||||
// This indicates regression, or logistic regression
|
// This indicates logistic regression (binary classification)
|
||||||
// If the length > 1, we assume multi-class classification.
|
// If the length > 1, we assume multi-class classification.
|
||||||
if (v.length == 1) {
|
if (v.length == 1) {
|
||||||
importances.add(FeatureImportance.forRegression(k, v[0]));
|
assert predictedValue == 1 || predictedValue == 0;
|
||||||
|
// If predicted value is `1`, then the other class is `0`
|
||||||
|
// If predicted value is `0`, then the other class is `1`
|
||||||
|
final int otherClass = 1 - predictedValue;
|
||||||
|
String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue);
|
||||||
|
String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass);
|
||||||
|
importances.add(FeatureImportance.forBinaryClassification(k,
|
||||||
|
v[0],
|
||||||
|
Arrays.asList(
|
||||||
|
new FeatureImportance.ClassImportance(
|
||||||
|
fieldType.transformPredictedValue((double)predictedValue, predictedLabel),
|
||||||
|
v[0]),
|
||||||
|
new FeatureImportance.ClassImportance(
|
||||||
|
fieldType.transformPredictedValue((double)otherClass, otherLabel),
|
||||||
|
-v[0])
|
||||||
|
)));
|
||||||
} else {
|
} else {
|
||||||
List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
|
List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
|
||||||
// If the classificationLabels exist, their length must match leaf_value length
|
// If the classificationLabels exist, their length must match leaf_value length
|
||||||
assert classificationLabels == null || classificationLabels.size() == v.length;
|
assert classificationLabels == null || classificationLabels.size() == v.length;
|
||||||
for (int i = 0; i < v.length; i++) {
|
for (int i = 0; i < v.length; i++) {
|
||||||
|
String label = classificationLabels == null ? null : classificationLabels.get(i);
|
||||||
classImportance.add(new FeatureImportance.ClassImportance(
|
classImportance.add(new FeatureImportance.ClassImportance(
|
||||||
classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i),
|
fieldType.transformPredictedValue((double)i, label),
|
||||||
v[i]));
|
v[i]));
|
||||||
}
|
}
|
||||||
importances.add(FeatureImportance.forClassification(k, classImportance));
|
importances.add(FeatureImportance.forClassification(k, classImportance));
|
||||||
|
|
|
@ -43,7 +43,8 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
|
||||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
|
||||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
|
||||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
|
||||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance;
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceClassification;
|
||||||
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceRegression;
|
||||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT;
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT;
|
||||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS;
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS;
|
||||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
|
||||||
|
@ -154,14 +155,7 @@ public class EnsembleInferenceModel implements InferenceModel {
|
||||||
RawInferenceResults inferenceResult = (RawInferenceResults) result;
|
RawInferenceResults inferenceResult = (RawInferenceResults) result;
|
||||||
inferenceResults[i++] = inferenceResult.getValue();
|
inferenceResults[i++] = inferenceResult.getValue();
|
||||||
if (config.requestingImportance()) {
|
if (config.requestingImportance()) {
|
||||||
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
|
addFeatureImportance(featureInfluence, inferenceResult);
|
||||||
assert modelFeatureImportance.length == featureInfluence.length;
|
|
||||||
for (int j = 0; j < modelFeatureImportance.length; j++) {
|
|
||||||
if (featureInfluence[j] == null) {
|
|
||||||
featureInfluence[j] = new double[modelFeatureImportance[j].length];
|
|
||||||
}
|
|
||||||
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
double[] processed = outputAggregator.processValues(inferenceResults);
|
double[] processed = outputAggregator.processValues(inferenceResults);
|
||||||
|
@ -176,6 +170,12 @@ public class EnsembleInferenceModel implements InferenceModel {
|
||||||
InferenceResults result = model.infer(features, subModelInferenceConfig);
|
InferenceResults result = model.infer(features, subModelInferenceConfig);
|
||||||
assert result instanceof RawInferenceResults;
|
assert result instanceof RawInferenceResults;
|
||||||
RawInferenceResults inferenceResult = (RawInferenceResults) result;
|
RawInferenceResults inferenceResult = (RawInferenceResults) result;
|
||||||
|
addFeatureImportance(featureInfluence, inferenceResult);
|
||||||
|
}
|
||||||
|
return featureInfluence;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addFeatureImportance(double[][] featureInfluence, RawInferenceResults inferenceResult) {
|
||||||
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
|
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
|
||||||
assert modelFeatureImportance.length == featureInfluence.length;
|
assert modelFeatureImportance.length == featureInfluence.length;
|
||||||
for (int j = 0; j < modelFeatureImportance.length; j++) {
|
for (int j = 0; j < modelFeatureImportance.length; j++) {
|
||||||
|
@ -185,8 +185,6 @@ public class EnsembleInferenceModel implements InferenceModel {
|
||||||
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
|
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return featureInfluence;
|
|
||||||
}
|
|
||||||
|
|
||||||
private InferenceResults buildResults(double[] processedInferences,
|
private InferenceResults buildResults(double[] processedInferences,
|
||||||
double[][] featureImportance,
|
double[][] featureImportance,
|
||||||
|
@ -208,7 +206,7 @@ public class EnsembleInferenceModel implements InferenceModel {
|
||||||
case REGRESSION:
|
case REGRESSION:
|
||||||
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
|
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
|
||||||
config,
|
config,
|
||||||
transformFeatureImportance(decodedFeatureImportance, null));
|
transformFeatureImportanceRegression(decodedFeatureImportance));
|
||||||
case CLASSIFICATION:
|
case CLASSIFICATION:
|
||||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||||
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
|
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
|
||||||
|
@ -220,10 +218,13 @@ public class EnsembleInferenceModel implements InferenceModel {
|
||||||
classificationConfig.getNumTopClasses(),
|
classificationConfig.getNumTopClasses(),
|
||||||
classificationConfig.getPredictionFieldType());
|
classificationConfig.getPredictionFieldType());
|
||||||
final InferenceHelpers.TopClassificationValue value = topClasses.v1();
|
final InferenceHelpers.TopClassificationValue value = topClasses.v1();
|
||||||
return new ClassificationInferenceResults((double)value.getValue(),
|
return new ClassificationInferenceResults(value.getValue(),
|
||||||
classificationLabel(topClasses.v1().getValue(), classificationLabels),
|
classificationLabel(topClasses.v1().getValue(), classificationLabels),
|
||||||
topClasses.v2(),
|
topClasses.v2(),
|
||||||
transformFeatureImportance(decodedFeatureImportance, classificationLabels),
|
transformFeatureImportanceClassification(decodedFeatureImportance,
|
||||||
|
value.getValue(),
|
||||||
|
classificationLabels,
|
||||||
|
classificationConfig.getPredictionFieldType()),
|
||||||
config,
|
config,
|
||||||
value.getProbability(),
|
value.getProbability(),
|
||||||
value.getScore());
|
value.getScore());
|
||||||
|
|
|
@ -188,14 +188,17 @@ public class TreeInferenceModel implements InferenceModel {
|
||||||
return new ClassificationInferenceResults(classificationValue.getValue(),
|
return new ClassificationInferenceResults(classificationValue.getValue(),
|
||||||
classificationLabel(classificationValue.getValue(), classificationLabels),
|
classificationLabel(classificationValue.getValue(), classificationLabels),
|
||||||
topClasses.v2(),
|
topClasses.v2(),
|
||||||
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, classificationLabels),
|
InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance,
|
||||||
|
classificationValue.getValue(),
|
||||||
|
classificationLabels,
|
||||||
|
classificationConfig.getPredictionFieldType()),
|
||||||
config,
|
config,
|
||||||
classificationValue.getProbability(),
|
classificationValue.getProbability(),
|
||||||
classificationValue.getScore());
|
classificationValue.getScore());
|
||||||
case REGRESSION:
|
case REGRESSION:
|
||||||
return new RegressionInferenceResults(value[0],
|
return new RegressionInferenceResults(value[0],
|
||||||
config,
|
config,
|
||||||
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, null));
|
InferenceHelpers.transformFeatureImportanceRegression(decodedFeatureImportance));
|
||||||
default:
|
default:
|
||||||
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
|
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,8 +12,10 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -185,8 +187,17 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
||||||
private static ConstructingObjectParser<ClassImportance, Void> createParser(boolean ignoreUnknownFields) {
|
private static ConstructingObjectParser<ClassImportance, Void> createParser(boolean ignoreUnknownFields) {
|
||||||
ConstructingObjectParser<ClassImportance, Void> parser = new ConstructingObjectParser<>(NAME,
|
ConstructingObjectParser<ClassImportance, Void> parser = new ConstructingObjectParser<>(NAME,
|
||||||
ignoreUnknownFields,
|
ignoreUnknownFields,
|
||||||
a -> new ClassImportance((String)a[0], (Importance)a[1]));
|
a -> new ClassImportance(a[0], (Importance)a[1]));
|
||||||
parser.declareString(ConstructingObjectParser.constructorArg(), CLASS_NAME);
|
parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
||||||
|
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
|
||||||
|
return p.text();
|
||||||
|
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
|
||||||
|
return p.numberValue();
|
||||||
|
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
|
||||||
|
return p.booleanValue();
|
||||||
|
}
|
||||||
|
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
|
||||||
|
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
|
||||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||||
ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER,
|
ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER,
|
||||||
IMPORTANCE);
|
IMPORTANCE);
|
||||||
|
@ -197,22 +208,22 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
||||||
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
|
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public final String className;
|
public final Object className;
|
||||||
public final Importance importance;
|
public final Importance importance;
|
||||||
|
|
||||||
public ClassImportance(StreamInput in) throws IOException {
|
public ClassImportance(StreamInput in) throws IOException {
|
||||||
this.className = in.readString();
|
this.className = in.readGenericValue();
|
||||||
this.importance = new Importance(in);
|
this.importance = new Importance(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
ClassImportance(String className, Importance importance) {
|
ClassImportance(Object className, Importance importance) {
|
||||||
this.className = className;
|
this.className = className;
|
||||||
this.importance = importance;
|
this.importance = importance;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void writeTo(StreamOutput out) throws IOException {
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
out.writeString(className);
|
out.writeGenericValue(className);
|
||||||
importance.writeTo(out);
|
importance.writeTo(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
|
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -154,10 +155,26 @@ public class InferenceDefinitionTests extends ESTestCase {
|
||||||
|
|
||||||
ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
|
ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
|
||||||
assertThat(results.valueAsString(), equalTo("second"));
|
assertThat(results.valueAsString(), equalTo("second"));
|
||||||
assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2"));
|
FeatureImportance featureImportance1 = results.getFeatureImportance().get(0);
|
||||||
assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001));
|
assertThat(featureImportance1.getFeatureName(), equalTo("col2"));
|
||||||
assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1_male"));
|
assertThat(featureImportance1.getImportance(), closeTo(0.944, 0.001));
|
||||||
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
|
for (FeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) {
|
||||||
|
if (classImportance.getClassName().equals("second")) {
|
||||||
|
assertThat(classImportance.getImportance(), closeTo(0.944, 0.001));
|
||||||
|
} else {
|
||||||
|
assertThat(classImportance.getImportance(), closeTo(-0.944, 0.001));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
FeatureImportance featureImportance2 = results.getFeatureImportance().get(1);
|
||||||
|
assertThat(featureImportance2.getFeatureName(), equalTo("col1_male"));
|
||||||
|
assertThat(featureImportance2.getImportance(), closeTo(0.199, 0.001));
|
||||||
|
for (FeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) {
|
||||||
|
if (classImportance.getClassName().equals("second")) {
|
||||||
|
assertThat(classImportance.getImportance(), closeTo(0.199, 0.001));
|
||||||
|
} else {
|
||||||
|
assertThat(classImportance.getImportance(), closeTo(-0.199, 0.001));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static String getClassificationDefinition(boolean customPreprocessor) {
|
public static String getClassificationDefinition(boolean customPreprocessor) {
|
||||||
|
|
Loading…
Reference in New Issue