This updates the feature_importance mapping change from elastic/ml-cpp#1387
This commit is contained in:
parent
f2f1552e2c
commit
7c3bfb9437
|
@ -314,7 +314,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
||||||
Map<String, Object> additionalProperties = new HashMap<>();
|
Map<String, Object> additionalProperties = new HashMap<>();
|
||||||
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
|
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.classificationFeatureImportanceMapping());
|
||||||
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
|
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
|
||||||
if ((dependentVariableMapping instanceof Map) == false) {
|
if ((dependentVariableMapping instanceof Map) == false) {
|
||||||
return additionalProperties;
|
return additionalProperties;
|
||||||
|
|
|
@ -18,22 +18,46 @@ import java.util.Map;
|
||||||
|
|
||||||
final class MapUtils {
|
final class MapUtils {
|
||||||
|
|
||||||
private static final Map<String, Object> FEATURE_IMPORTANCE_MAPPING;
|
private static Map<String, Object> createFeatureImportanceMapping(Map<String, Object> featureImportanceMappingProperties){
|
||||||
static {
|
|
||||||
Map<String, Object> featureImportanceMappingProperties = new HashMap<>();
|
|
||||||
featureImportanceMappingProperties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE));
|
featureImportanceMappingProperties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE));
|
||||||
featureImportanceMappingProperties.put("importance",
|
|
||||||
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
|
|
||||||
Map<String, Object> featureImportanceMapping = new HashMap<>();
|
Map<String, Object> featureImportanceMapping = new HashMap<>();
|
||||||
// TODO sorted indices don't support nested types
|
// TODO sorted indices don't support nested types
|
||||||
//featureImportanceMapping.put("dynamic", true);
|
//featureImportanceMapping.put("dynamic", true);
|
||||||
//featureImportanceMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
|
//featureImportanceMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
|
||||||
featureImportanceMapping.put("properties", featureImportanceMappingProperties);
|
featureImportanceMapping.put("properties", featureImportanceMappingProperties);
|
||||||
FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(featureImportanceMapping);
|
return featureImportanceMapping;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Map<String, Object> featureImportanceMapping() {
|
private static final Map<String, Object> CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING;
|
||||||
return FEATURE_IMPORTANCE_MAPPING;
|
static {
|
||||||
|
Map<String, Object> classImportancePropertiesMapping = new HashMap<>();
|
||||||
|
// TODO sorted indices don't support nested types
|
||||||
|
//classImportancePropertiesMapping.put("dynamic", true);
|
||||||
|
//classImportancePropertiesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
|
||||||
|
classImportancePropertiesMapping.put("class_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE));
|
||||||
|
classImportancePropertiesMapping.put("importance",
|
||||||
|
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
|
||||||
|
Map<String, Object> featureImportancePropertiesMapping = new HashMap<>();
|
||||||
|
featureImportancePropertiesMapping.put("classes", Collections.singletonMap("properties", classImportancePropertiesMapping));
|
||||||
|
CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING =
|
||||||
|
Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final Map<String, Object> REGRESSION_FEATURE_IMPORTANCE_MAPPING;
|
||||||
|
static {
|
||||||
|
Map<String, Object> featureImportancePropertiesMapping = new HashMap<>();
|
||||||
|
featureImportancePropertiesMapping.put("importance",
|
||||||
|
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
|
||||||
|
REGRESSION_FEATURE_IMPORTANCE_MAPPING =
|
||||||
|
Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping));
|
||||||
|
}
|
||||||
|
|
||||||
|
static Map<String, Object> regressionFeatureImportanceMapping() {
|
||||||
|
return REGRESSION_FEATURE_IMPORTANCE_MAPPING;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Map<String, Object> classificationFeatureImportanceMapping() {
|
||||||
|
return CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING;
|
||||||
}
|
}
|
||||||
|
|
||||||
private MapUtils() {}
|
private MapUtils() {}
|
||||||
|
|
|
@ -247,7 +247,7 @@ public class Regression implements DataFrameAnalysis {
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
||||||
Map<String, Object> additionalProperties = new HashMap<>();
|
Map<String, Object> additionalProperties = new HashMap<>();
|
||||||
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
|
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.regressionFeatureImportanceMapping());
|
||||||
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
|
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
|
||||||
// high (over 10M) values of dependent variable.
|
// high (over 10M) values of dependent variable.
|
||||||
additionalProperties.put(resultsFieldName + "." + predictionFieldName,
|
additionalProperties.put(resultsFieldName + "." + predictionFieldName,
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||||
|
|
||||||
|
import org.elasticsearch.Version;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.io.stream.StreamInput;
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
@ -16,65 +17,74 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||||
|
|
||||||
public class FeatureImportance implements Writeable, ToXContentObject {
|
public class FeatureImportance implements Writeable, ToXContentObject {
|
||||||
|
|
||||||
private final Map<String, Double> classImportance;
|
private final List<ClassImportance> classImportance;
|
||||||
private final double importance;
|
private final double importance;
|
||||||
private final String featureName;
|
private final String featureName;
|
||||||
static final String IMPORTANCE = "importance";
|
static final String IMPORTANCE = "importance";
|
||||||
static final String FEATURE_NAME = "feature_name";
|
static final String FEATURE_NAME = "feature_name";
|
||||||
static final String CLASS_IMPORTANCE = "class_importance";
|
static final String CLASSES = "classes";
|
||||||
|
|
||||||
public static FeatureImportance forRegression(String featureName, double importance) {
|
public static FeatureImportance forRegression(String featureName, double importance) {
|
||||||
return new FeatureImportance(featureName, importance, null);
|
return new FeatureImportance(featureName, importance, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static FeatureImportance forClassification(String featureName, Map<String, Double> classImportance) {
|
public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
|
||||||
return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
|
return new FeatureImportance(featureName,
|
||||||
|
classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
|
||||||
|
classImportance);
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
|
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
|
||||||
new ConstructingObjectParser<>("feature_importance",
|
new ConstructingObjectParser<>("feature_importance",
|
||||||
a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2])
|
a -> new FeatureImportance((String) a[0], (Double) a[1], (List<ClassImportance>) a[2])
|
||||||
);
|
);
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
|
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
|
||||||
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
||||||
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
PARSER.declareObjectArray(optionalConstructorArg(),
|
||||||
new ParseField(FeatureImportance.CLASS_IMPORTANCE));
|
(p, c) -> ClassImportance.fromXContent(p),
|
||||||
|
new ParseField(FeatureImportance.CLASSES));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static FeatureImportance fromXContent(XContentParser parser) {
|
public static FeatureImportance fromXContent(XContentParser parser) {
|
||||||
return PARSER.apply(parser, null);
|
return PARSER.apply(parser, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
|
FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
|
||||||
this.featureName = Objects.requireNonNull(featureName);
|
this.featureName = Objects.requireNonNull(featureName);
|
||||||
this.importance = importance;
|
this.importance = importance;
|
||||||
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
|
this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
|
||||||
}
|
}
|
||||||
|
|
||||||
public FeatureImportance(StreamInput in) throws IOException {
|
public FeatureImportance(StreamInput in) throws IOException {
|
||||||
this.featureName = in.readString();
|
this.featureName = in.readString();
|
||||||
this.importance = in.readDouble();
|
this.importance = in.readDouble();
|
||||||
if (in.readBoolean()) {
|
if (in.readBoolean()) {
|
||||||
this.classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
|
if (in.getVersion().before(Version.V_7_10_0)) {
|
||||||
|
Map<String, Double> classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
|
||||||
|
this.classImportance = ClassImportance.fromMap(classImportance);
|
||||||
|
} else {
|
||||||
|
this.classImportance = in.readList(ClassImportance::new);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
this.classImportance = null;
|
this.classImportance = null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<String, Double> getClassImportance() {
|
public List<ClassImportance> getClassImportance() {
|
||||||
return classImportance;
|
return classImportance;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,7 +102,11 @@ public class FeatureImportance implements Writeable, ToXContentObject {
|
||||||
out.writeDouble(this.importance);
|
out.writeDouble(this.importance);
|
||||||
out.writeBoolean(this.classImportance != null);
|
out.writeBoolean(this.classImportance != null);
|
||||||
if (this.classImportance != null) {
|
if (this.classImportance != null) {
|
||||||
out.writeMap(this.classImportance, StreamOutput::writeString, StreamOutput::writeDouble);
|
if (out.getVersion().before(Version.V_7_10_0)) {
|
||||||
|
out.writeMap(ClassImportance.toMap(this.classImportance), StreamOutput::writeString, StreamOutput::writeDouble);
|
||||||
|
} else {
|
||||||
|
out.writeList(this.classImportance);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,7 +115,7 @@ public class FeatureImportance implements Writeable, ToXContentObject {
|
||||||
map.put(FEATURE_NAME, featureName);
|
map.put(FEATURE_NAME, featureName);
|
||||||
map.put(IMPORTANCE, importance);
|
map.put(IMPORTANCE, importance);
|
||||||
if (classImportance != null) {
|
if (classImportance != null) {
|
||||||
classImportance.forEach(map::put);
|
map.put(CLASSES, classImportance.stream().map(ClassImportance::toMap).collect(Collectors.toList()));
|
||||||
}
|
}
|
||||||
return map;
|
return map;
|
||||||
}
|
}
|
||||||
|
@ -112,11 +126,7 @@ public class FeatureImportance implements Writeable, ToXContentObject {
|
||||||
builder.field(FEATURE_NAME, featureName);
|
builder.field(FEATURE_NAME, featureName);
|
||||||
builder.field(IMPORTANCE, importance);
|
builder.field(IMPORTANCE, importance);
|
||||||
if (classImportance != null && classImportance.isEmpty() == false) {
|
if (classImportance != null && classImportance.isEmpty() == false) {
|
||||||
builder.startObject(CLASS_IMPORTANCE);
|
builder.field(CLASSES, classImportance);
|
||||||
for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
|
|
||||||
builder.field(entry.getKey(), entry.getValue());
|
|
||||||
}
|
|
||||||
builder.endObject();
|
|
||||||
}
|
}
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -136,4 +146,92 @@ public class FeatureImportance implements Writeable, ToXContentObject {
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(featureName, importance, classImportance);
|
return Objects.hash(featureName, importance, classImportance);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static class ClassImportance implements Writeable, ToXContentObject {
|
||||||
|
|
||||||
|
static final String CLASS_NAME = "class_name";
|
||||||
|
|
||||||
|
private static final ConstructingObjectParser<ClassImportance, Void> PARSER =
|
||||||
|
new ConstructingObjectParser<>("feature_importance_class_importance",
|
||||||
|
a -> new ClassImportance((String) a[0], (Double) a[1])
|
||||||
|
);
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
|
||||||
|
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ClassImportance fromMapEntry(Map.Entry<String, Double> entry) {
|
||||||
|
return new ClassImportance(entry.getKey(), entry.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<ClassImportance> fromMap(Map<String, Double> classImportanceMap) {
|
||||||
|
return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Map<String, Double> toMap(List<ClassImportance> importances) {
|
||||||
|
return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ClassImportance fromXContent(XContentParser parser) {
|
||||||
|
return PARSER.apply(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final String className;
|
||||||
|
private final double importance;
|
||||||
|
|
||||||
|
public ClassImportance(String className, double importance) {
|
||||||
|
this.className = className;
|
||||||
|
this.importance = importance;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ClassImportance(StreamInput in) throws IOException {
|
||||||
|
this.className = in.readString();
|
||||||
|
this.importance = in.readDouble();
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getClassName() {
|
||||||
|
return className;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getImportance() {
|
||||||
|
return importance;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<String, Object> toMap() {
|
||||||
|
Map<String, Object> map = new LinkedHashMap<>();
|
||||||
|
map.put(CLASS_NAME, className);
|
||||||
|
map.put(IMPORTANCE, importance);
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeString(className);
|
||||||
|
out.writeDouble(importance);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
builder.field(CLASS_NAME, className);
|
||||||
|
builder.field(IMPORTANCE, importance);
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
ClassImportance that = (ClassImportance) o;
|
||||||
|
return Double.compare(that.importance, importance) == 0 &&
|
||||||
|
Objects.equals(className, that.className);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(className, importance);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,6 @@ import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
@ -139,11 +138,13 @@ public final class InferenceHelpers {
|
||||||
if (v.length == 1) {
|
if (v.length == 1) {
|
||||||
importances.add(FeatureImportance.forRegression(k, v[0]));
|
importances.add(FeatureImportance.forRegression(k, v[0]));
|
||||||
} else {
|
} else {
|
||||||
Map<String, Double> classImportance = new LinkedHashMap<>(v.length, 1.0f);
|
List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
|
||||||
// If the classificationLabels exist, their length must match leaf_value length
|
// If the classificationLabels exist, their length must match leaf_value length
|
||||||
assert classificationLabels == null || classificationLabels.size() == v.length;
|
assert classificationLabels == null || classificationLabels.size() == v.length;
|
||||||
for (int i = 0; i < v.length; i++) {
|
for (int i = 0; i < v.length; i++) {
|
||||||
classImportance.put(classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), v[i]);
|
classImportance.add(new FeatureImportance.ClassImportance(
|
||||||
|
classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i),
|
||||||
|
v[i]));
|
||||||
}
|
}
|
||||||
importances.add(FeatureImportance.forClassification(k, classImportance));
|
importances.add(FeatureImportance.forClassification(k, classImportance));
|
||||||
}
|
}
|
||||||
|
|
|
@ -261,12 +261,12 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
|
|
||||||
public void testGetExplicitlyMappedFields() {
|
public void testGetExplicitlyMappedFields() {
|
||||||
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"),
|
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"),
|
||||||
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
|
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
|
||||||
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"),
|
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"),
|
||||||
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
|
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
|
||||||
assertThat(
|
assertThat(
|
||||||
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
|
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
|
||||||
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
|
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
|
||||||
Map<String, Object> explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
|
Map<String, Object> explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
|
||||||
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
|
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
|
||||||
"results");
|
"results");
|
||||||
|
@ -274,7 +274,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
allOf(
|
allOf(
|
||||||
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
|
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
|
||||||
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
|
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
|
||||||
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
|
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()));
|
||||||
|
|
||||||
explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
|
explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
|
||||||
new HashMap<String, Object>() {{
|
new HashMap<String, Object>() {{
|
||||||
|
@ -289,7 +289,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
allOf(
|
allOf(
|
||||||
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
|
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
|
||||||
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
|
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
|
||||||
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
|
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()));
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
new Classification("foo").getExplicitlyMappedFields(
|
new Classification("foo").getExplicitlyMappedFields(
|
||||||
|
@ -298,7 +298,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
put("path", "missing");
|
put("path", "missing");
|
||||||
}}),
|
}}),
|
||||||
"results"),
|
"results"),
|
||||||
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
|
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {
|
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {
|
||||||
|
|
|
@ -206,7 +206,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
||||||
public void testGetExplicitlyMappedFields() {
|
public void testGetExplicitlyMappedFields() {
|
||||||
Map<String, Object> explicitlyMappedFields = new Regression("foo").getExplicitlyMappedFields(null, "results");
|
Map<String, Object> explicitlyMappedFields = new Regression("foo").getExplicitlyMappedFields(null, "results");
|
||||||
assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
|
assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
|
||||||
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
|
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.regressionFeatureImportanceMapping()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetStateDocId() {
|
public void testGetStateDocId() {
|
||||||
|
|
|
@ -152,8 +152,15 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
||||||
FeatureImportance importance = importanceList.get(i);
|
FeatureImportance importance = importanceList.get(i);
|
||||||
assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
|
assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
|
||||||
assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
|
assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
List<Map<String, Object>> classImportances = (List<Map<String, Object>>)objectMap.get("classes");
|
||||||
if (importance.getClassImportance() != null) {
|
if (importance.getClassImportance() != null) {
|
||||||
importance.getClassImportance().forEach((k, v) -> assertThat(objectMap.get(k), equalTo(v)));
|
for (int j = 0; j < importance.getClassImportance().size(); j++) {
|
||||||
|
Map<String, Object> classMap = classImportances.get(j);
|
||||||
|
FeatureImportance.ClassImportance classImportance = importance.getClassImportance().get(j);
|
||||||
|
assertThat(classMap.get("class_name"), equalTo(classImportance.getClassName()));
|
||||||
|
assertThat(classMap.get("importance"), equalTo(classImportance.getImportance()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -205,7 +212,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
||||||
expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}";
|
expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}";
|
||||||
assertEquals(expected, stringRep);
|
assertEquals(expected, stringRep);
|
||||||
|
|
||||||
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap());
|
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList());
|
||||||
TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0);
|
TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0);
|
||||||
result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp),
|
result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp),
|
||||||
Collections.singletonList(fi), config,
|
Collections.singletonList(fi), config,
|
||||||
|
|
|
@ -10,7 +10,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.function.Function;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
@ -29,7 +28,8 @@ public class FeatureImportanceTests extends AbstractSerializingTestCase<FeatureI
|
||||||
randomAlphaOfLength(10),
|
randomAlphaOfLength(10),
|
||||||
Stream.generate(() -> randomAlphaOfLength(10))
|
Stream.generate(() -> randomAlphaOfLength(10))
|
||||||
.limit(randomLongBetween(2, 10))
|
.limit(randomLongBetween(2, 10))
|
||||||
.collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false))));
|
.map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
|
||||||
|
.collect(Collectors.toList()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -92,7 +92,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
|
||||||
String expected = "{\"" + resultsField + "\":1.0}";
|
String expected = "{\"" + resultsField + "\":1.0}";
|
||||||
assertEquals(expected, stringRep);
|
assertEquals(expected, stringRep);
|
||||||
|
|
||||||
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap());
|
FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList());
|
||||||
result = new RegressionInferenceResults(1.0, resultsField, Collections.singletonList(fi));
|
result = new RegressionInferenceResults(1.0, resultsField, Collections.singletonList(fi));
|
||||||
stringRep = Strings.toString(result);
|
stringRep = Strings.toString(result);
|
||||||
expected = "{\"" + resultsField + "\":1.0,\"feature_importance\":[{\"feature_name\":\"foo\",\"importance\":1.0}]}";
|
expected = "{\"" + resultsField + "\":1.0,\"feature_importance\":[{\"feature_name\":\"foo\",\"importance\":1.0}]}";
|
||||||
|
|
Loading…
Reference in New Issue