[7.x][ML] Remove top level importance from classification inference results (#62486) (#62964)

As we have decided top level importance for classification is not useful,
it has been removed from the results from the training job. This commit
also removes them from inference.

Backport of #62486
This commit is contained in:
Dimitris Athanasiou 2020-09-29 10:58:48 +03:00 committed by GitHub
parent cc33df87d3
commit 7f6c1ff5b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 650 additions and 265 deletions

View File

@ -47,7 +47,7 @@ public class FeatureImportance implements ToXContentObject {
static { static {
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME)); PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); PARSER.declareDouble(optionalConstructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
PARSER.declareObjectArray(optionalConstructorArg(), PARSER.declareObjectArray(optionalConstructorArg(),
(p, c) -> ClassImportance.fromXContent(p), (p, c) -> ClassImportance.fromXContent(p),
new ParseField(FeatureImportance.CLASSES)); new ParseField(FeatureImportance.CLASSES));
@ -58,10 +58,10 @@ public class FeatureImportance implements ToXContentObject {
} }
private final List<ClassImportance> classImportance; private final List<ClassImportance> classImportance;
private final double importance; private final Double importance;
private final String featureName; private final String featureName;
public FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) { public FeatureImportance(String featureName, Double importance, List<ClassImportance> classImportance) {
this.featureName = Objects.requireNonNull(featureName); this.featureName = Objects.requireNonNull(featureName);
this.importance = importance; this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance); this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
@ -71,7 +71,7 @@ public class FeatureImportance implements ToXContentObject {
return classImportance; return classImportance;
} }
public double getImportance() { public Double getImportance() {
return importance; return importance;
} }
@ -83,7 +83,9 @@ public class FeatureImportance implements ToXContentObject {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(FEATURE_NAME, featureName); builder.field(FEATURE_NAME, featureName);
if (importance != null) {
builder.field(IMPORTANCE, importance); builder.field(IMPORTANCE, importance);
}
if (classImportance != null && classImportance.isEmpty() == false) { if (classImportance != null && classImportance.isEmpty() == false) {
builder.field(CLASSES, classImportance); builder.field(CLASSES, classImportance);
} }

View File

@ -32,7 +32,7 @@ public class FeatureImportanceTests extends AbstractXContentTestCase<FeatureImpo
protected FeatureImportance createTestInstance() { protected FeatureImportance createTestInstance() {
return new FeatureImportance( return new FeatureImportance(
randomAlphaOfLength(10), randomAlphaOfLength(10),
randomDoubleBetween(-10.0, 10.0, false), randomBoolean() ? null : randomDoubleBetween(-10.0, 10.0, false),
randomBoolean() ? null : randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10)) Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(2, 10)) .limit(randomLongBetween(2, 10))

View File

@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Map;
abstract class AbstractFeatureImportance implements Writeable, ToXContentObject {
public abstract String getFeatureName();
public abstract Map<String, Object> toMap();
@Override
public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.map(toMap());
}
}

View File

@ -5,7 +5,6 @@
*/ */
package org.elasticsearch.xpack.core.ml.inference.results; package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
@ -26,157 +25,101 @@ import java.util.stream.Collectors;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class FeatureImportance implements Writeable, ToXContentObject { public class ClassificationFeatureImportance extends AbstractFeatureImportance {
private final List<ClassImportance> classImportance; private final List<ClassImportance> classImportance;
private final double importance;
private final String featureName; private final String featureName;
static final String IMPORTANCE = "importance";
static final String FEATURE_NAME = "feature_name"; static final String FEATURE_NAME = "feature_name";
static final String CLASSES = "classes"; static final String CLASSES = "classes";
public static FeatureImportance forRegression(String featureName, double importance) {
return new FeatureImportance(featureName, importance, null);
}
public static FeatureImportance forBinaryClassification(String featureName, double importance, List<ClassImportance> classImportance) {
return new FeatureImportance(featureName,
importance,
classImportance);
}
public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
return new FeatureImportance(featureName,
classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
classImportance);
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER = private static final ConstructingObjectParser<ClassificationFeatureImportance, Void> PARSER =
new ConstructingObjectParser<>("feature_importance", new ConstructingObjectParser<>("classification_feature_importance",
a -> new FeatureImportance((String) a[0], (Double) a[1], (List<ClassImportance>) a[2]) a -> new ClassificationFeatureImportance((String) a[0], (List<ClassImportance>) a[1])
); );
static { static {
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME)); PARSER.declareString(constructorArg(), new ParseField(ClassificationFeatureImportance.FEATURE_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
PARSER.declareObjectArray(optionalConstructorArg(), PARSER.declareObjectArray(optionalConstructorArg(),
(p, c) -> ClassImportance.fromXContent(p), (p, c) -> ClassImportance.fromXContent(p),
new ParseField(FeatureImportance.CLASSES)); new ParseField(ClassificationFeatureImportance.CLASSES));
} }
public static FeatureImportance fromXContent(XContentParser parser) { public static ClassificationFeatureImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) { public ClassificationFeatureImportance(String featureName, List<ClassImportance> classImportance) {
this.featureName = Objects.requireNonNull(featureName); this.featureName = Objects.requireNonNull(featureName);
this.importance = importance; this.classImportance = classImportance == null ? Collections.emptyList() : Collections.unmodifiableList(classImportance);
this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
} }
public FeatureImportance(StreamInput in) throws IOException { public ClassificationFeatureImportance(StreamInput in) throws IOException {
this.featureName = in.readString(); this.featureName = in.readString();
this.importance = in.readDouble();
if (in.readBoolean()) {
if (in.getVersion().before(Version.V_7_10_0)) {
Map<String, Double> classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
this.classImportance = ClassImportance.fromMap(classImportance);
} else {
this.classImportance = in.readList(ClassImportance::new); this.classImportance = in.readList(ClassImportance::new);
} }
} else {
this.classImportance = null;
}
}
public List<ClassImportance> getClassImportance() { public List<ClassImportance> getClassImportance() {
return classImportance; return classImportance;
} }
public double getImportance() { @Override
return importance;
}
public String getFeatureName() { public String getFeatureName() {
return featureName; return featureName;
} }
@Override public double getTotalImportance() {
public void writeTo(StreamOutput out) throws IOException { if (classImportance.size() == 2) {
out.writeString(this.featureName); // Binary classification. We can return the first class importance here
out.writeDouble(this.importance); return Math.abs(classImportance.get(0).getImportance());
out.writeBoolean(this.classImportance != null);
if (this.classImportance != null) {
if (out.getVersion().before(Version.V_7_10_0)) {
out.writeMap(ClassImportance.toMap(this.classImportance), StreamOutput::writeString, StreamOutput::writeDouble);
} else {
out.writeList(this.classImportance);
}
} }
return classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum();
} }
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(featureName);
out.writeList(classImportance);
}
@Override
public Map<String, Object> toMap() { public Map<String, Object> toMap() {
Map<String, Object> map = new LinkedHashMap<>(); Map<String, Object> map = new LinkedHashMap<>();
map.put(FEATURE_NAME, featureName); map.put(FEATURE_NAME, featureName);
map.put(IMPORTANCE, importance); if (classImportance.isEmpty() == false) {
if (classImportance != null) {
map.put(CLASSES, classImportance.stream().map(ClassImportance::toMap).collect(Collectors.toList())); map.put(CLASSES, classImportance.stream().map(ClassImportance::toMap).collect(Collectors.toList()));
} }
return map; return map;
} }
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FEATURE_NAME, featureName);
builder.field(IMPORTANCE, importance);
if (classImportance != null && classImportance.isEmpty() == false) {
builder.field(CLASSES, classImportance);
}
builder.endObject();
return builder;
}
@Override @Override
public boolean equals(Object object) { public boolean equals(Object object) {
if (object == this) { return true; } if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; } if (object == null || getClass() != object.getClass()) { return false; }
FeatureImportance that = (FeatureImportance) object; ClassificationFeatureImportance that = (ClassificationFeatureImportance) object;
return Objects.equals(featureName, that.featureName) return Objects.equals(featureName, that.featureName)
&& Objects.equals(importance, that.importance)
&& Objects.equals(classImportance, that.classImportance); && Objects.equals(classImportance, that.classImportance);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(featureName, importance, classImportance); return Objects.hash(featureName, classImportance);
} }
public static class ClassImportance implements Writeable, ToXContentObject { public static class ClassImportance implements Writeable, ToXContentObject {
static final String CLASS_NAME = "class_name"; static final String CLASS_NAME = "class_name";
static final String IMPORTANCE = "importance";
private static final ConstructingObjectParser<ClassImportance, Void> PARSER = private static final ConstructingObjectParser<ClassImportance, Void> PARSER =
new ConstructingObjectParser<>("feature_importance_class_importance", new ConstructingObjectParser<>("classification_feature_importance_class_importance",
a -> new ClassImportance((String) a[0], (Double) a[1]) a -> new ClassImportance(a[0], (Double) a[1])
); );
static { static {
PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME)); PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); PARSER.declareDouble(constructorArg(), new ParseField(IMPORTANCE));
}
private static ClassImportance fromMapEntry(Map.Entry<String, Double> entry) {
return new ClassImportance(entry.getKey(), entry.getValue());
}
private static List<ClassImportance> fromMap(Map<String, Double> classImportanceMap) {
return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList());
}
private static Map<String, Double> toMap(List<ClassImportance> importances) {
return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
} }
public static ClassImportance fromXContent(XContentParser parser) { public static ClassImportance fromXContent(XContentParser parser) {
@ -219,11 +162,7 @@ public class FeatureImportance implements Writeable, ToXContentObject {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); return builder.map(toMap());
builder.field(CLASS_NAME, className);
builder.field(IMPORTANCE, importance);
builder.endObject();
return builder;
} }
@Override @Override

