[ML] fixing classification feature importance parsing (#63003) (#63015)

Classification feature importance supports various types in the class name:
- string
- boolean
- numerical

The xcontent parsing on the server side and the HLRC side should support and test these types.
This commit is contained in:
Benjamin Trent 2020-09-29 10:54:35 -04:00 committed by GitHub
parent c9be9963a8
commit 0b3af242d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 64 additions and 10 deletions

View File

@ -21,8 +21,10 @@ package org.elasticsearch.client.ml.inference.results;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
@ -115,11 +117,20 @@ public class FeatureImportance implements ToXContentObject {
private static final ConstructingObjectParser<ClassImportance, Void> PARSER =
new ConstructingObjectParser<>("feature_importance_class_importance",
true,
a -> new ClassImportance((String) a[0], (Double) a[1])
a -> new ClassImportance(a[0], (Double) a[1])
);
static {
PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return p.text();
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return p.numberValue();
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
return p.booleanValue();
}
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
}, new ParseField(CLASS_NAME), ObjectParser.ValueType.VALUE);
PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
}
@ -127,15 +138,15 @@ public class FeatureImportance implements ToXContentObject {
return PARSER.apply(parser, null);
}
private final String className;
private final Object className;
private final double importance;
public ClassImportance(String className, double importance) {
public ClassImportance(Object className, double importance) {
this.className = className;
this.importance = importance;
}
public String getClassName() {
public Object getClassName() {
return className;
}

View File

@ -21,20 +21,28 @@ package org.elasticsearch.client.ml.inference.results;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class FeatureImportanceTests extends AbstractXContentTestCase<FeatureImportance> {
@Override
@SuppressWarnings("unchecked")
protected FeatureImportance createTestInstance() {
Supplier<Object> classNameGenerator = randomFrom(
() -> randomAlphaOfLength(10),
ESTestCase::randomBoolean,
() -> randomIntBetween(0, 10)
);
return new FeatureImportance(
randomAlphaOfLength(10),
randomBoolean() ? null : randomDoubleBetween(-10.0, 10.0, false),
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10))
Stream.generate(classNameGenerator)
.limit(randomLongBetween(2, 10))
.map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
.collect(Collectors.toList()));

View File

@ -20,8 +20,10 @@ package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -29,13 +31,19 @@ import java.util.stream.Stream;
public class TotalFeatureImportanceTests extends AbstractXContentTestCase<TotalFeatureImportance> {
@SuppressWarnings("unchecked")
public static TotalFeatureImportance randomInstance() {
Supplier<Object> classNameGenerator = randomFrom(
() -> randomAlphaOfLength(10),
ESTestCase::randomBoolean,
() -> randomIntBetween(0, 10)
);
return new TotalFeatureImportance(
randomAlphaOfLength(10),
randomBoolean() ? null : randomImportance(),
randomBoolean() ?
null :
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance()))
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(classNameGenerator.get(), randomImportance()))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList())
);

View File

@ -10,8 +10,10 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
@ -118,7 +120,16 @@ public class ClassificationFeatureImportance extends AbstractFeatureImportance {
);
static {
PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return p.text();
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return p.numberValue();
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
return p.booleanValue();
}
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
}, new ParseField(CLASS_NAME), ObjectParser.ValueType.VALUE);
PARSER.declareDouble(constructorArg(), new ParseField(IMPORTANCE));
}

View File

@ -8,9 +8,11 @@ package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -34,10 +36,16 @@ public class ClassificationFeatureImportanceTests extends AbstractSerializingTes
return createRandomInstance();
}
@SuppressWarnings("unchecked")
public static ClassificationFeatureImportance createRandomInstance() {
Supplier<Object> classNameGenerator = randomFrom(
() -> randomAlphaOfLength(10),
ESTestCase::randomBoolean,
() -> randomIntBetween(0, 10)
);
return new ClassificationFeatureImportance(
randomAlphaOfLength(10),
Stream.generate(() -> randomAlphaOfLength(10))
Stream.generate(classNameGenerator)
.limit(randomLongBetween(2, 10))
.map(name -> new ClassificationFeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
.collect(Collectors.toList()));

View File

@ -8,10 +8,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;
import java.io.IOException;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -20,13 +22,19 @@ public class TotalFeatureImportanceTests extends AbstractBWCSerializationTestCas
private boolean lenient;
@SuppressWarnings("unchecked")
public static TotalFeatureImportance randomInstance() {
Supplier<Object> classNameGenerator = randomFrom(
() -> randomAlphaOfLength(10),
ESTestCase::randomBoolean,
() -> randomIntBetween(0, 10)
);
return new TotalFeatureImportance(
randomAlphaOfLength(10),
randomBoolean() ? null : randomImportance(),
randomBoolean() ?
null :
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance()))
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(classNameGenerator.get(), randomImportance()))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList())
);