Classification feature importance supports various types in the class name: - string - boolean - numerical The xcontent parsing on the server side and the HLRC side should support and test these types.
This commit is contained in:
parent
c9be9963a8
commit
0b3af242d4
|
@ -21,8 +21,10 @@ package org.elasticsearch.client.ml.inference.results;
|
|||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -115,11 +117,20 @@ public class FeatureImportance implements ToXContentObject {
|
|||
private static final ConstructingObjectParser<ClassImportance, Void> PARSER =
|
||||
new ConstructingObjectParser<>("feature_importance_class_importance",
|
||||
true,
|
||||
a -> new ClassImportance((String) a[0], (Double) a[1])
|
||||
a -> new ClassImportance(a[0], (Double) a[1])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
|
||||
PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
||||
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
|
||||
return p.text();
|
||||
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
|
||||
return p.numberValue();
|
||||
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
|
||||
return p.booleanValue();
|
||||
}
|
||||
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
|
||||
}, new ParseField(CLASS_NAME), ObjectParser.ValueType.VALUE);
|
||||
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
||||
}
|
||||
|
||||
|
@ -127,15 +138,15 @@ public class FeatureImportance implements ToXContentObject {
|
|||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String className;
|
||||
private final Object className;
|
||||
private final double importance;
|
||||
|
||||
public ClassImportance(String className, double importance) {
|
||||
public ClassImportance(Object className, double importance) {
|
||||
this.className = className;
|
||||
this.importance = importance;
|
||||
}
|
||||
|
||||
public String getClassName() {
|
||||
public Object getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
|
|
|
@ -21,20 +21,28 @@ package org.elasticsearch.client.ml.inference.results;
|
|||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class FeatureImportanceTests extends AbstractXContentTestCase<FeatureImportance> {
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
protected FeatureImportance createTestInstance() {
|
||||
Supplier<Object> classNameGenerator = randomFrom(
|
||||
() -> randomAlphaOfLength(10),
|
||||
ESTestCase::randomBoolean,
|
||||
() -> randomIntBetween(0, 10)
|
||||
);
|
||||
return new FeatureImportance(
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : randomDoubleBetween(-10.0, 10.0, false),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(() -> randomAlphaOfLength(10))
|
||||
Stream.generate(classNameGenerator)
|
||||
.limit(randomLongBetween(2, 10))
|
||||
.map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
|
||||
.collect(Collectors.toList()));
|
||||
|
|
|
@ -20,8 +20,10 @@ package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
|
|||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
@ -29,13 +31,19 @@ import java.util.stream.Stream;
|
|||
public class TotalFeatureImportanceTests extends AbstractXContentTestCase<TotalFeatureImportance> {
|
||||
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static TotalFeatureImportance randomInstance() {
|
||||
Supplier<Object> classNameGenerator = randomFrom(
|
||||
() -> randomAlphaOfLength(10),
|
||||
ESTestCase::randomBoolean,
|
||||
() -> randomIntBetween(0, 10)
|
||||
);
|
||||
return new TotalFeatureImportance(
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : randomImportance(),
|
||||
randomBoolean() ?
|
||||
null :
|
||||
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance()))
|
||||
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(classNameGenerator.get(), randomImportance()))
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList())
|
||||
);
|
||||
|
|
|
@ -10,8 +10,10 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -118,7 +120,16 @@ public class ClassificationFeatureImportance extends AbstractFeatureImportance {
|
|||
);
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
|
||||
PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
||||
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
|
||||
return p.text();
|
||||
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
|
||||
return p.numberValue();
|
||||
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
|
||||
return p.booleanValue();
|
||||
}
|
||||
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
|
||||
}, new ParseField(CLASS_NAME), ObjectParser.ValueType.VALUE);
|
||||
PARSER.declareDouble(constructorArg(), new ParseField(IMPORTANCE));
|
||||
}
|
||||
|
||||
|
|
|
@ -8,9 +8,11 @@ package org.elasticsearch.xpack.core.ml.inference.results;
|
|||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
@ -34,10 +36,16 @@ public class ClassificationFeatureImportanceTests extends AbstractSerializingTes
|
|||
return createRandomInstance();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static ClassificationFeatureImportance createRandomInstance() {
|
||||
Supplier<Object> classNameGenerator = randomFrom(
|
||||
() -> randomAlphaOfLength(10),
|
||||
ESTestCase::randomBoolean,
|
||||
() -> randomIntBetween(0, 10)
|
||||
);
|
||||
return new ClassificationFeatureImportance(
|
||||
randomAlphaOfLength(10),
|
||||
Stream.generate(() -> randomAlphaOfLength(10))
|
||||
Stream.generate(classNameGenerator)
|
||||
.limit(randomLongBetween(2, 10))
|
||||
.map(name -> new ClassificationFeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
|
||||
.collect(Collectors.toList()));
|
||||
|
|
|
@ -8,10 +8,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;
|
|||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
@ -20,13 +22,19 @@ public class TotalFeatureImportanceTests extends AbstractBWCSerializationTestCas
|
|||
|
||||
private boolean lenient;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static TotalFeatureImportance randomInstance() {
|
||||
Supplier<Object> classNameGenerator = randomFrom(
|
||||
() -> randomAlphaOfLength(10),
|
||||
ESTestCase::randomBoolean,
|
||||
() -> randomIntBetween(0, 10)
|
||||
);
|
||||
return new TotalFeatureImportance(
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : randomImportance(),
|
||||
randomBoolean() ?
|
||||
null :
|
||||
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance()))
|
||||
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(classNameGenerator.get(), randomImportance()))
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList())
|
||||
);
|
||||
|
|
Loading…
Reference in New Issue