View File

@ -15,9 +15,9 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldTyp
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.Map;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -34,12 +34,13 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
private final Double predictionProbability; private final Double predictionProbability;
private final Double predictionScore; private final Double predictionScore;
private final List<TopClassEntry> topClasses; private final List<TopClassEntry> topClasses;
private final List<ClassificationFeatureImportance> featureImportance;
private final PredictionFieldType predictionFieldType; private final PredictionFieldType predictionFieldType;
public ClassificationInferenceResults(double value, public ClassificationInferenceResults(double value,
String classificationLabel, String classificationLabel,
List<TopClassEntry> topClasses, List<TopClassEntry> topClasses,
List<FeatureImportance> featureImportance, List<ClassificationFeatureImportance> featureImportance,
InferenceConfig config, InferenceConfig config,
Double predictionProbability, Double predictionProbability,
Double predictionScore) { Double predictionScore) {
@ -55,13 +56,11 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
private ClassificationInferenceResults(double value, private ClassificationInferenceResults(double value,
String classificationLabel, String classificationLabel,
List<TopClassEntry> topClasses, List<TopClassEntry> topClasses,
List<FeatureImportance> featureImportance, List<ClassificationFeatureImportance> featureImportance,
ClassificationConfig classificationConfig, ClassificationConfig classificationConfig,
Double predictionProbability, Double predictionProbability,
Double predictionScore) { Double predictionScore) {
super(value, super(value);
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
classificationConfig.getNumTopFeatureImportanceValues()));
this.classificationLabel = classificationLabel; this.classificationLabel = classificationLabel;
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses); this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
this.topNumClassesField = classificationConfig.getTopClassesResultsField(); this.topNumClassesField = classificationConfig.getTopClassesResultsField();
@ -69,10 +68,32 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
this.predictionFieldType = classificationConfig.getPredictionFieldType(); this.predictionFieldType = classificationConfig.getPredictionFieldType();
this.predictionProbability = predictionProbability; this.predictionProbability = predictionProbability;
this.predictionScore = predictionScore; this.predictionScore = predictionScore;
this.featureImportance = takeTopFeatureImportances(featureImportance, classificationConfig.getNumTopFeatureImportanceValues());
}
static List<ClassificationFeatureImportance> takeTopFeatureImportances(List<ClassificationFeatureImportance> featureImportances,
int numTopFeatures) {
if (featureImportances == null || featureImportances.isEmpty()) {
return Collections.emptyList();
}
return featureImportances.stream()
.sorted((l, r)-> Double.compare(r.getTotalImportance(), l.getTotalImportance()))
.limit(numTopFeatures)
.collect(Collectors.toList());
} }
public ClassificationInferenceResults(StreamInput in) throws IOException { public ClassificationInferenceResults(StreamInput in) throws IOException {
super(in); super(in);
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.featureImportance = in.readList(ClassificationFeatureImportance::new);
} else if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.featureImportance = in.readList(LegacyFeatureImportance::new)
.stream()
.map(LegacyFeatureImportance::forClassification)
.collect(Collectors.toList());
} else {
this.featureImportance = Collections.emptyList();
}
this.classificationLabel = in.readOptionalString(); this.classificationLabel = in.readOptionalString();
this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new)); this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new));
this.topNumClassesField = in.readString(); this.topNumClassesField = in.readString();
@ -103,9 +124,18 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
return predictionFieldType; return predictionFieldType;
} }
public List<ClassificationFeatureImportance> getFeatureImportance() {
return featureImportance;
}
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out); super.writeTo(out);
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeList(featureImportance);
} else if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeList(featureImportance.stream().map(LegacyFeatureImportance::fromClassification).collect(Collectors.toList()));
}
out.writeOptionalString(classificationLabel); out.writeOptionalString(classificationLabel);
out.writeCollection(topClasses); out.writeCollection(topClasses);
out.writeString(topNumClassesField); out.writeString(topNumClassesField);
@ -132,7 +162,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
&& Objects.equals(predictionFieldType, that.predictionFieldType) && Objects.equals(predictionFieldType, that.predictionFieldType)
&& Objects.equals(predictionProbability, that.predictionProbability) && Objects.equals(predictionProbability, that.predictionProbability)
&& Objects.equals(predictionScore, that.predictionScore) && Objects.equals(predictionScore, that.predictionScore)
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance()); && Objects.equals(featureImportance, that.featureImportance);
} }
@Override @Override
@ -144,7 +174,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
topNumClassesField, topNumClassesField,
predictionProbability, predictionProbability,
predictionScore, predictionScore,
getFeatureImportance(), featureImportance,
predictionFieldType); predictionFieldType);
} }
@ -179,8 +209,9 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
if (predictionScore != null) { if (predictionScore != null) {
map.put(PREDICTION_SCORE, predictionScore); map.put(PREDICTION_SCORE, predictionScore);
} }
if (getFeatureImportance().isEmpty() == false) { if (featureImportance.isEmpty() == false) {
map.put(FEATURE_IMPORTANCE, getFeatureImportance().stream().map(FeatureImportance::toMap).collect(Collectors.toList())); map.put(FEATURE_IMPORTANCE, featureImportance.stream().map(ClassificationFeatureImportance::toMap)
.collect(Collectors.toList()));
} }
return map; return map;
} }
@ -202,8 +233,8 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
if (predictionScore != null) { if (predictionScore != null) {
builder.field(PREDICTION_SCORE, predictionScore); builder.field(PREDICTION_SCORE, predictionScore);
} }
if (getFeatureImportance().size() > 0) { if (featureImportance.isEmpty() == false) {
builder.field(FEATURE_IMPORTANCE, getFeatureImportance()); builder.field(FEATURE_IMPORTANCE, featureImportance);
} }
return builder; return builder;
} }

