related to PR https://github.com/elastic/elasticsearch/pull/61104
This commit is contained in:
parent
65d0c7bbee
commit
038cc26ac5
|
@ -27,8 +27,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||||
|
@ -38,36 +37,37 @@ public class FeatureImportance implements ToXContentObject {
|
||||||
|
|
||||||
public static final String IMPORTANCE = "importance";
|
public static final String IMPORTANCE = "importance";
|
||||||
public static final String FEATURE_NAME = "feature_name";
|
public static final String FEATURE_NAME = "feature_name";
|
||||||
public static final String CLASS_IMPORTANCE = "class_importance";
|
public static final String CLASSES = "classes";
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
|
private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
|
||||||
new ConstructingObjectParser<>("feature_importance", true,
|
new ConstructingObjectParser<>("feature_importance", true,
|
||||||
a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2])
|
a -> new FeatureImportance((String) a[0], (Double) a[1], (List<ClassImportance>) a[2])
|
||||||
);
|
);
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
|
PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
|
||||||
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
||||||
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
PARSER.declareObjectArray(optionalConstructorArg(),
|
||||||
new ParseField(FeatureImportance.CLASS_IMPORTANCE));
|
(p, c) -> ClassImportance.fromXContent(p),
|
||||||
|
new ParseField(FeatureImportance.CLASSES));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static FeatureImportance fromXContent(XContentParser parser) {
|
public static FeatureImportance fromXContent(XContentParser parser) {
|
||||||
return PARSER.apply(parser, null);
|
return PARSER.apply(parser, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final Map<String, Double> classImportance;
|
private final List<ClassImportance> classImportance;
|
||||||
private final double importance;
|
private final double importance;
|
||||||
private final String featureName;
|
private final String featureName;
|
||||||
|
|
||||||
public FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
|
public FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
|
||||||
this.featureName = Objects.requireNonNull(featureName);
|
this.featureName = Objects.requireNonNull(featureName);
|
||||||
this.importance = importance;
|
this.importance = importance;
|
||||||
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
|
this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<String, Double> getClassImportance() {
|
public List<ClassImportance> getClassImportance() {
|
||||||
return classImportance;
|
return classImportance;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,11 +85,7 @@ public class FeatureImportance implements ToXContentObject {
|
||||||
builder.field(FEATURE_NAME, featureName);
|
builder.field(FEATURE_NAME, featureName);
|
||||||
builder.field(IMPORTANCE, importance);
|
builder.field(IMPORTANCE, importance);
|
||||||
if (classImportance != null && classImportance.isEmpty() == false) {
|
if (classImportance != null && classImportance.isEmpty() == false) {
|
||||||
builder.startObject(CLASS_IMPORTANCE);
|
builder.field(CLASSES, classImportance);
|
||||||
for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
|
|
||||||
builder.field(entry.getKey(), entry.getValue());
|
|
||||||
}
|
|
||||||
builder.endObject();
|
|
||||||
}
|
}
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -109,4 +105,63 @@ public class FeatureImportance implements ToXContentObject {
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(featureName, importance, classImportance);
|
return Objects.hash(featureName, importance, classImportance);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static class ClassImportance implements ToXContentObject {
|
||||||
|
|
||||||
|
static final String CLASS_NAME = "class_name";
|
||||||
|
|
||||||
|
private static final ConstructingObjectParser<ClassImportance, Void> PARSER =
|
||||||
|
new ConstructingObjectParser<>("feature_importance_class_importance",
|
||||||
|
true,
|
||||||
|
a -> new ClassImportance((String) a[0], (Double) a[1])
|
||||||
|
);
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
|
||||||
|
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ClassImportance fromXContent(XContentParser parser) {
|
||||||
|
return PARSER.apply(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final String className;
|
||||||
|
private final double importance;
|
||||||
|
|
||||||
|
public ClassImportance(String className, double importance) {
|
||||||
|
this.className = className;
|
||||||
|
this.importance = importance;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getClassName() {
|
||||||
|
return className;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getImportance() {
|
||||||
|
return importance;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
builder.field(CLASS_NAME, className);
|
||||||
|
builder.field(IMPORTANCE, importance);
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
ClassImportance that = (ClassImportance) o;
|
||||||
|
return Double.compare(that.importance, importance) == 0 &&
|
||||||
|
Objects.equals(className, that.className);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(className, importance);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,8 +23,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.function.Function;
|
|
||||||
import java.util.function.Predicate;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
@ -38,7 +36,8 @@ public class FeatureImportanceTests extends AbstractXContentTestCase<FeatureImpo
|
||||||
randomBoolean() ? null :
|
randomBoolean() ? null :
|
||||||
Stream.generate(() -> randomAlphaOfLength(10))
|
Stream.generate(() -> randomAlphaOfLength(10))
|
||||||
.limit(randomLongBetween(2, 10))
|
.limit(randomLongBetween(2, 10))
|
||||||
.collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false))));
|
.map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
|
||||||
|
.collect(Collectors.toList()));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,8 +51,4 @@ public class FeatureImportanceTests extends AbstractXContentTestCase<FeatureImpo
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
|
||||||
return field -> field.equals(FeatureImportance.CLASS_IMPORTANCE);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue