As we have decided top level importance for classification is not useful, it has been removed from the results from the training job. This commit also removes them from inference. Backport of #62486
This commit is contained in:
parent
cc33df87d3
commit
7f6c1ff5b4
|
@ -47,7 +47,7 @@ public class FeatureImportance implements ToXContentObject {
|
|||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
|
||||
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
||||
PARSER.declareDouble(optionalConstructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
||||
PARSER.declareObjectArray(optionalConstructorArg(),
|
||||
(p, c) -> ClassImportance.fromXContent(p),
|
||||
new ParseField(FeatureImportance.CLASSES));
|
||||
|
@ -58,10 +58,10 @@ public class FeatureImportance implements ToXContentObject {
|
|||
}
|
||||
|
||||
private final List<ClassImportance> classImportance;
|
||||
private final double importance;
|
||||
private final Double importance;
|
||||
private final String featureName;
|
||||
|
||||
public FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
|
||||
public FeatureImportance(String featureName, Double importance, List<ClassImportance> classImportance) {
|
||||
this.featureName = Objects.requireNonNull(featureName);
|
||||
this.importance = importance;
|
||||
this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
|
||||
|
@ -71,7 +71,7 @@ public class FeatureImportance implements ToXContentObject {
|
|||
return classImportance;
|
||||
}
|
||||
|
||||
public double getImportance() {
|
||||
public Double getImportance() {
|
||||
return importance;
|
||||
}
|
||||
|
||||
|
@ -83,7 +83,9 @@ public class FeatureImportance implements ToXContentObject {
|
|||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FEATURE_NAME, featureName);
|
||||
builder.field(IMPORTANCE, importance);
|
||||
if (importance != null) {
|
||||
builder.field(IMPORTANCE, importance);
|
||||
}
|
||||
if (classImportance != null && classImportance.isEmpty() == false) {
|
||||
builder.field(CLASSES, classImportance);
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ public class FeatureImportanceTests extends AbstractXContentTestCase<FeatureImpo
|
|||
protected FeatureImportance createTestInstance() {
|
||||
return new FeatureImportance(
|
||||
randomAlphaOfLength(10),
|
||||
randomDoubleBetween(-10.0, 10.0, false),
|
||||
randomBoolean() ? null : randomDoubleBetween(-10.0, 10.0, false),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(() -> randomAlphaOfLength(10))
|
||||
.limit(randomLongBetween(2, 10))
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
|
||||
abstract class AbstractFeatureImportance implements Writeable, ToXContentObject {
|
||||
|
||||
public abstract String getFeatureName();
|
||||
|
||||
public abstract Map<String, Object> toMap();
|
||||
|
||||
@Override
|
||||
public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
return builder.map(toMap());
|
||||
}
|
||||
}
|
|
@ -5,7 +5,6 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
|
@ -26,157 +25,101 @@ import java.util.stream.Collectors;
|
|||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class FeatureImportance implements Writeable, ToXContentObject {
|
||||
public class ClassificationFeatureImportance extends AbstractFeatureImportance {
|
||||
|
||||
private final List<ClassImportance> classImportance;
|
||||
private final double importance;
|
||||
private final String featureName;
|
||||
static final String IMPORTANCE = "importance";
|
||||
|
||||
static final String FEATURE_NAME = "feature_name";
|
||||
static final String CLASSES = "classes";
|
||||
|
||||
public static FeatureImportance forRegression(String featureName, double importance) {
|
||||
return new FeatureImportance(featureName, importance, null);
|
||||
}
|
||||
|
||||
public static FeatureImportance forBinaryClassification(String featureName, double importance, List<ClassImportance> classImportance) {
|
||||
return new FeatureImportance(featureName,
|
||||
importance,
|
||||
classImportance);
|
||||
}
|
||||
|
||||
public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
|
||||
return new FeatureImportance(featureName,
|
||||
classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
|
||||
classImportance);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
|
||||
new ConstructingObjectParser<>("feature_importance",
|
||||
a -> new FeatureImportance((String) a[0], (Double) a[1], (List<ClassImportance>) a[2])
|
||||
private static final ConstructingObjectParser<ClassificationFeatureImportance, Void> PARSER =
|
||||
new ConstructingObjectParser<>("classification_feature_importance",
|
||||
a -> new ClassificationFeatureImportance((String) a[0], (List<ClassImportance>) a[1])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
|
||||
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
||||
PARSER.declareString(constructorArg(), new ParseField(ClassificationFeatureImportance.FEATURE_NAME));
|
||||
PARSER.declareObjectArray(optionalConstructorArg(),
|
||||
(p, c) -> ClassImportance.fromXContent(p),
|
||||
new ParseField(FeatureImportance.CLASSES));
|
||||
new ParseField(ClassificationFeatureImportance.CLASSES));
|
||||
}
|
||||
|
||||
public static FeatureImportance fromXContent(XContentParser parser) {
|
||||
public static ClassificationFeatureImportance fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
|
||||
public ClassificationFeatureImportance(String featureName, List<ClassImportance> classImportance) {
|
||||
this.featureName = Objects.requireNonNull(featureName);
|
||||
this.importance = importance;
|
||||
this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
|
||||
this.classImportance = classImportance == null ? Collections.emptyList() : Collections.unmodifiableList(classImportance);
|
||||
}
|
||||
|
||||
public FeatureImportance(StreamInput in) throws IOException {
|
||||
public ClassificationFeatureImportance(StreamInput in) throws IOException {
|
||||
this.featureName = in.readString();
|
||||
this.importance = in.readDouble();
|
||||
if (in.readBoolean()) {
|
||||
if (in.getVersion().before(Version.V_7_10_0)) {
|
||||
Map<String, Double> classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
|
||||
this.classImportance = ClassImportance.fromMap(classImportance);
|
||||
} else {
|
||||
this.classImportance = in.readList(ClassImportance::new);
|
||||
}
|
||||
} else {
|
||||
this.classImportance = null;
|
||||
}
|
||||
this.classImportance = in.readList(ClassImportance::new);
|
||||
}
|
||||
|
||||
public List<ClassImportance> getClassImportance() {
|
||||
return classImportance;
|
||||
}
|
||||
|
||||
public double getImportance() {
|
||||
return importance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getFeatureName() {
|
||||
return featureName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(this.featureName);
|
||||
out.writeDouble(this.importance);
|
||||
out.writeBoolean(this.classImportance != null);
|
||||
if (this.classImportance != null) {
|
||||
if (out.getVersion().before(Version.V_7_10_0)) {
|
||||
out.writeMap(ClassImportance.toMap(this.classImportance), StreamOutput::writeString, StreamOutput::writeDouble);
|
||||
} else {
|
||||
out.writeList(this.classImportance);
|
||||
}
|
||||
public double getTotalImportance() {
|
||||
if (classImportance.size() == 2) {
|
||||
// Binary classification. We can return the first class importance here
|
||||
return Math.abs(classImportance.get(0).getImportance());
|
||||
}
|
||||
return classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(featureName);
|
||||
out.writeList(classImportance);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> toMap() {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
map.put(FEATURE_NAME, featureName);
|
||||
map.put(IMPORTANCE, importance);
|
||||
if (classImportance != null) {
|
||||
if (classImportance.isEmpty() == false) {
|
||||
map.put(CLASSES, classImportance.stream().map(ClassImportance::toMap).collect(Collectors.toList()));
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FEATURE_NAME, featureName);
|
||||
builder.field(IMPORTANCE, importance);
|
||||
if (classImportance != null && classImportance.isEmpty() == false) {
|
||||
builder.field(CLASSES, classImportance);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
FeatureImportance that = (FeatureImportance) object;
|
||||
ClassificationFeatureImportance that = (ClassificationFeatureImportance) object;
|
||||
return Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(importance, that.importance)
|
||||
&& Objects.equals(classImportance, that.classImportance);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureName, importance, classImportance);
|
||||
return Objects.hash(featureName, classImportance);
|
||||
}
|
||||
|
||||
public static class ClassImportance implements Writeable, ToXContentObject {
|
||||
|
||||
static final String CLASS_NAME = "class_name";
|
||||
static final String IMPORTANCE = "importance";
|
||||
|
||||
private static final ConstructingObjectParser<ClassImportance, Void> PARSER =
|
||||
new ConstructingObjectParser<>("feature_importance_class_importance",
|
||||
a -> new ClassImportance((String) a[0], (Double) a[1])
|
||||
new ConstructingObjectParser<>("classification_feature_importance_class_importance",
|
||||
a -> new ClassImportance(a[0], (Double) a[1])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
|
||||
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
||||
}
|
||||
|
||||
private static ClassImportance fromMapEntry(Map.Entry<String, Double> entry) {
|
||||
return new ClassImportance(entry.getKey(), entry.getValue());
|
||||
}
|
||||
|
||||
private static List<ClassImportance> fromMap(Map<String, Double> classImportanceMap) {
|
||||
return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static Map<String, Double> toMap(List<ClassImportance> importances) {
|
||||
return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
|
||||
PARSER.declareDouble(constructorArg(), new ParseField(IMPORTANCE));
|
||||
}
|
||||
|
||||
public static ClassImportance fromXContent(XContentParser parser) {
|
||||
|
@ -219,11 +162,7 @@ public class FeatureImportance implements Writeable, ToXContentObject {
|
|||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASS_NAME, className);
|
||||
builder.field(IMPORTANCE, importance);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
return builder.map(toMap());
|
||||
}
|
||||
|
||||
@Override
|
|
@ -15,9 +15,9 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldTyp
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
@ -34,12 +34,13 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
private final Double predictionProbability;
|
||||
private final Double predictionScore;
|
||||
private final List<TopClassEntry> topClasses;
|
||||
private final List<ClassificationFeatureImportance> featureImportance;
|
||||
private final PredictionFieldType predictionFieldType;
|
||||
|
||||
public ClassificationInferenceResults(double value,
|
||||
String classificationLabel,
|
||||
List<TopClassEntry> topClasses,
|
||||
List<FeatureImportance> featureImportance,
|
||||
List<ClassificationFeatureImportance> featureImportance,
|
||||
InferenceConfig config,
|
||||
Double predictionProbability,
|
||||
Double predictionScore) {
|
||||
|
@ -55,13 +56,11 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
private ClassificationInferenceResults(double value,
|
||||
String classificationLabel,
|
||||
List<TopClassEntry> topClasses,
|
||||
List<FeatureImportance> featureImportance,
|
||||
List<ClassificationFeatureImportance> featureImportance,
|
||||
ClassificationConfig classificationConfig,
|
||||
Double predictionProbability,
|
||||
Double predictionScore) {
|
||||
super(value,
|
||||
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
|
||||
classificationConfig.getNumTopFeatureImportanceValues()));
|
||||
super(value);
|
||||
this.classificationLabel = classificationLabel;
|
||||
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
|
||||
this.topNumClassesField = classificationConfig.getTopClassesResultsField();
|
||||
|
@ -69,10 +68,32 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
this.predictionFieldType = classificationConfig.getPredictionFieldType();
|
||||
this.predictionProbability = predictionProbability;
|
||||
this.predictionScore = predictionScore;
|
||||
this.featureImportance = takeTopFeatureImportances(featureImportance, classificationConfig.getNumTopFeatureImportanceValues());
|
||||
}
|
||||
|
||||
static List<ClassificationFeatureImportance> takeTopFeatureImportances(List<ClassificationFeatureImportance> featureImportances,
|
||||
int numTopFeatures) {
|
||||
if (featureImportances == null || featureImportances.isEmpty()) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
return featureImportances.stream()
|
||||
.sorted((l, r)-> Double.compare(r.getTotalImportance(), l.getTotalImportance()))
|
||||
.limit(numTopFeatures)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public ClassificationInferenceResults(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
this.featureImportance = in.readList(ClassificationFeatureImportance::new);
|
||||
} else if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
this.featureImportance = in.readList(LegacyFeatureImportance::new)
|
||||
.stream()
|
||||
.map(LegacyFeatureImportance::forClassification)
|
||||
.collect(Collectors.toList());
|
||||
} else {
|
||||
this.featureImportance = Collections.emptyList();
|
||||
}
|
||||
this.classificationLabel = in.readOptionalString();
|
||||
this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new));
|
||||
this.topNumClassesField = in.readString();
|
||||
|
@ -103,9 +124,18 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
return predictionFieldType;
|
||||
}
|
||||
|
||||
public List<ClassificationFeatureImportance> getFeatureImportance() {
|
||||
return featureImportance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeList(featureImportance);
|
||||
} else if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
out.writeList(featureImportance.stream().map(LegacyFeatureImportance::fromClassification).collect(Collectors.toList()));
|
||||
}
|
||||
out.writeOptionalString(classificationLabel);
|
||||
out.writeCollection(topClasses);
|
||||
out.writeString(topNumClassesField);
|
||||
|
@ -132,7 +162,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
&& Objects.equals(predictionFieldType, that.predictionFieldType)
|
||||
&& Objects.equals(predictionProbability, that.predictionProbability)
|
||||
&& Objects.equals(predictionScore, that.predictionScore)
|
||||
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
|
||||
&& Objects.equals(featureImportance, that.featureImportance);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -144,7 +174,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
topNumClassesField,
|
||||
predictionProbability,
|
||||
predictionScore,
|
||||
getFeatureImportance(),
|
||||
featureImportance,
|
||||
predictionFieldType);
|
||||
}
|
||||
|
||||
|
@ -179,8 +209,9 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
if (predictionScore != null) {
|
||||
map.put(PREDICTION_SCORE, predictionScore);
|
||||
}
|
||||
if (getFeatureImportance().isEmpty() == false) {
|
||||
map.put(FEATURE_IMPORTANCE, getFeatureImportance().stream().map(FeatureImportance::toMap).collect(Collectors.toList()));
|
||||
if (featureImportance.isEmpty() == false) {
|
||||
map.put(FEATURE_IMPORTANCE, featureImportance.stream().map(ClassificationFeatureImportance::toMap)
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
@ -202,8 +233,8 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
if (predictionScore != null) {
|
||||
builder.field(PREDICTION_SCORE, predictionScore);
|
||||
}
|
||||
if (getFeatureImportance().size() > 0) {
|
||||
builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
|
||||
if (featureImportance.isEmpty() == false) {
|
||||
builder.field(FEATURE_IMPORTANCE, featureImportance);
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* This class captures serialization of feature importance for
|
||||
* classification and regression prior to version 7.10.
|
||||
*/
|
||||
public class LegacyFeatureImportance implements Writeable {
|
||||
|
||||
public static LegacyFeatureImportance fromClassification(ClassificationFeatureImportance classificationFeatureImportance) {
|
||||
return new LegacyFeatureImportance(
|
||||
classificationFeatureImportance.getFeatureName(),
|
||||
classificationFeatureImportance.getTotalImportance(),
|
||||
classificationFeatureImportance.getClassImportance().stream().map(classImportance -> new ClassImportance(
|
||||
classImportance.getClassName(), classImportance.getImportance())).collect(Collectors.toList())
|
||||
);
|
||||
}
|
||||
|
||||
public static LegacyFeatureImportance fromRegression(RegressionFeatureImportance regressionFeatureImportance) {
|
||||
return new LegacyFeatureImportance(
|
||||
regressionFeatureImportance.getFeatureName(),
|
||||
regressionFeatureImportance.getImportance(),
|
||||
null
|
||||
);
|
||||
}
|
||||
|
||||
private final List<ClassImportance> classImportance;
|
||||
private final double importance;
|
||||
private final String featureName;
|
||||
|
||||
LegacyFeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
|
||||
this.featureName = Objects.requireNonNull(featureName);
|
||||
this.importance = importance;
|
||||
this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
|
||||
}
|
||||
|
||||
public LegacyFeatureImportance(StreamInput in) throws IOException {
|
||||
this.featureName = in.readString();
|
||||
this.importance = in.readDouble();
|
||||
if (in.readBoolean()) {
|
||||
if (in.getVersion().before(Version.V_7_10_0)) {
|
||||
Map<String, Double> classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
|
||||
this.classImportance = ClassImportance.fromMap(classImportance);
|
||||
} else {
|
||||
this.classImportance = in.readList(ClassImportance::new);
|
||||
}
|
||||
} else {
|
||||
this.classImportance = null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(featureName);
|
||||
out.writeDouble(importance);
|
||||
out.writeBoolean(classImportance != null);
|
||||
if (classImportance != null) {
|
||||
if (out.getVersion().before(Version.V_7_10_0)) {
|
||||
out.writeMap(ClassImportance.toMap(classImportance), StreamOutput::writeString, StreamOutput::writeDouble);
|
||||
} else {
|
||||
out.writeList(classImportance);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
LegacyFeatureImportance that = (LegacyFeatureImportance) object;
|
||||
return Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(importance, that.importance)
|
||||
&& Objects.equals(classImportance, that.classImportance);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureName, importance, classImportance);
|
||||
}
|
||||
|
||||
public RegressionFeatureImportance forRegression() {
|
||||
assert classImportance == null;
|
||||
return new RegressionFeatureImportance(featureName, importance);
|
||||
}
|
||||
|
||||
public ClassificationFeatureImportance forClassification() {
|
||||
assert classImportance != null;
|
||||
return new ClassificationFeatureImportance(featureName, classImportance.stream().map(
|
||||
aClassImportance -> new ClassificationFeatureImportance.ClassImportance(
|
||||
aClassImportance.className, aClassImportance.importance)).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
public static class ClassImportance implements Writeable {
|
||||
|
||||
private static ClassImportance fromMapEntry(Map.Entry<String, Double> entry) {
|
||||
return new ClassImportance(entry.getKey(), entry.getValue());
|
||||
}
|
||||
|
||||
private static List<ClassImportance> fromMap(Map<String, Double> classImportanceMap) {
|
||||
return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static Map<String, Double> toMap(List<ClassImportance> importances) {
|
||||
return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
|
||||
}
|
||||
|
||||
private final Object className;
|
||||
private final double importance;
|
||||
|
||||
public ClassImportance(Object className, double importance) {
|
||||
this.className = className;
|
||||
this.importance = importance;
|
||||
}
|
||||
|
||||
public ClassImportance(StreamInput in) throws IOException {
|
||||
this.className = in.readGenericValue();
|
||||
this.importance = in.readDouble();
|
||||
}
|
||||
|
||||
double getImportance() {
|
||||
return importance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeGenericValue(className);
|
||||
out.writeDouble(importance);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ClassImportance that = (ClassImportance) o;
|
||||
return Double.compare(that.importance, importance) == 0 &&
|
||||
Objects.equals(className, that.className);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(className, importance);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
public class RegressionFeatureImportance extends AbstractFeatureImportance {
|
||||
|
||||
private final double importance;
|
||||
private final String featureName;
|
||||
static final String IMPORTANCE = "importance";
|
||||
static final String FEATURE_NAME = "feature_name";
|
||||
|
||||
private static final ConstructingObjectParser<RegressionFeatureImportance, Void> PARSER =
|
||||
new ConstructingObjectParser<>("regression_feature_importance",
|
||||
a -> new RegressionFeatureImportance((String) a[0], (Double) a[1])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), new ParseField(RegressionFeatureImportance.FEATURE_NAME));
|
||||
PARSER.declareDouble(constructorArg(), new ParseField(RegressionFeatureImportance.IMPORTANCE));
|
||||
}
|
||||
|
||||
public static RegressionFeatureImportance fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public RegressionFeatureImportance(String featureName, double importance) {
|
||||
this.featureName = Objects.requireNonNull(featureName);
|
||||
this.importance = importance;
|
||||
}
|
||||
|
||||
public RegressionFeatureImportance(StreamInput in) throws IOException {
|
||||
this.featureName = in.readString();
|
||||
this.importance = in.readDouble();
|
||||
}
|
||||
|
||||
public double getImportance() {
|
||||
return importance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getFeatureName() {
|
||||
return featureName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(featureName);
|
||||
out.writeDouble(importance);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> toMap() {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
map.put(FEATURE_NAME, featureName);
|
||||
map.put(IMPORTANCE, importance);
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
RegressionFeatureImportance that = (RegressionFeatureImportance) object;
|
||||
return Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(importance, that.importance);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureName, importance);
|
||||
}
|
||||
}
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
@ -24,14 +25,19 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
|||
public static final String NAME = "regression";
|
||||
|
||||
private final String resultsField;
|
||||
private final List<RegressionFeatureImportance> featureImportance;
|
||||
|
||||
public RegressionInferenceResults(double value, InferenceConfig config) {
|
||||
this(value, config, Collections.emptyList());
|
||||
}
|
||||
|
||||
public RegressionInferenceResults(double value, InferenceConfig config, List<FeatureImportance> featureImportance) {
|
||||
this(value, ((RegressionConfig)config).getResultsField(),
|
||||
((RegressionConfig)config).getNumTopFeatureImportanceValues(), featureImportance);
|
||||
public RegressionInferenceResults(double value, InferenceConfig config, List<RegressionFeatureImportance> featureImportance) {
|
||||
this(
|
||||
value,
|
||||
((RegressionConfig)config).getResultsField(),
|
||||
((RegressionConfig)config).getNumTopFeatureImportanceValues(),
|
||||
featureImportance
|
||||
);
|
||||
}
|
||||
|
||||
public RegressionInferenceResults(double value, String resultsField) {
|
||||
|
@ -39,28 +45,58 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
|||
}
|
||||
|
||||
public RegressionInferenceResults(double value, String resultsField,
|
||||
List<FeatureImportance> featureImportance) {
|
||||
List<RegressionFeatureImportance> featureImportance) {
|
||||
this(value, resultsField, featureImportance.size(), featureImportance);
|
||||
}
|
||||
|
||||
public RegressionInferenceResults(double value, String resultsField, int topNFeatures,
|
||||
List<FeatureImportance> featureImportance) {
|
||||
super(value,
|
||||
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance, topNFeatures));
|
||||
List<RegressionFeatureImportance> featureImportance) {
|
||||
super(value);
|
||||
this.resultsField = resultsField;
|
||||
this.featureImportance = takeTopFeatureImportances(featureImportance, topNFeatures);
|
||||
}
|
||||
|
||||
static List<RegressionFeatureImportance> takeTopFeatureImportances(List<RegressionFeatureImportance> featureImportances,
|
||||
int numTopFeatures) {
|
||||
if (featureImportances == null || featureImportances.isEmpty()) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
return featureImportances.stream()
|
||||
.sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())))
|
||||
.limit(numTopFeatures)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public RegressionInferenceResults(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
this.featureImportance = in.readList(RegressionFeatureImportance::new);
|
||||
} else if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
this.featureImportance = in.readList(LegacyFeatureImportance::new)
|
||||
.stream()
|
||||
.map(LegacyFeatureImportance::forRegression)
|
||||
.collect(Collectors.toList());
|
||||
} else {
|
||||
this.featureImportance = Collections.emptyList();
|
||||
}
|
||||
this.resultsField = in.readString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeList(featureImportance);
|
||||
} else if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
out.writeList(featureImportance.stream().map(LegacyFeatureImportance::fromRegression).collect(Collectors.toList()));
|
||||
}
|
||||
out.writeString(resultsField);
|
||||
}
|
||||
|
||||
public List<RegressionFeatureImportance> getFeatureImportance() {
|
||||
return featureImportance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
|
@ -68,12 +104,12 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
|||
RegressionInferenceResults that = (RegressionInferenceResults) object;
|
||||
return Objects.equals(value(), that.value())
|
||||
&& Objects.equals(this.resultsField, that.resultsField)
|
||||
&& Objects.equals(this.getFeatureImportance(), that.getFeatureImportance());
|
||||
&& Objects.equals(this.featureImportance, that.featureImportance);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(value(), resultsField, getFeatureImportance());
|
||||
return Objects.hash(value(), resultsField, featureImportance);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -85,8 +121,8 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
|||
public Map<String, Object> asMap() {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
map.put(resultsField, value());
|
||||
if (getFeatureImportance().isEmpty() == false) {
|
||||
map.put(FEATURE_IMPORTANCE, getFeatureImportance().stream().map(FeatureImportance::toMap).collect(Collectors.toList()));
|
||||
if (featureImportance.isEmpty() == false) {
|
||||
map.put(FEATURE_IMPORTANCE, featureImportance.stream().map(RegressionFeatureImportance::toMap).collect(Collectors.toList()));
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
@ -94,8 +130,8 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
|||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(resultsField, value());
|
||||
if (getFeatureImportance().size() > 0) {
|
||||
builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
|
||||
if (featureImportance.isEmpty() == false) {
|
||||
builder.field(FEATURE_IMPORTANCE, featureImportance);
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -5,53 +5,30 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public abstract class SingleValueInferenceResults implements InferenceResults {
|
||||
|
||||
public static final String FEATURE_IMPORTANCE = "feature_importance";
|
||||
|
||||
private final double value;
|
||||
private final List<FeatureImportance> featureImportance;
|
||||
|
||||
static List<FeatureImportance> takeTopFeatureImportances(List<FeatureImportance> unsortedFeatureImportances, int numTopFeatures) {
|
||||
if (unsortedFeatureImportances == null || unsortedFeatureImportances.isEmpty()) {
|
||||
return unsortedFeatureImportances;
|
||||
}
|
||||
return unsortedFeatureImportances.stream()
|
||||
.sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())))
|
||||
.limit(numTopFeatures)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
SingleValueInferenceResults(StreamInput in) throws IOException {
|
||||
value = in.readDouble();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
this.featureImportance = in.readList(FeatureImportance::new);
|
||||
} else {
|
||||
this.featureImportance = Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
SingleValueInferenceResults(double value, List<FeatureImportance> featureImportance) {
|
||||
SingleValueInferenceResults(double value) {
|
||||
this.value = value;
|
||||
this.featureImportance = featureImportance == null ? Collections.emptyList() : featureImportance;
|
||||
}
|
||||
|
||||
public Double value() {
|
||||
return value;
|
||||
}
|
||||
|
||||
public List<FeatureImportance> getFeatureImportance() {
|
||||
return featureImportance;
|
||||
}
|
||||
|
||||
public String valueAsString() {
|
||||
return String.valueOf(value);
|
||||
|
@ -60,9 +37,6 @@ public abstract class SingleValueInferenceResults implements InferenceResults {
|
|||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(value);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
out.writeList(this.featureImportance);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -7,7 +7,8 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
|||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
|
@ -130,17 +131,18 @@ public final class InferenceHelpers {
|
|||
return originalFeatureImportance;
|
||||
}
|
||||
|
||||
public static List<FeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) {
|
||||
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
|
||||
featureImportance.forEach((k, v) -> importances.add(FeatureImportance.forRegression(k, v[0])));
|
||||
public static List<RegressionFeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) {
|
||||
List<RegressionFeatureImportance> importances = new ArrayList<>(featureImportance.size());
|
||||
featureImportance.forEach((k, v) -> importances.add(new RegressionFeatureImportance(k, v[0])));
|
||||
return importances;
|
||||
}
|
||||
|
||||
public static List<FeatureImportance> transformFeatureImportanceClassification(Map<String, double[]> featureImportance,
|
||||
final int predictedValue,
|
||||
@Nullable List<String> classificationLabels,
|
||||
@Nullable PredictionFieldType predictionFieldType) {
|
||||
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
|
||||
public static List<ClassificationFeatureImportance> transformFeatureImportanceClassification(
|
||||
Map<String, double[]> featureImportance,
|
||||
final int predictedValue,
|
||||
@Nullable List<String> classificationLabels,
|
||||
@Nullable PredictionFieldType predictionFieldType) {
|
||||
List<ClassificationFeatureImportance> importances = new ArrayList<>(featureImportance.size());
|
||||
final PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
|
||||
featureImportance.forEach((k, v) -> {
|
||||
// This indicates logistic regression (binary classification)
|
||||
|
@ -152,27 +154,26 @@ public final class InferenceHelpers {
|
|||
final int otherClass = 1 - predictedValue;
|
||||
String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue);
|
||||
String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass);
|
||||
importances.add(FeatureImportance.forBinaryClassification(k,
|
||||
v[0],
|
||||
importances.add(new ClassificationFeatureImportance(k,
|
||||
Arrays.asList(
|
||||
new FeatureImportance.ClassImportance(
|
||||
new ClassificationFeatureImportance.ClassImportance(
|
||||
fieldType.transformPredictedValue((double)predictedValue, predictedLabel),
|
||||
v[0]),
|
||||
new FeatureImportance.ClassImportance(
|
||||
new ClassificationFeatureImportance.ClassImportance(
|
||||
fieldType.transformPredictedValue((double)otherClass, otherLabel),
|
||||
-v[0])
|
||||
)));
|
||||
} else {
|
||||
List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
|
||||
List<ClassificationFeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
|
||||
// If the classificationLabels exist, their length must match leaf_value length
|
||||
assert classificationLabels == null || classificationLabels.size() == v.length;
|
||||
for (int i = 0; i < v.length; i++) {
|
||||
String label = classificationLabels == null ? null : classificationLabels.get(i);
|
||||
classImportance.add(new FeatureImportance.ClassImportance(
|
||||
classImportance.add(new ClassificationFeatureImportance.ClassImportance(
|
||||
fieldType.transformPredictedValue((double)i, label),
|
||||
v[i]));
|
||||
}
|
||||
importances.add(FeatureImportance.forClassification(k, classImportance));
|
||||
importances.add(new ClassificationFeatureImportance(k, classImportance));
|
||||
}
|
||||
});
|
||||
return importances;
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class ClassificationFeatureImportanceTests extends AbstractSerializingTestCase<ClassificationFeatureImportance> {
|
||||
|
||||
@Override
|
||||
protected ClassificationFeatureImportance doParseInstance(XContentParser parser) throws IOException {
|
||||
return ClassificationFeatureImportance.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<ClassificationFeatureImportance> instanceReader() {
|
||||
return ClassificationFeatureImportance::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClassificationFeatureImportance createTestInstance() {
|
||||
return createRandomInstance();
|
||||
}
|
||||
|
||||
public static ClassificationFeatureImportance createRandomInstance() {
|
||||
return new ClassificationFeatureImportance(
|
||||
randomAlphaOfLength(10),
|
||||
Stream.generate(() -> randomAlphaOfLength(10))
|
||||
.limit(randomLongBetween(2, 10))
|
||||
.map(name -> new ClassificationFeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
public void testGetTotalImportance_GivenBinary() {
|
||||
ClassificationFeatureImportance featureImportance = new ClassificationFeatureImportance(
|
||||
"binary",
|
||||
Arrays.asList(
|
||||
new ClassificationFeatureImportance.ClassImportance("a", 0.15),
|
||||
new ClassificationFeatureImportance.ClassImportance("not-a", -0.15)
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(featureImportance.getTotalImportance(), equalTo(0.15));
|
||||
}
|
||||
|
||||
public void testGetTotalImportance_GivenMulticlass() {
|
||||
ClassificationFeatureImportance featureImportance = new ClassificationFeatureImportance(
|
||||
"multiclass",
|
||||
Arrays.asList(
|
||||
new ClassificationFeatureImportance.ClassImportance("a", 0.15),
|
||||
new ClassificationFeatureImportance.ClassImportance("b", -0.05),
|
||||
new ClassificationFeatureImportance.ClassImportance("c", 0.30)
|
||||
)
|
||||
);
|
||||
|
||||
assertThat(featureImportance.getTotalImportance(), closeTo(0.50, 0.00000001));
|
||||
}
|
||||
}
|
|
@ -18,7 +18,6 @@ import java.util.Collections;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
@ -29,10 +28,6 @@ import static org.hamcrest.Matchers.hasSize;
|
|||
public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase<ClassificationInferenceResults> {
|
||||
|
||||
public static ClassificationInferenceResults createRandomResults() {
|
||||
Supplier<FeatureImportance> featureImportanceCtor = randomBoolean() ?
|
||||
FeatureImportanceTests::randomClassification :
|
||||
FeatureImportanceTests::randomRegression;
|
||||
|
||||
ClassificationConfig config = ClassificationConfigTests.randomClassificationConfig();
|
||||
Double value = randomDouble();
|
||||
if (config.getPredictionFieldType() == PredictionFieldType.BOOLEAN) {
|
||||
|
@ -47,7 +42,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
.limit(randomIntBetween(0, 10))
|
||||
.collect(Collectors.toList()),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(featureImportanceCtor)
|
||||
Stream.generate(ClassificationFeatureImportanceTests::createRandomInstance)
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList()),
|
||||
config,
|
||||
|
@ -123,11 +118,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
}
|
||||
|
||||
public void testWriteResultsWithImportance() {
|
||||
Supplier<FeatureImportance> featureImportanceCtor = randomBoolean() ?
|
||||
FeatureImportanceTests::randomClassification :
|
||||
FeatureImportanceTests::randomRegression;
|
||||
|
||||
List<FeatureImportance> importanceList = Stream.generate(featureImportanceCtor)
|
||||
List<ClassificationFeatureImportance> importanceList = Stream.generate(ClassificationFeatureImportanceTests::createRandomInstance)
|
||||
.limit(5)
|
||||
.collect(Collectors.toList());
|
||||
ClassificationInferenceResults result = new ClassificationInferenceResults(0.0,
|
||||
|
@ -146,18 +137,17 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
"result_field.feature_importance",
|
||||
List.class);
|
||||
assertThat(writtenImportance, hasSize(3));
|
||||
importanceList.sort((l, r) -> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
|
||||
importanceList.sort((l, r) -> Double.compare(Math.abs(r.getTotalImportance()), Math.abs(l.getTotalImportance())));
|
||||
for (int i = 0; i < 3; i++) {
|
||||
Map<String, Object> objectMap = writtenImportance.get(i);
|
||||
FeatureImportance importance = importanceList.get(i);
|
||||
ClassificationFeatureImportance importance = importanceList.get(i);
|
||||
assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
|
||||
assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Map<String, Object>> classImportances = (List<Map<String, Object>>)objectMap.get("classes");
|
||||
if (importance.getClassImportance() != null) {
|
||||
for (int j = 0; j < importance.getClassImportance().size(); j++) {
|
||||
Map<String, Object> classMap = classImportances.get(j);
|
||||
FeatureImportance.ClassImportance classImportance = importance.getClassImportance().get(j);
|
||||
ClassificationFeatureImportance.ClassImportance classImportance = importance.getClassImportance().get(j);
|
||||
assertThat(classMap.get("class_name"), equalTo(classImportance.getClassName()));
|
||||
assertThat(classMap.get("importance"), equalTo(classImportance.getImportance()));
|
||||
}
|
||||
|
@ -212,7 +202,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}";
|
||||
assertEquals(expected, stringRep);
|
||||
|
||||
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList());
|
||||
ClassificationFeatureImportance fi = new ClassificationFeatureImportance("foo", Collections.emptyList());
|
||||
TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0);
|
||||
result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp),
|
||||
Collections.singletonList(fi), config,
|
||||
|
|
|
@ -1,49 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class FeatureImportanceTests extends AbstractSerializingTestCase<FeatureImportance> {
|
||||
|
||||
public static FeatureImportance createRandomInstance() {
|
||||
return randomBoolean() ? randomClassification() : randomRegression();
|
||||
}
|
||||
|
||||
static FeatureImportance randomRegression() {
|
||||
return FeatureImportance.forRegression(randomAlphaOfLength(10), randomDoubleBetween(-10.0, 10.0, false));
|
||||
}
|
||||
|
||||
static FeatureImportance randomClassification() {
|
||||
return FeatureImportance.forClassification(
|
||||
randomAlphaOfLength(10),
|
||||
Stream.generate(() -> randomAlphaOfLength(10))
|
||||
.limit(randomLongBetween(2, 10))
|
||||
.map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FeatureImportance createTestInstance() {
|
||||
return createRandomInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<FeatureImportance> instanceReader() {
|
||||
return FeatureImportance::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FeatureImportance doParseInstance(XContentParser parser) throws IOException {
|
||||
return FeatureImportance.fromXContent(parser);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class LegacyFeatureImportanceTests extends AbstractWireSerializingTestCase<LegacyFeatureImportance> {
|
||||
|
||||
public static LegacyFeatureImportance createRandomInstance() {
|
||||
return createRandomInstance(randomBoolean());
|
||||
}
|
||||
|
||||
public static LegacyFeatureImportance createRandomInstance(boolean hasClasses) {
|
||||
double importance = randomDouble();
|
||||
List<LegacyFeatureImportance.ClassImportance> classImportances = null;
|
||||
if (hasClasses) {
|
||||
classImportances = Stream.generate(() -> randomAlphaOfLength(10))
|
||||
.limit(randomLongBetween(2, 10))
|
||||
.map(featureName -> new LegacyFeatureImportance.ClassImportance(featureName, randomDouble()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
importance = classImportances.stream().mapToDouble(LegacyFeatureImportance.ClassImportance::getImportance).map(Math::abs).sum();
|
||||
}
|
||||
return new LegacyFeatureImportance(randomAlphaOfLength(10), importance, classImportances);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected LegacyFeatureImportance createTestInstance() {
|
||||
return createRandomInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<LegacyFeatureImportance> instanceReader() {
|
||||
return LegacyFeatureImportance::new;
|
||||
}
|
||||
|
||||
public void testClassificationConversion() {
|
||||
{
|
||||
ClassificationFeatureImportance classificationFeatureImportance = ClassificationFeatureImportanceTests.createRandomInstance();
|
||||
LegacyFeatureImportance legacyFeatureImportance = LegacyFeatureImportance.fromClassification(classificationFeatureImportance);
|
||||
ClassificationFeatureImportance convertedFeatureImportance = legacyFeatureImportance.forClassification();
|
||||
assertThat(convertedFeatureImportance, equalTo(classificationFeatureImportance));
|
||||
}
|
||||
{
|
||||
LegacyFeatureImportance legacyFeatureImportance = createRandomInstance(true);
|
||||
ClassificationFeatureImportance classificationFeatureImportance = legacyFeatureImportance.forClassification();
|
||||
LegacyFeatureImportance convertedFeatureImportance = LegacyFeatureImportance.fromClassification(
|
||||
classificationFeatureImportance);
|
||||
assertThat(convertedFeatureImportance, equalTo(legacyFeatureImportance));
|
||||
}
|
||||
}
|
||||
|
||||
public void testRegressionConversion() {
|
||||
{
|
||||
RegressionFeatureImportance regressionFeatureImportance = RegressionFeatureImportanceTests.createRandomInstance();
|
||||
LegacyFeatureImportance legacyFeatureImportance = LegacyFeatureImportance.fromRegression(regressionFeatureImportance);
|
||||
RegressionFeatureImportance convertedFeatureImportance = legacyFeatureImportance.forRegression();
|
||||
assertThat(convertedFeatureImportance, equalTo(regressionFeatureImportance));
|
||||
}
|
||||
{
|
||||
LegacyFeatureImportance legacyFeatureImportance = createRandomInstance(false);
|
||||
RegressionFeatureImportance regressionFeatureImportance = legacyFeatureImportance.forRegression();
|
||||
LegacyFeatureImportance convertedFeatureImportance = LegacyFeatureImportance.fromRegression(regressionFeatureImportance);
|
||||
assertThat(convertedFeatureImportance, equalTo(legacyFeatureImportance));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class RegressionFeatureImportanceTests extends AbstractSerializingTestCase<RegressionFeatureImportance> {
|
||||
|
||||
@Override
|
||||
protected RegressionFeatureImportance doParseInstance(XContentParser parser) throws IOException {
|
||||
return RegressionFeatureImportance.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<RegressionFeatureImportance> instanceReader() {
|
||||
return RegressionFeatureImportance::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RegressionFeatureImportance createTestInstance() {
|
||||
return createRandomInstance();
|
||||
}
|
||||
|
||||
public static RegressionFeatureImportance createRandomInstance() {
|
||||
return new RegressionFeatureImportance(randomAlphaOfLength(10), randomDoubleBetween(-10.0, 10.0, false));
|
||||
}
|
||||
}
|
|
@ -29,8 +29,8 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
|
|||
public static RegressionInferenceResults createRandomResults() {
|
||||
return new RegressionInferenceResults(randomDouble(),
|
||||
RegressionConfigTests.randomRegressionConfig(),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(FeatureImportanceTests::randomRegression)
|
||||
randomBoolean() ? Collections.emptyList() :
|
||||
Stream.generate(RegressionFeatureImportanceTests::createRandomInstance)
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
|
|||
}
|
||||
|
||||
public void testWriteResultsWithImportance() {
|
||||
List<FeatureImportance> importanceList = Stream.generate(FeatureImportanceTests::randomRegression)
|
||||
List<RegressionFeatureImportance> importanceList = Stream.generate(RegressionFeatureImportanceTests::createRandomInstance)
|
||||
.limit(5)
|
||||
.collect(Collectors.toList());
|
||||
RegressionInferenceResults result = new RegressionInferenceResults(0.3,
|
||||
|
@ -68,7 +68,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
|
|||
importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
|
||||
for (int i = 0; i < 3; i++) {
|
||||
Map<String, Object> objectMap = writtenImportance.get(i);
|
||||
FeatureImportance importance = importanceList.get(i);
|
||||
RegressionFeatureImportance importance = importanceList.get(i);
|
||||
assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
|
||||
assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
|
||||
assertThat(objectMap.size(), equalTo(2));
|
||||
|
@ -92,7 +92,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
|
|||
String expected = "{\"" + resultsField + "\":1.0}";
|
||||
assertEquals(expected, stringRep);
|
||||
|
||||
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList());
|
||||
RegressionFeatureImportance fi = new RegressionFeatureImportance("foo", 1.0);
|
||||
result = new RegressionInferenceResults(1.0, resultsField, Collections.singletonList(fi));
|
||||
stringRep = Strings.toString(result);
|
||||
expected = "{\"" + resultsField + "\":1.0,\"feature_importance\":[{\"feature_name\":\"foo\",\"importance\":1.0}]}";
|
||||
|
|
|
@ -16,8 +16,8 @@ import org.elasticsearch.common.xcontent.XContentType;
|
|||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -134,9 +134,9 @@ public class InferenceDefinitionTests extends ESTestCase {
|
|||
ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
|
||||
assertThat(results.valueAsString(), equalTo("second"));
|
||||
assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2"));
|
||||
assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001));
|
||||
assertThat(results.getFeatureImportance().get(0).getTotalImportance(), closeTo(0.944, 0.001));
|
||||
assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1"));
|
||||
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
|
||||
assertThat(results.getFeatureImportance().get(1).getTotalImportance(), closeTo(0.199, 0.001));
|
||||
}
|
||||
|
||||
public void testComplexInferenceDefinitionInferWithCustomPreProcessor() throws IOException {
|
||||
|
@ -155,20 +155,20 @@ public class InferenceDefinitionTests extends ESTestCase {
|
|||
|
||||
ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
|
||||
assertThat(results.valueAsString(), equalTo("second"));
|
||||
FeatureImportance featureImportance1 = results.getFeatureImportance().get(0);
|
||||
ClassificationFeatureImportance featureImportance1 = results.getFeatureImportance().get(0);
|
||||
assertThat(featureImportance1.getFeatureName(), equalTo("col2"));
|
||||
assertThat(featureImportance1.getImportance(), closeTo(0.944, 0.001));
|
||||
for (FeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) {
|
||||
assertThat(featureImportance1.getTotalImportance(), closeTo(0.944, 0.001));
|
||||
for (ClassificationFeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) {
|
||||
if (classImportance.getClassName().equals("second")) {
|
||||
assertThat(classImportance.getImportance(), closeTo(0.944, 0.001));
|
||||
} else {
|
||||
assertThat(classImportance.getImportance(), closeTo(-0.944, 0.001));
|
||||
}
|
||||
}
|
||||
FeatureImportance featureImportance2 = results.getFeatureImportance().get(1);
|
||||
ClassificationFeatureImportance featureImportance2 = results.getFeatureImportance().get(1);
|
||||
assertThat(featureImportance2.getFeatureName(), equalTo("col1_male"));
|
||||
assertThat(featureImportance2.getImportance(), closeTo(0.199, 0.001));
|
||||
for (FeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) {
|
||||
assertThat(featureImportance2.getTotalImportance(), closeTo(0.199, 0.001));
|
||||
for (ClassificationFeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) {
|
||||
if (classImportance.getClassName().equals("second")) {
|
||||
assertThat(classImportance.getImportance(), closeTo(0.199, 0.001));
|
||||
} else {
|
||||
|
|
|
@ -16,10 +16,11 @@ import org.elasticsearch.search.aggregations.InvalidAggregationPathException;
|
|||
import org.elasticsearch.search.aggregations.ParsedAggregation;
|
||||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||
|
@ -115,7 +116,7 @@ public class InternalInferenceAggregationTests extends InternalAggregationTestCa
|
|||
} else if (result instanceof RegressionInferenceResults) {
|
||||
RegressionInferenceResults regression = (RegressionInferenceResults) result;
|
||||
assertEquals(regression.value(), parsed.getValue());
|
||||
List<FeatureImportance> featureImportance = regression.getFeatureImportance();
|
||||
List<RegressionFeatureImportance> featureImportance = regression.getFeatureImportance();
|
||||
if (featureImportance.isEmpty()) {
|
||||
featureImportance = null;
|
||||
}
|
||||
|
@ -124,7 +125,7 @@ public class InternalInferenceAggregationTests extends InternalAggregationTestCa
|
|||
ClassificationInferenceResults classification = (ClassificationInferenceResults) result;
|
||||
assertEquals(classification.predictedValue(), parsed.getValue());
|
||||
|
||||
List<FeatureImportance> featureImportance = classification.getFeatureImportance();
|
||||
List<ClassificationFeatureImportance> featureImportance = classification.getFeatureImportance();
|
||||
if (featureImportance.isEmpty()) {
|
||||
featureImportance = null;
|
||||
}
|
||||
|
|
|
@ -13,7 +13,6 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
|
|||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.ParsedAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
|
@ -21,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults.PREDICTION_PROBABILITY;
|
||||
|
@ -45,7 +45,7 @@ public class ParsedInference extends ParsedAggregation {
|
|||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<ParsedInference, Void> PARSER =
|
||||
new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true,
|
||||
args -> new ParsedInference(args[0], (List<FeatureImportance>) args[1],
|
||||
args -> new ParsedInference(args[0], (List<Map<String, Object>>) args[1],
|
||||
(List<TopClassEntry>) args[2], (String) args[3], (Double) args[4], (Double) args[5]));
|
||||
|
||||
static {
|
||||
|
@ -65,7 +65,7 @@ public class ParsedInference extends ParsedAggregation {
|
|||
}
|
||||
return o;
|
||||
}, CommonFields.VALUE, ObjectParser.ValueType.VALUE);
|
||||
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FeatureImportance.fromXContent(p),
|
||||
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> p.map(),
|
||||
new ParseField(SingleValueInferenceResults.FEATURE_IMPORTANCE));
|
||||
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p),
|
||||
new ParseField(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD));
|
||||
|
@ -82,14 +82,14 @@ public class ParsedInference extends ParsedAggregation {
|
|||
}
|
||||
|
||||
private final Object value;
|
||||
private final List<FeatureImportance> featureImportance;
|
||||
private final List<Map<String, Object>> featureImportance;
|
||||
private final List<TopClassEntry> topClasses;
|
||||
private final String warning;
|
||||
private final Double predictionProbability;
|
||||
private final Double predictionScore;
|
||||
|
||||
ParsedInference(Object value,
|
||||
List<FeatureImportance> featureImportance,
|
||||
List<Map<String, Object>> featureImportance,
|
||||
List<TopClassEntry> topClasses,
|
||||
String warning,
|
||||
Double predictionProbability,
|
||||
|
@ -106,7 +106,7 @@ public class ParsedInference extends ParsedAggregation {
|
|||
return value;
|
||||
}
|
||||
|
||||
public List<FeatureImportance> getFeatureImportance() {
|
||||
public List<Map<String, Object>> getFeatureImportance() {
|
||||
return featureImportance;
|
||||
}
|
||||
|
||||
|
|
|
@ -9,8 +9,9 @@ import org.elasticsearch.client.Client;
|
|||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
|
@ -136,9 +137,11 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
classes.add(new TopClassEntry("foo", 0.6, 0.6));
|
||||
classes.add(new TopClassEntry("bar", 0.4, 0.4));
|
||||
|
||||
List<FeatureImportance> featureInfluence = new ArrayList<>();
|
||||
featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
|
||||
featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0));
|
||||
List<ClassificationFeatureImportance> featureInfluence = new ArrayList<>();
|
||||
featureInfluence.add(new ClassificationFeatureImportance("feature_1",
|
||||
Collections.singletonList(new ClassificationFeatureImportance.ClassImportance("class_a", 1.13))));
|
||||
featureInfluence.add(new ClassificationFeatureImportance("feature_2",
|
||||
Collections.singletonList(new ClassificationFeatureImportance.ClassImportance("class_b", -42.0))));
|
||||
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0,
|
||||
|
@ -153,10 +156,12 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
|
||||
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.importance", Double.class), equalTo(-42.0));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.feature_name", String.class), equalTo("feature_2"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.importance", Double.class), equalTo(1.13));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.classes.0.class_name", String.class), equalTo("class_b"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.classes.0.importance", Double.class), equalTo(-42.0));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.feature_name", String.class), equalTo("feature_1"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.classes.0.class_name", String.class), equalTo("class_a"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.classes.0.importance", Double.class), equalTo(1.13));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -234,9 +239,9 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
List<FeatureImportance> featureInfluence = new ArrayList<>();
|
||||
featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
|
||||
featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0));
|
||||
List<RegressionFeatureImportance> featureInfluence = new ArrayList<>();
|
||||
featureInfluence.add(new RegressionFeatureImportance("feature_1", 1.13));
|
||||
featureInfluence.add(new RegressionFeatureImportance("feature_2", -42.0));
|
||||
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true);
|
||||
|
|
Loading…
Reference in New Issue