View File

@ -0,0 +1,160 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* This class captures serialization of feature importance for
* classification and regression prior to version 7.10.
*/
public class LegacyFeatureImportance implements Writeable {
public static LegacyFeatureImportance fromClassification(ClassificationFeatureImportance classificationFeatureImportance) {
return new LegacyFeatureImportance(
classificationFeatureImportance.getFeatureName(),
classificationFeatureImportance.getTotalImportance(),
classificationFeatureImportance.getClassImportance().stream().map(classImportance -> new ClassImportance(
classImportance.getClassName(), classImportance.getImportance())).collect(Collectors.toList())
);
}
public static LegacyFeatureImportance fromRegression(RegressionFeatureImportance regressionFeatureImportance) {
return new LegacyFeatureImportance(
regressionFeatureImportance.getFeatureName(),
regressionFeatureImportance.getImportance(),
null
);
}
private final List<ClassImportance> classImportance;
private final double importance;
private final String featureName;
LegacyFeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
this.featureName = Objects.requireNonNull(featureName);
this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
}
public LegacyFeatureImportance(StreamInput in) throws IOException {
this.featureName = in.readString();
this.importance = in.readDouble();
if (in.readBoolean()) {
if (in.getVersion().before(Version.V_7_10_0)) {
Map<String, Double> classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
this.classImportance = ClassImportance.fromMap(classImportance);
} else {
this.classImportance = in.readList(ClassImportance::new);
}
} else {
this.classImportance = null;
}
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(featureName);
out.writeDouble(importance);
out.writeBoolean(classImportance != null);
if (classImportance != null) {
if (out.getVersion().before(Version.V_7_10_0)) {
out.writeMap(ClassImportance.toMap(classImportance), StreamOutput::writeString, StreamOutput::writeDouble);
} else {
out.writeList(classImportance);
}
}
}
@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
LegacyFeatureImportance that = (LegacyFeatureImportance) object;
return Objects.equals(featureName, that.featureName)
&& Objects.equals(importance, that.importance)
&& Objects.equals(classImportance, that.classImportance);
}
@Override
public int hashCode() {
return Objects.hash(featureName, importance, classImportance);
}
public RegressionFeatureImportance forRegression() {
assert classImportance == null;
return new RegressionFeatureImportance(featureName, importance);
}
public ClassificationFeatureImportance forClassification() {
assert classImportance != null;
return new ClassificationFeatureImportance(featureName, classImportance.stream().map(
aClassImportance -> new ClassificationFeatureImportance.ClassImportance(
aClassImportance.className, aClassImportance.importance)).collect(Collectors.toList()));
}
public static class ClassImportance implements Writeable {
private static ClassImportance fromMapEntry(Map.Entry<String, Double> entry) {
return new ClassImportance(entry.getKey(), entry.getValue());
}
private static List<ClassImportance> fromMap(Map<String, Double> classImportanceMap) {
return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList());
}
private static Map<String, Double> toMap(List<ClassImportance> importances) {
return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
}
private final Object className;
private final double importance;
public ClassImportance(Object className, double importance) {
this.className = className;
this.importance = importance;
}
public ClassImportance(StreamInput in) throws IOException {
this.className = in.readGenericValue();
this.importance = in.readDouble();
}
double getImportance() {
return importance;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeGenericValue(className);
out.writeDouble(importance);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassImportance that = (ClassImportance) o;
return Double.compare(that.importance, importance) == 0 &&
Objects.equals(className, that.className);
}
@Override
public int hashCode() {
return Objects.hash(className, importance);
}
}
}

View File

@ -0,0 +1,88 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
public class RegressionFeatureImportance extends AbstractFeatureImportance {
private final double importance;
private final String featureName;
static final String IMPORTANCE = "importance";
static final String FEATURE_NAME = "feature_name";
private static final ConstructingObjectParser<RegressionFeatureImportance, Void> PARSER =
new ConstructingObjectParser<>("regression_feature_importance",
a -> new RegressionFeatureImportance((String) a[0], (Double) a[1])
);
static {
PARSER.declareString(constructorArg(), new ParseField(RegressionFeatureImportance.FEATURE_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(RegressionFeatureImportance.IMPORTANCE));
}
public static RegressionFeatureImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public RegressionFeatureImportance(String featureName, double importance) {
this.featureName = Objects.requireNonNull(featureName);
this.importance = importance;
}
public RegressionFeatureImportance(StreamInput in) throws IOException {
this.featureName = in.readString();
this.importance = in.readDouble();
}
public double getImportance() {
return importance;
}
@Override
public String getFeatureName() {
return featureName;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(featureName);
out.writeDouble(importance);
}
@Override
public Map<String, Object> toMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(FEATURE_NAME, featureName);
map.put(IMPORTANCE, importance);
return map;
}
@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RegressionFeatureImportance that = (RegressionFeatureImportance) object;
return Objects.equals(featureName, that.featureName)
&& Objects.equals(importance, that.importance);
}
@Override
public int hashCode() {
return Objects.hash(featureName, importance);
}
}

