From 0b3af242d41a49bc41ba8defed16e3dc7baedb77 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 29 Sep 2020 10:54:35 -0400 Subject: [PATCH] [ML] fixing classification feature importance parsing (#63003) (#63015) Classification feature importance supports various types in the class name: - string - boolean - numerical The xcontent parsing on the server side and the HLRC side should support and test these types. --- .../inference/results/FeatureImportance.java | 21 ++++++++++++++----- .../results/FeatureImportanceTests.java | 10 ++++++++- .../metadata/TotalFeatureImportanceTests.java | 10 ++++++++- .../ClassificationFeatureImportance.java | 13 +++++++++++- .../ClassificationFeatureImportanceTests.java | 10 ++++++++- .../metadata/TotalFeatureImportanceTests.java | 10 ++++++++- 6 files changed, 64 insertions(+), 10 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java index 23c2aa168b3..a918a7bb795 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java @@ -21,8 +21,10 @@ package org.elasticsearch.client.ml.inference.results; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; @@ -115,11 +117,20 @@ public class FeatureImportance implements ToXContentObject { private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("feature_importance_class_importance", true, - a -> new ClassImportance((String) a[0], (Double) a[1]) + a -> new ClassImportance(a[0], (Double) a[1]) ); static { - PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME)); + PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return p.text(); + } else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) { + return p.numberValue(); + } else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) { + return p.booleanValue(); + } + throw new XContentParseException("Unsupported token [" + p.currentToken() + "]"); + }, new ParseField(CLASS_NAME), ObjectParser.ValueType.VALUE); PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); } @@ -127,15 +138,15 @@ public class FeatureImportance implements ToXContentObject { return PARSER.apply(parser, null); } - private final String className; + private final Object className; private final double importance; - public ClassImportance(String className, double importance) { + public ClassImportance(Object className, double importance) { this.className = className; this.importance = importance; } - public String getClassName() { + public Object getClassName() { return className; } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java index dfb3118c4c4..e7580ac276e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java @@ -21,20 +21,28 @@ package org.elasticsearch.client.ml.inference.results; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.test.ESTestCase; import java.io.IOException; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; public class FeatureImportanceTests extends AbstractXContentTestCase { @Override + @SuppressWarnings("unchecked") protected FeatureImportance createTestInstance() { + Supplier classNameGenerator = randomFrom( + () -> randomAlphaOfLength(10), + ESTestCase::randomBoolean, + () -> randomIntBetween(0, 10) + ); return new FeatureImportance( randomAlphaOfLength(10), randomBoolean() ? null : randomDoubleBetween(-10.0, 10.0, false), randomBoolean() ? null : - Stream.generate(() -> randomAlphaOfLength(10)) + Stream.generate(classNameGenerator) .limit(randomLongBetween(2, 10)) .map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false))) .collect(Collectors.toList())); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java index eef5c3bae21..adbf9ab052d 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java @@ -20,8 +20,10 @@ package org.elasticsearch.client.ml.inference.trainedmodel.metadata; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.test.ESTestCase; import java.io.IOException; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -29,13 +31,19 @@ import java.util.stream.Stream; public class TotalFeatureImportanceTests extends AbstractXContentTestCase { + @SuppressWarnings("unchecked") public static TotalFeatureImportance randomInstance() { + Supplier classNameGenerator = randomFrom( + () -> randomAlphaOfLength(10), + ESTestCase::randomBoolean, + () -> randomIntBetween(0, 10) + ); return new TotalFeatureImportance( randomAlphaOfLength(10), randomBoolean() ? null : randomImportance(), randomBoolean() ? null : - Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance())) + Stream.generate(() -> new TotalFeatureImportance.ClassImportance(classNameGenerator.get(), randomImportance())) .limit(randomIntBetween(1, 10)) .collect(Collectors.toList()) ); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportance.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportance.java index 7eff392dabe..b29e4709eb6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportance.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportance.java @@ -10,8 +10,10 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; @@ -118,7 +120,16 @@ public class ClassificationFeatureImportance extends AbstractFeatureImportance { ); static { - PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME)); + PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return p.text(); + } else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) { + return p.numberValue(); + } else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) { + return p.booleanValue(); + } + throw new XContentParseException("Unsupported token [" + p.currentToken() + "]"); + }, new ParseField(CLASS_NAME), ObjectParser.ValueType.VALUE); PARSER.declareDouble(constructorArg(), new ParseField(IMPORTANCE)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportanceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportanceTests.java index 6ef314cfe2c..3ad8daed61a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportanceTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportanceTests.java @@ -8,9 +8,11 @@ package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.util.Arrays; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -34,10 +36,16 @@ public class ClassificationFeatureImportanceTests extends AbstractSerializingTes return createRandomInstance(); } + @SuppressWarnings("unchecked") public static ClassificationFeatureImportance createRandomInstance() { + Supplier classNameGenerator = randomFrom( + () -> randomAlphaOfLength(10), + ESTestCase::randomBoolean, + () -> randomIntBetween(0, 10) + ); return new ClassificationFeatureImportance( randomAlphaOfLength(10), - Stream.generate(() -> randomAlphaOfLength(10)) + Stream.generate(classNameGenerator) .limit(randomLongBetween(2, 10)) .map(name -> new ClassificationFeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false))) .collect(Collectors.toList())); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java index fcf4978f525..fa68e71e8cc 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java @@ -8,10 +8,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata; import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; import org.junit.Before; import java.io.IOException; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -20,13 +22,19 @@ public class TotalFeatureImportanceTests extends AbstractBWCSerializationTestCas private boolean lenient; + @SuppressWarnings("unchecked") public static TotalFeatureImportance randomInstance() { + Supplier classNameGenerator = randomFrom( + () -> randomAlphaOfLength(10), + ESTestCase::randomBoolean, + () -> randomIntBetween(0, 10) + ); return new TotalFeatureImportance( randomAlphaOfLength(10), randomBoolean() ? null : randomImportance(), randomBoolean() ? null : - Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance())) + Stream.generate(() -> new TotalFeatureImportance.ClassImportance(classNameGenerator.get(), randomImportance())) .limit(randomIntBetween(1, 10)) .collect(Collectors.toList()) );