[ML] adjusts feature importance format for hlrc (#61150) (#61153)

related to PR https://github.com/elastic/elasticsearch/pull/61104
This commit is contained in:
Benjamin Trent 2020-08-14 11:33:41 -04:00 committed by GitHub
parent 65d0c7bbee
commit 038cc26ac5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 22 deletions

View File

@ -27,8 +27,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
@ -38,36 +37,37 @@ public class FeatureImportance implements ToXContentObject {
public static final String IMPORTANCE = "importance"; public static final String IMPORTANCE = "importance";
public static final String FEATURE_NAME = "feature_name"; public static final String FEATURE_NAME = "feature_name";
public static final String CLASS_IMPORTANCE = "class_importance"; public static final String CLASSES = "classes";
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER = private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
new ConstructingObjectParser<>("feature_importance", true, new ConstructingObjectParser<>("feature_importance", true,
a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2]) a -> new FeatureImportance((String) a[0], (Double) a[1], (List<ClassImportance>) a[2])
); );
static { static {
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME)); PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue), PARSER.declareObjectArray(optionalConstructorArg(),
new ParseField(FeatureImportance.CLASS_IMPORTANCE)); (p, c) -> ClassImportance.fromXContent(p),
new ParseField(FeatureImportance.CLASSES));
} }
public static FeatureImportance fromXContent(XContentParser parser) { public static FeatureImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
private final Map<String, Double> classImportance; private final List<ClassImportance> classImportance;
private final double importance; private final double importance;
private final String featureName; private final String featureName;
public FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) { public FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
this.featureName = Objects.requireNonNull(featureName); this.featureName = Objects.requireNonNull(featureName);
this.importance = importance; this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance); this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
} }
public Map<String, Double> getClassImportance() { public List<ClassImportance> getClassImportance() {
return classImportance; return classImportance;
} }
@ -85,11 +85,7 @@ public class FeatureImportance implements ToXContentObject {
builder.field(FEATURE_NAME, featureName); builder.field(FEATURE_NAME, featureName);
builder.field(IMPORTANCE, importance); builder.field(IMPORTANCE, importance);
if (classImportance != null && classImportance.isEmpty() == false) { if (classImportance != null && classImportance.isEmpty() == false) {
builder.startObject(CLASS_IMPORTANCE); builder.field(CLASSES, classImportance);
for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
} }
builder.endObject(); builder.endObject();
return builder; return builder;
@ -109,4 +105,63 @@ public class FeatureImportance implements ToXContentObject {
public int hashCode() { public int hashCode() {
return Objects.hash(featureName, importance, classImportance); return Objects.hash(featureName, importance, classImportance);
} }
public static class ClassImportance implements ToXContentObject {
static final String CLASS_NAME = "class_name";
private static final ConstructingObjectParser<ClassImportance, Void> PARSER =
new ConstructingObjectParser<>("feature_importance_class_importance",
true,
a -> new ClassImportance((String) a[0], (Double) a[1])
);
static {
PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
}
public static ClassImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final String className;
private final double importance;
public ClassImportance(String className, double importance) {
this.className = className;
this.importance = importance;
}
public String getClassName() {
return className;
}
public double getImportance() {
return importance;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME, className);
builder.field(IMPORTANCE, importance);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassImportance that = (ClassImportance) o;
return Double.compare(that.importance, importance) == 0 &&
Objects.equals(className, that.className);
}
@Override
public int hashCode() {
return Objects.hash(className, importance);
}
}
} }

View File

@ -23,8 +23,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException; import java.io.IOException;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -38,7 +36,8 @@ public class FeatureImportanceTests extends AbstractXContentTestCase<FeatureImpo
randomBoolean() ? null : randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10)) Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(2, 10)) .limit(randomLongBetween(2, 10))
.collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false)))); .map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
.collect(Collectors.toList()));
} }
@ -52,8 +51,4 @@ public class FeatureImportanceTests extends AbstractXContentTestCase<FeatureImpo
return true; return true;
} }
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> field.equals(FeatureImportance.CLASS_IMPORTANCE);
}
} }