View File

@ -5,6 +5,7 @@
*/ */
package org.elasticsearch.xpack.core.ml.inference.results; package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
@ -24,14 +25,19 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
public static final String NAME = "regression"; public static final String NAME = "regression";
private final String resultsField; private final String resultsField;
private final List<RegressionFeatureImportance> featureImportance;
public RegressionInferenceResults(double value, InferenceConfig config) { public RegressionInferenceResults(double value, InferenceConfig config) {
this(value, config, Collections.emptyList()); this(value, config, Collections.emptyList());
} }
public RegressionInferenceResults(double value, InferenceConfig config, List<FeatureImportance> featureImportance) { public RegressionInferenceResults(double value, InferenceConfig config, List<RegressionFeatureImportance> featureImportance) {
this(value, ((RegressionConfig)config).getResultsField(), this(
((RegressionConfig)config).getNumTopFeatureImportanceValues(), featureImportance); value,
((RegressionConfig)config).getResultsField(),
((RegressionConfig)config).getNumTopFeatureImportanceValues(),
featureImportance
);
} }
public RegressionInferenceResults(double value, String resultsField) { public RegressionInferenceResults(double value, String resultsField) {
@ -39,28 +45,58 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
} }
public RegressionInferenceResults(double value, String resultsField, public RegressionInferenceResults(double value, String resultsField,
List<FeatureImportance> featureImportance) { List<RegressionFeatureImportance> featureImportance) {
this(value, resultsField, featureImportance.size(), featureImportance); this(value, resultsField, featureImportance.size(), featureImportance);
} }
public RegressionInferenceResults(double value, String resultsField, int topNFeatures, public RegressionInferenceResults(double value, String resultsField, int topNFeatures,
List<FeatureImportance> featureImportance) { List<RegressionFeatureImportance> featureImportance) {
super(value, super(value);
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance, topNFeatures));
this.resultsField = resultsField; this.resultsField = resultsField;
this.featureImportance = takeTopFeatureImportances(featureImportance, topNFeatures);
}
static List<RegressionFeatureImportance> takeTopFeatureImportances(List<RegressionFeatureImportance> featureImportances,
int numTopFeatures) {
if (featureImportances == null || featureImportances.isEmpty()) {
return Collections.emptyList();
}
return featureImportances.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())))
.limit(numTopFeatures)
.collect(Collectors.toList());
} }
public RegressionInferenceResults(StreamInput in) throws IOException { public RegressionInferenceResults(StreamInput in) throws IOException {
super(in); super(in);
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.featureImportance = in.readList(RegressionFeatureImportance::new);
} else if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.featureImportance = in.readList(LegacyFeatureImportance::new)
.stream()
.map(LegacyFeatureImportance::forRegression)
.collect(Collectors.toList());
} else {
this.featureImportance = Collections.emptyList();
}
this.resultsField = in.readString(); this.resultsField = in.readString();
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out); super.writeTo(out);
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeList(featureImportance);
} else if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeList(featureImportance.stream().map(LegacyFeatureImportance::fromRegression).collect(Collectors.toList()));
}
out.writeString(resultsField); out.writeString(resultsField);
} }
public List<RegressionFeatureImportance> getFeatureImportance() {
return featureImportance;
}
@Override @Override
public boolean equals(Object object) { public boolean equals(Object object) {
if (object == this) { return true; } if (object == this) { return true; }
@ -68,12 +104,12 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
RegressionInferenceResults that = (RegressionInferenceResults) object; RegressionInferenceResults that = (RegressionInferenceResults) object;
return Objects.equals(value(), that.value()) return Objects.equals(value(), that.value())
&& Objects.equals(this.resultsField, that.resultsField) && Objects.equals(this.resultsField, that.resultsField)
&& Objects.equals(this.getFeatureImportance(), that.getFeatureImportance()); && Objects.equals(this.featureImportance, that.featureImportance);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(value(), resultsField, getFeatureImportance()); return Objects.hash(value(), resultsField, featureImportance);
} }
@Override @Override
@ -85,8 +121,8 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
public Map<String, Object> asMap() { public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>(); Map<String, Object> map = new LinkedHashMap<>();
map.put(resultsField, value()); map.put(resultsField, value());
if (getFeatureImportance().isEmpty() == false) { if (featureImportance.isEmpty() == false) {
map.put(FEATURE_IMPORTANCE, getFeatureImportance().stream().map(FeatureImportance::toMap).collect(Collectors.toList())); map.put(FEATURE_IMPORTANCE, featureImportance.stream().map(RegressionFeatureImportance::toMap).collect(Collectors.toList()));
} }
return map; return map;
} }
@ -94,8 +130,8 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(resultsField, value()); builder.field(resultsField, value());
if (getFeatureImportance().size() > 0) { if (featureImportance.isEmpty() == false) {
builder.field(FEATURE_IMPORTANCE, getFeatureImportance()); builder.field(FEATURE_IMPORTANCE, featureImportance);
} }
return builder; return builder;
} }

View File

@ -5,53 +5,30 @@
*/ */
package org.elasticsearch.xpack.core.ml.inference.results; package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import java.io.IOException; import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
public abstract class SingleValueInferenceResults implements InferenceResults { public abstract class SingleValueInferenceResults implements InferenceResults {
public static final String FEATURE_IMPORTANCE = "feature_importance"; public static final String FEATURE_IMPORTANCE = "feature_importance";
private final double value; private final double value;
private final List<FeatureImportance> featureImportance;
static List<FeatureImportance> takeTopFeatureImportances(List<FeatureImportance> unsortedFeatureImportances, int numTopFeatures) {
if (unsortedFeatureImportances == null || unsortedFeatureImportances.isEmpty()) {
return unsortedFeatureImportances;
}
return unsortedFeatureImportances.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())))
.limit(numTopFeatures)
.collect(Collectors.toList());
}
SingleValueInferenceResults(StreamInput in) throws IOException { SingleValueInferenceResults(StreamInput in) throws IOException {
value = in.readDouble(); value = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.featureImportance = in.readList(FeatureImportance::new);
} else {
this.featureImportance = Collections.emptyList();
}
} }
SingleValueInferenceResults(double value, List<FeatureImportance> featureImportance) { SingleValueInferenceResults(double value) {
this.value = value; this.value = value;
this.featureImportance = featureImportance == null ? Collections.emptyList() : featureImportance;
} }
public Double value() { public Double value() {
return value; return value;
} }
public List<FeatureImportance> getFeatureImportance() {
return featureImportance;
}
public String valueAsString() { public String valueAsString() {
return String.valueOf(value); return String.valueOf(value);
@ -60,9 +37,6 @@ public abstract class SingleValueInferenceResults implements InferenceResults {
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(value); out.writeDouble(value);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeList(this.featureImportance);
}
} }
} }

View File

@ -7,7 +7,8 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -130,17 +131,18 @@ public final class InferenceHelpers {
return originalFeatureImportance; return originalFeatureImportance;
} }
public static List<FeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) { public static List<RegressionFeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) {
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size()); List<RegressionFeatureImportance> importances = new ArrayList<>(featureImportance.size());
featureImportance.forEach((k, v) -> importances.add(FeatureImportance.forRegression(k, v[0]))); featureImportance.forEach((k, v) -> importances.add(new RegressionFeatureImportance(k, v[0])));
return importances; return importances;
} }
public static List<FeatureImportance> transformFeatureImportanceClassification(Map<String, double[]> featureImportance, public static List<ClassificationFeatureImportance> transformFeatureImportanceClassification(
Map<String, double[]> featureImportance,
final int predictedValue, final int predictedValue,
@Nullable List<String> classificationLabels, @Nullable List<String> classificationLabels,
@Nullable PredictionFieldType predictionFieldType) { @Nullable PredictionFieldType predictionFieldType) {
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size()); List<ClassificationFeatureImportance> importances = new ArrayList<>(featureImportance.size());
final PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType; final PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
featureImportance.forEach((k, v) -> { featureImportance.forEach((k, v) -> {
// This indicates logistic regression (binary classification) // This indicates logistic regression (binary classification)
@ -152,27 +154,26 @@ public final class InferenceHelpers {
final int otherClass = 1 - predictedValue; final int otherClass = 1 - predictedValue;
String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue); String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue);
String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass); String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass);
importances.add(FeatureImportance.forBinaryClassification(k, importances.add(new ClassificationFeatureImportance(k,
v[0],
Arrays.asList( Arrays.asList(
new FeatureImportance.ClassImportance( new ClassificationFeatureImportance.ClassImportance(
fieldType.transformPredictedValue((double)predictedValue, predictedLabel), fieldType.transformPredictedValue((double)predictedValue, predictedLabel),
v[0]), v[0]),
new FeatureImportance.ClassImportance( new ClassificationFeatureImportance.ClassImportance(
fieldType.transformPredictedValue((double)otherClass, otherLabel), fieldType.transformPredictedValue((double)otherClass, otherLabel),
-v[0]) -v[0])
))); )));
} else { } else {
List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length); List<ClassificationFeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
// If the classificationLabels exist, their length must match leaf_value length // If the classificationLabels exist, their length must match leaf_value length
assert classificationLabels == null || classificationLabels.size() == v.length; assert classificationLabels == null || classificationLabels.size() == v.length;
for (int i = 0; i < v.length; i++) { for (int i = 0; i < v.length; i++) {
String label = classificationLabels == null ? null : classificationLabels.get(i); String label = classificationLabels == null ? null : classificationLabels.get(i);
classImportance.add(new FeatureImportance.ClassImportance( classImportance.add(new ClassificationFeatureImportance.ClassImportance(
fieldType.transformPredictedValue((double)i, label), fieldType.transformPredictedValue((double)i, label),
v[i])); v[i]));
} }
importances.add(FeatureImportance.forClassification(k, classImportance)); importances.add(new ClassificationFeatureImportance(k, classImportance));
} }
}); });
return importances; return importances;

View File

@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
public class ClassificationFeatureImportanceTests extends AbstractSerializingTestCase<ClassificationFeatureImportance> {
@Override
protected ClassificationFeatureImportance doParseInstance(XContentParser parser) throws IOException {
return ClassificationFeatureImportance.fromXContent(parser);
}
@Override
protected Writeable.Reader<ClassificationFeatureImportance> instanceReader() {
return ClassificationFeatureImportance::new;
}
@Override
protected ClassificationFeatureImportance createTestInstance() {
return createRandomInstance();
}
public static ClassificationFeatureImportance createRandomInstance() {
return new ClassificationFeatureImportance(
randomAlphaOfLength(10),
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(2, 10))
.map(name -> new ClassificationFeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
.collect(Collectors.toList()));
}
public void testGetTotalImportance_GivenBinary() {
ClassificationFeatureImportance featureImportance = new ClassificationFeatureImportance(
"binary",
Arrays.asList(
new ClassificationFeatureImportance.ClassImportance("a", 0.15),
new ClassificationFeatureImportance.ClassImportance("not-a", -0.15)
)
);
assertThat(featureImportance.getTotalImportance(), equalTo(0.15));
}
public void testGetTotalImportance_GivenMulticlass() {
ClassificationFeatureImportance featureImportance = new ClassificationFeatureImportance(
"multiclass",
Arrays.asList(
new ClassificationFeatureImportance.ClassImportance("a", 0.15),
new ClassificationFeatureImportance.ClassImportance("b", -0.05),
new ClassificationFeatureImportance.ClassImportance("c", 0.30)
)
);
assertThat(featureImportance.getTotalImportance(), closeTo(0.50, 0.00000001));
}
}

View File

@ -18,7 +18,6 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -29,10 +28,6 @@ import static org.hamcrest.Matchers.hasSize;
public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase<ClassificationInferenceResults> { public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase<ClassificationInferenceResults> {
public static ClassificationInferenceResults createRandomResults() { public static ClassificationInferenceResults createRandomResults() {
Supplier<FeatureImportance> featureImportanceCtor = randomBoolean() ?
FeatureImportanceTests::randomClassification :
FeatureImportanceTests::randomRegression;
ClassificationConfig config = ClassificationConfigTests.randomClassificationConfig(); ClassificationConfig config = ClassificationConfigTests.randomClassificationConfig();
Double value = randomDouble(); Double value = randomDouble();
if (config.getPredictionFieldType() == PredictionFieldType.BOOLEAN) { if (config.getPredictionFieldType() == PredictionFieldType.BOOLEAN) {
@ -47,7 +42,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
.limit(randomIntBetween(0, 10)) .limit(randomIntBetween(0, 10))
.collect(Collectors.toList()), .collect(Collectors.toList()),
randomBoolean() ? null : randomBoolean() ? null :
Stream.generate(featureImportanceCtor) Stream.generate(ClassificationFeatureImportanceTests::createRandomInstance)
.limit(randomIntBetween(1, 10)) .limit(randomIntBetween(1, 10))
.collect(Collectors.toList()), .collect(Collectors.toList()),
config, config,
@ -123,11 +118,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
} }
public void testWriteResultsWithImportance() { public void testWriteResultsWithImportance() {
Supplier<FeatureImportance> featureImportanceCtor = randomBoolean() ? List<ClassificationFeatureImportance> importanceList = Stream.generate(ClassificationFeatureImportanceTests::createRandomInstance)
FeatureImportanceTests::randomClassification :
FeatureImportanceTests::randomRegression;
List<FeatureImportance> importanceList = Stream.generate(featureImportanceCtor)
.limit(5) .limit(5)
.collect(Collectors.toList()); .collect(Collectors.toList());
ClassificationInferenceResults result = new ClassificationInferenceResults(0.0, ClassificationInferenceResults result = new ClassificationInferenceResults(0.0,
@ -146,18 +137,17 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
"result_field.feature_importance", "result_field.feature_importance",
List.class); List.class);
assertThat(writtenImportance, hasSize(3)); assertThat(writtenImportance, hasSize(3));
importanceList.sort((l, r) -> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance()))); importanceList.sort((l, r) -> Double.compare(Math.abs(r.getTotalImportance()), Math.abs(l.getTotalImportance())));
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
Map<String, Object> objectMap = writtenImportance.get(i); Map<String, Object> objectMap = writtenImportance.get(i);
FeatureImportance importance = importanceList.get(i); ClassificationFeatureImportance importance = importanceList.get(i);
assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName())); assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
List<Map<String, Object>> classImportances = (List<Map<String, Object>>)objectMap.get("classes"); List<Map<String, Object>> classImportances = (List<Map<String, Object>>)objectMap.get("classes");
if (importance.getClassImportance() != null) { if (importance.getClassImportance() != null) {
for (int j = 0; j < importance.getClassImportance().size(); j++) { for (int j = 0; j < importance.getClassImportance().size(); j++) {
Map<String, Object> classMap = classImportances.get(j); Map<String, Object> classMap = classImportances.get(j);
FeatureImportance.ClassImportance classImportance = importance.getClassImportance().get(j); ClassificationFeatureImportance.ClassImportance classImportance = importance.getClassImportance().get(j);
assertThat(classMap.get("class_name"), equalTo(classImportance.getClassName())); assertThat(classMap.get("class_name"), equalTo(classImportance.getClassName()));
assertThat(classMap.get("importance"), equalTo(classImportance.getImportance())); assertThat(classMap.get("importance"), equalTo(classImportance.getImportance()));
} }
@ -212,7 +202,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}"; expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}";
assertEquals(expected, stringRep); assertEquals(expected, stringRep);
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList()); ClassificationFeatureImportance fi = new ClassificationFeatureImportance("foo", Collections.emptyList());
TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0); TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0);
result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp), result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp),
Collections.singletonList(fi), config, Collections.singletonList(fi), config,

View File

@ -1,49 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class FeatureImportanceTests extends AbstractSerializingTestCase<FeatureImportance> {
public static FeatureImportance createRandomInstance() {
return randomBoolean() ? randomClassification() : randomRegression();
}
static FeatureImportance randomRegression() {
return FeatureImportance.forRegression(randomAlphaOfLength(10), randomDoubleBetween(-10.0, 10.0, false));
}
static FeatureImportance randomClassification() {
return FeatureImportance.forClassification(
randomAlphaOfLength(10),
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(2, 10))
.map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
.collect(Collectors.toList()));
}
@Override
protected FeatureImportance createTestInstance() {
return createRandomInstance();
}
@Override
protected Writeable.Reader<FeatureImportance> instanceReader() {
return FeatureImportance::new;
}
@Override
protected FeatureImportance doParseInstance(XContentParser parser) throws IOException {
return FeatureImportance.fromXContent(parser);
}
}

View File

@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.equalTo;
public class LegacyFeatureImportanceTests extends AbstractWireSerializingTestCase<LegacyFeatureImportance> {
public static LegacyFeatureImportance createRandomInstance() {
return createRandomInstance(randomBoolean());
}
public static LegacyFeatureImportance createRandomInstance(boolean hasClasses) {
double importance = randomDouble();
List<LegacyFeatureImportance.ClassImportance> classImportances = null;
if (hasClasses) {
classImportances = Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(2, 10))
.map(featureName -> new LegacyFeatureImportance.ClassImportance(featureName, randomDouble()))
.collect(Collectors.toList());
importance = classImportances.stream().mapToDouble(LegacyFeatureImportance.ClassImportance::getImportance).map(Math::abs).sum();
}
return new LegacyFeatureImportance(randomAlphaOfLength(10), importance, classImportances);
}
@Override
protected LegacyFeatureImportance createTestInstance() {
return createRandomInstance();
}
@Override
protected Writeable.Reader<LegacyFeatureImportance> instanceReader() {
return LegacyFeatureImportance::new;
}
public void testClassificationConversion() {
{
ClassificationFeatureImportance classificationFeatureImportance = ClassificationFeatureImportanceTests.createRandomInstance();
LegacyFeatureImportance legacyFeatureImportance = LegacyFeatureImportance.fromClassification(classificationFeatureImportance);
ClassificationFeatureImportance convertedFeatureImportance = legacyFeatureImportance.forClassification();
assertThat(convertedFeatureImportance, equalTo(classificationFeatureImportance));
}
{
LegacyFeatureImportance legacyFeatureImportance = createRandomInstance(true);
ClassificationFeatureImportance classificationFeatureImportance = legacyFeatureImportance.forClassification();
LegacyFeatureImportance convertedFeatureImportance = LegacyFeatureImportance.fromClassification(
classificationFeatureImportance);
assertThat(convertedFeatureImportance, equalTo(legacyFeatureImportance));
}
}
public void testRegressionConversion() {
{
RegressionFeatureImportance regressionFeatureImportance = RegressionFeatureImportanceTests.createRandomInstance();
LegacyFeatureImportance legacyFeatureImportance = LegacyFeatureImportance.fromRegression(regressionFeatureImportance);
RegressionFeatureImportance convertedFeatureImportance = legacyFeatureImportance.forRegression();
assertThat(convertedFeatureImportance, equalTo(regressionFeatureImportance));
}
{
LegacyFeatureImportance legacyFeatureImportance = createRandomInstance(false);
RegressionFeatureImportance regressionFeatureImportance = legacyFeatureImportance.forRegression();
LegacyFeatureImportance convertedFeatureImportance = LegacyFeatureImportance.fromRegression(regressionFeatureImportance);
assertThat(convertedFeatureImportance, equalTo(legacyFeatureImportance));
}
}
}

View File

@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
public class RegressionFeatureImportanceTests extends AbstractSerializingTestCase<RegressionFeatureImportance> {
@Override
protected RegressionFeatureImportance doParseInstance(XContentParser parser) throws IOException {
return RegressionFeatureImportance.fromXContent(parser);
}
@Override
protected Writeable.Reader<RegressionFeatureImportance> instanceReader() {
return RegressionFeatureImportance::new;
}
@Override
protected RegressionFeatureImportance createTestInstance() {
return createRandomInstance();
}
public static RegressionFeatureImportance createRandomInstance() {
return new RegressionFeatureImportance(randomAlphaOfLength(10), randomDoubleBetween(-10.0, 10.0, false));
}
}

View File

@ -29,8 +29,8 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
public static RegressionInferenceResults createRandomResults() { public static RegressionInferenceResults createRandomResults() {
return new RegressionInferenceResults(randomDouble(), return new RegressionInferenceResults(randomDouble(),
RegressionConfigTests.randomRegressionConfig(), RegressionConfigTests.randomRegressionConfig(),
randomBoolean() ? null : randomBoolean() ? Collections.emptyList() :
Stream.generate(FeatureImportanceTests::randomRegression) Stream.generate(RegressionFeatureImportanceTests::createRandomInstance)
.limit(randomIntBetween(1, 10)) .limit(randomIntBetween(1, 10))
.collect(Collectors.toList())); .collect(Collectors.toList()));
} }
@ -50,7 +50,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
} }
public void testWriteResultsWithImportance() { public void testWriteResultsWithImportance() {
List<FeatureImportance> importanceList = Stream.generate(FeatureImportanceTests::randomRegression) List<RegressionFeatureImportance> importanceList = Stream.generate(RegressionFeatureImportanceTests::createRandomInstance)
.limit(5) .limit(5)
.collect(Collectors.toList()); .collect(Collectors.toList());
RegressionInferenceResults result = new RegressionInferenceResults(0.3, RegressionInferenceResults result = new RegressionInferenceResults(0.3,
@ -68,7 +68,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance()))); importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
Map<String, Object> objectMap = writtenImportance.get(i); Map<String, Object> objectMap = writtenImportance.get(i);
FeatureImportance importance = importanceList.get(i); RegressionFeatureImportance importance = importanceList.get(i);
assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName())); assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
assertThat(objectMap.get("importance"), equalTo(importance.getImportance())); assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
assertThat(objectMap.size(), equalTo(2)); assertThat(objectMap.size(), equalTo(2));
@ -92,7 +92,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
String expected = "{\"" + resultsField + "\":1.0}"; String expected = "{\"" + resultsField + "\":1.0}";
assertEquals(expected, stringRep); assertEquals(expected, stringRep);
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList()); RegressionFeatureImportance fi = new RegressionFeatureImportance("foo", 1.0);
result = new RegressionInferenceResults(1.0, resultsField, Collections.singletonList(fi)); result = new RegressionInferenceResults(1.0, resultsField, Collections.singletonList(fi));
stringRep = Strings.toString(result); stringRep = Strings.toString(result);
expected = "{\"" + resultsField + "\":1.0,\"feature_importance\":[{\"feature_name\":\"foo\",\"importance\":1.0}]}"; expected = "{\"" + resultsField + "\":1.0,\"feature_importance\":[{\"feature_name\":\"foo\",\"importance\":1.0}]}";

View File

@ -16,8 +16,8 @@ import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor; import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import java.io.IOException; import java.io.IOException;
@ -134,9 +134,9 @@ public class InferenceDefinitionTests extends ESTestCase {
ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config); ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
assertThat(results.valueAsString(), equalTo("second")); assertThat(results.valueAsString(), equalTo("second"));
assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2")); assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2"));
assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001)); assertThat(results.getFeatureImportance().get(0).getTotalImportance(), closeTo(0.944, 0.001));
assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1")); assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1"));
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001)); assertThat(results.getFeatureImportance().get(1).getTotalImportance(), closeTo(0.199, 0.001));
} }
public void testComplexInferenceDefinitionInferWithCustomPreProcessor() throws IOException { public void testComplexInferenceDefinitionInferWithCustomPreProcessor() throws IOException {
@ -155,20 +155,20 @@ public class InferenceDefinitionTests extends ESTestCase {
ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config); ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
assertThat(results.valueAsString(), equalTo("second")); assertThat(results.valueAsString(), equalTo("second"));
FeatureImportance featureImportance1 = results.getFeatureImportance().get(0); ClassificationFeatureImportance featureImportance1 = results.getFeatureImportance().get(0);
assertThat(featureImportance1.getFeatureName(), equalTo("col2")); assertThat(featureImportance1.getFeatureName(), equalTo("col2"));
assertThat(featureImportance1.getImportance(), closeTo(0.944, 0.001)); assertThat(featureImportance1.getTotalImportance(), closeTo(0.944, 0.001));
for (FeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) { for (ClassificationFeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) {
if (classImportance.getClassName().equals("second")) { if (classImportance.getClassName().equals("second")) {
assertThat(classImportance.getImportance(), closeTo(0.944, 0.001)); assertThat(classImportance.getImportance(), closeTo(0.944, 0.001));
} else { } else {
assertThat(classImportance.getImportance(), closeTo(-0.944, 0.001)); assertThat(classImportance.getImportance(), closeTo(-0.944, 0.001));
} }
} }
FeatureImportance featureImportance2 = results.getFeatureImportance().get(1); ClassificationFeatureImportance featureImportance2 = results.getFeatureImportance().get(1);
assertThat(featureImportance2.getFeatureName(), equalTo("col1_male")); assertThat(featureImportance2.getFeatureName(), equalTo("col1_male"));
assertThat(featureImportance2.getImportance(), closeTo(0.199, 0.001)); assertThat(featureImportance2.getTotalImportance(), closeTo(0.199, 0.001));
for (FeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) { for (ClassificationFeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) {
if (classImportance.getClassName().equals("second")) { if (classImportance.getClassName().equals("second")) {
assertThat(classImportance.getImportance(), closeTo(0.199, 0.001)); assertThat(classImportance.getImportance(), closeTo(0.199, 0.001));
} else { } else {

View File

@ -16,10 +16,11 @@ import org.elasticsearch.search.aggregations.InvalidAggregationPathException;
import org.elasticsearch.search.aggregations.ParsedAggregation; import org.elasticsearch.search.aggregations.ParsedAggregation;
import org.elasticsearch.test.InternalAggregationTestCase; import org.elasticsearch.test.InternalAggregationTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
@ -115,7 +116,7 @@ public class InternalInferenceAggregationTests extends InternalAggregationTestCa
} else if (result instanceof RegressionInferenceResults) { } else if (result instanceof RegressionInferenceResults) {
RegressionInferenceResults regression = (RegressionInferenceResults) result; RegressionInferenceResults regression = (RegressionInferenceResults) result;
assertEquals(regression.value(), parsed.getValue()); assertEquals(regression.value(), parsed.getValue());
List<FeatureImportance> featureImportance = regression.getFeatureImportance(); List<RegressionFeatureImportance> featureImportance = regression.getFeatureImportance();
if (featureImportance.isEmpty()) { if (featureImportance.isEmpty()) {
featureImportance = null; featureImportance = null;
} }
@ -124,7 +125,7 @@ public class InternalInferenceAggregationTests extends InternalAggregationTestCa
ClassificationInferenceResults classification = (ClassificationInferenceResults) result; ClassificationInferenceResults classification = (ClassificationInferenceResults) result;
assertEquals(classification.predictedValue(), parsed.getValue()); assertEquals(classification.predictedValue(), parsed.getValue());
List<FeatureImportance> featureImportance = classification.getFeatureImportance(); List<ClassificationFeatureImportance> featureImportance = classification.getFeatureImportance();
if (featureImportance.isEmpty()) { if (featureImportance.isEmpty()) {
featureImportance = null; featureImportance = null;
} }

View File

@ -13,7 +13,6 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.ParsedAggregation; import org.elasticsearch.search.aggregations.ParsedAggregation;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
@ -21,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Map;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults.PREDICTION_PROBABILITY; import static org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults.PREDICTION_PROBABILITY;
@ -45,7 +45,7 @@ public class ParsedInference extends ParsedAggregation {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ParsedInference, Void> PARSER = private static final ConstructingObjectParser<ParsedInference, Void> PARSER =
new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true, new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true,
args -> new ParsedInference(args[0], (List<FeatureImportance>) args[1], args -> new ParsedInference(args[0], (List<Map<String, Object>>) args[1],
(List<TopClassEntry>) args[2], (String) args[3], (Double) args[4], (Double) args[5])); (List<TopClassEntry>) args[2], (String) args[3], (Double) args[4], (Double) args[5]));
static { static {
@ -65,7 +65,7 @@ public class ParsedInference extends ParsedAggregation {
} }
return o; return o;
}, CommonFields.VALUE, ObjectParser.ValueType.VALUE); }, CommonFields.VALUE, ObjectParser.ValueType.VALUE);
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FeatureImportance.fromXContent(p), PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> p.map(),
new ParseField(SingleValueInferenceResults.FEATURE_IMPORTANCE)); new ParseField(SingleValueInferenceResults.FEATURE_IMPORTANCE));
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p), PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p),
new ParseField(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD)); new ParseField(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD));
@ -82,14 +82,14 @@ public class ParsedInference extends ParsedAggregation {
} }
private final Object value; private final Object value;
private final List<FeatureImportance> featureImportance; private final List<Map<String, Object>> featureImportance;
private final List<TopClassEntry> topClasses; private final List<TopClassEntry> topClasses;
private final String warning; private final String warning;
private final Double predictionProbability; private final Double predictionProbability;
private final Double predictionScore; private final Double predictionScore;
ParsedInference(Object value, ParsedInference(Object value,
List<FeatureImportance> featureImportance, List<Map<String, Object>> featureImportance,
List<TopClassEntry> topClasses, List<TopClassEntry> topClasses,
String warning, String warning,
Double predictionProbability, Double predictionProbability,
@ -106,7 +106,7 @@ public class ParsedInference extends ParsedAggregation {
return value; return value;
} }
public List<FeatureImportance> getFeatureImportance() { public List<Map<String, Object>> getFeatureImportance() {
return featureImportance; return featureImportance;
} }

View File

@ -9,8 +9,9 @@ import org.elasticsearch.client.Client;
import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
@ -136,9 +137,11 @@ public class InferenceProcessorTests extends ESTestCase {
classes.add(new TopClassEntry("foo", 0.6, 0.6)); classes.add(new TopClassEntry("foo", 0.6, 0.6));
classes.add(new TopClassEntry("bar", 0.4, 0.4)); classes.add(new TopClassEntry("bar", 0.4, 0.4));
List<FeatureImportance> featureInfluence = new ArrayList<>(); List<ClassificationFeatureImportance> featureInfluence = new ArrayList<>();
featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13)); featureInfluence.add(new ClassificationFeatureImportance("feature_1",
featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0)); Collections.singletonList(new ClassificationFeatureImportance.ClassImportance("class_a", 1.13))));
featureInfluence.add(new ClassificationFeatureImportance("feature_2",
Collections.singletonList(new ClassificationFeatureImportance.ClassImportance("class_b", -42.0))));
InternalInferModelAction.Response response = new InternalInferModelAction.Response( InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, Collections.singletonList(new ClassificationInferenceResults(1.0,
@ -153,10 +156,12 @@ public class InferenceProcessorTests extends ESTestCase {
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model")); assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo")); assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.importance", Double.class), equalTo(-42.0));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.feature_name", String.class), equalTo("feature_2")); assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.feature_name", String.class), equalTo("feature_2"));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.importance", Double.class), equalTo(1.13)); assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.classes.0.class_name", String.class), equalTo("class_b"));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.classes.0.importance", Double.class), equalTo(-42.0));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.feature_name", String.class), equalTo("feature_1")); assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.feature_name", String.class), equalTo("feature_1"));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.classes.0.class_name", String.class), equalTo("class_a"));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.classes.0.importance", Double.class), equalTo(1.13));
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -234,9 +239,9 @@ public class InferenceProcessorTests extends ESTestCase {
Map<String, Object> ingestMetadata = new HashMap<>(); Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata); IngestDocument document = new IngestDocument(source, ingestMetadata);
List<FeatureImportance> featureInfluence = new ArrayList<>(); List<RegressionFeatureImportance> featureInfluence = new ArrayList<>();
featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13)); featureInfluence.add(new RegressionFeatureImportance("feature_1", 1.13));
featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0)); featureInfluence.add(new RegressionFeatureImportance("feature_2", -42.0));
InternalInferModelAction.Response response = new InternalInferModelAction.Response( InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true); Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true);