[ML] Adds feature importance to option to inference processor (#52218) (#52666)

This adds machine learning model feature importance calculations to the inference processor.

The new flag in the configuration matches the analytics parameter name: `num_top_feature_importance_values`
Example:
```
"inference": {
   "field_mappings": {},
   "model_id": "my_model",
   "inference_config": {
      "regression": {
         "num_top_feature_importance_values": 3
      }
   }
}
```

This will write to the document as follows:
```
"inference" : {
   "feature_importance" : {
      "FlightTimeMin" : -76.90955548511226,
      "FlightDelayType" : 114.13514762158526,
      "DistanceMiles" : 13.731580450792187
   },
   "predicted_value" : 108.33165831875137,
   "model_id" : "my_model"
}
```

This is done through calculating the [SHAP values](https://arxiv.org/abs/1802.03888).

It requires that models have populated `number_samples` for each tree node. This is not available to models that were created before 7.7.

Additionally, if the inference config is requesting feature_importance, and not all nodes have been upgraded yet, it will not allow the pipeline to be created. This is to safe-guard in a mixed-version environment where only some ingest nodes have been upgraded.

NOTE: the algorithm is a Java port of the one laid out in ml-cpp: https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc

usability blocked by: https://github.com/elastic/ml-cpp/pull/991
This commit is contained in:
Benjamin Trent 2020-02-21 18:42:31 -05:00 committed by GitHub
parent f06d692706
commit afd90647c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 980 additions and 104 deletions

View File

@ -44,6 +44,12 @@ include::common-options.asciidoc[]
Specifies the field to which the inference prediction is written. Defaults to
`predicted_value`.
`num_top_feature_importance_values`::::
(Optional, integer)
Specifies the maximum number of
{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature
importance] values per document. By default, it is zero and no feature importance
calculation occurs.
[discrete]
[[inference-processor-classification-opt]]
@ -63,6 +69,12 @@ Specifies the number of top class predictions to return. Defaults to 0.
Specifies the field to which the top classes are written. Defaults to
`top_classes`.
`num_top_feature_importance_values`::::
(Optional, integer)
Specifies the maximum number of
{ml-docs}/dfa-classification.html#dfa-classification-feature-importance[feature
importance] values per document. By default, it is zero and no feature importance
calculation occurs.
[discrete]
[[inference-processor-config-example]]

View File

@ -32,6 +32,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -73,6 +74,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors;
private Map<String, String> decoderMap;
private TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
@ -115,13 +117,35 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
return preProcessors;
}
private void preProcess(Map<String, Object> fields) {
void preProcess(Map<String, Object> fields) {
preProcessors.forEach(preProcessor -> preProcessor.process(fields));
}
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
preProcess(fields);
return trainedModel.infer(fields, config);
if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
throw ExceptionsHelper.badRequestException(
"Feature importance is not supported for the configured model of type [{}]",
trainedModel.getName());
}
return trainedModel.infer(fields,
config,
config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
}
private Map<String, String> getDecoderMap() {
if (decoderMap != null) {
return decoderMap;
}
synchronized (this) {
if (decoderMap != null) {
return decoderMap;
}
this.decoderMap = preProcessors.stream()
.map(PreProcessor::reverseLookup)
.collect(HashMap::new, Map::putAll, Map::putAll);
return decoderMap;
}
}
@Override

View File

@ -25,6 +25,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembeddi
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -235,6 +236,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
fields.put(destField, concatEmbeddings(processedFeatures));
}
@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(destField, fieldName);
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;

View File

@ -97,6 +97,11 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
return featureName;
}
@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(featureName, field);
}
@Override
public String getName() {
return NAME.getPreferredName();

View File

@ -18,8 +18,10 @@ import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* PreProcessor for one hot encoding a set of categorical values for a given field.
@ -80,6 +82,11 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
return hotMap;
}
@Override
public Map<String, String> reverseLookup() {
return hotMap.entrySet().stream().collect(Collectors.toMap(HashMap.Entry::getValue, (entry) -> field));
}
@Override
public String getName() {
return NAME.getPreferredName();

View File

@ -24,4 +24,9 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
* @param fields The fields and their values to process
*/
void process(Map<String, Object> fields);
/**
* @return Reverse lookup map to match resulting features to their original feature name
*/
Map<String, String> reverseLookup();
}

View File

@ -108,6 +108,11 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
return featureName;
}
@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(featureName, field);
}
@Override
public String getName() {
return NAME.getPreferredName();

View File

@ -35,9 +35,25 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
String classificationLabel,
List<TopClassEntry> topClasses,
InferenceConfig config) {
super(value);
assert config instanceof ClassificationConfig;
ClassificationConfig classificationConfig = (ClassificationConfig)config;
this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
}
public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
InferenceConfig config) {
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
}
private ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
ClassificationConfig classificationConfig) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
classificationConfig.getNumTopFeatureImportanceValues()));
this.classificationLabel = classificationLabel;
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
this.topNumClassesField = classificationConfig.getTopClassesResultsField();
@ -74,16 +90,17 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
ClassificationInferenceResults that = (ClassificationInferenceResults) object;
return Objects.equals(value(), that.value()) &&
Objects.equals(classificationLabel, that.classificationLabel) &&
Objects.equals(resultsField, that.resultsField) &&
Objects.equals(topNumClassesField, that.topNumClassesField) &&
Objects.equals(topClasses, that.topClasses);
return Objects.equals(value(), that.value())
&& Objects.equals(classificationLabel, that.classificationLabel)
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(topNumClassesField, that.topNumClassesField)
&& Objects.equals(topClasses, that.topClasses)
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
}
@Override
public int hashCode() {
return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField);
return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField, getFeatureImportance());
}
@Override
@ -100,6 +117,9 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
document.setFieldValue(parentResultField + "." + topNumClassesField,
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
}
}
@Override

View File

@ -10,18 +10,19 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.ingest.IngestDocument;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
public class RawInferenceResults extends SingleValueInferenceResults {
public static final String NAME = "raw";
public RawInferenceResults(double value) {
super(value);
public RawInferenceResults(double value, Map<String, Double> featureImportance) {
super(value, featureImportance);
}
public RawInferenceResults(StreamInput in) throws IOException {
super(in.readDouble());
super(in);
}
@Override
@ -34,12 +35,13 @@ public class RawInferenceResults extends SingleValueInferenceResults {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RawInferenceResults that = (RawInferenceResults) object;
return Objects.equals(value(), that.value());
return Objects.equals(value(), that.value())
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
}
@Override
public int hashCode() {
return Objects.hash(value());
return Objects.hash(value(), getFeatureImportance());
}
@Override

View File

@ -13,6 +13,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
public class RegressionInferenceResults extends SingleValueInferenceResults {
@ -22,14 +24,22 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
private final String resultsField;
public RegressionInferenceResults(double value, InferenceConfig config) {
super(value);
assert config instanceof RegressionConfig;
RegressionConfig regressionConfig = (RegressionConfig)config;
this(value, (RegressionConfig) config, Collections.emptyMap());
}
public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
this(value, (RegressionConfig)config, featureImportance);
}
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
regressionConfig.getNumTopFeatureImportanceValues()));
this.resultsField = regressionConfig.getResultsField();
}
public RegressionInferenceResults(StreamInput in) throws IOException {
super(in.readDouble());
super(in);
this.resultsField = in.readString();
}
@ -44,12 +54,14 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RegressionInferenceResults that = (RegressionInferenceResults) object;
return Objects.equals(value(), that.value()) && Objects.equals(this.resultsField, that.resultsField);
return Objects.equals(value(), that.value())
&& Objects.equals(this.resultsField, that.resultsField)
&& Objects.equals(this.getFeatureImportance(), that.getFeatureImportance());
}
@Override
public int hashCode() {
return Objects.hash(value(), resultsField);
return Objects.hash(value(), resultsField, getFeatureImportance());
}
@Override
@ -57,6 +69,9 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
document.setFieldValue(parentResultField + "." + this.resultsField, value());
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
}
}
@Override

View File

@ -5,27 +5,51 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
public abstract class SingleValueInferenceResults implements InferenceResults {
private final double value;
private final Map<String, Double> featureImportance;
static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
return unsortedFeatureImportances.entrySet()
.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
.limit(numTopFeatures)
.collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
}
SingleValueInferenceResults(StreamInput in) throws IOException {
value = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
} else {
this.featureImportance = Collections.emptyMap();
}
}
SingleValueInferenceResults(double value) {
SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
this.value = value;
this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
}
public Double value() {
return value;
}
public Map<String, Double> getFeatureImportance() {
return featureImportance;
}
public String valueAsString() {
return String.valueOf(value);
}
@ -33,6 +57,9 @@ public abstract class SingleValueInferenceResults implements InferenceResults {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(value);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
}
}
}

View File

@ -31,33 +31,39 @@ public class ClassificationConfig implements InferenceConfig {
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field");
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD);
public static ClassificationConfig EMPTY_PARAMS =
new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD, null);
private final int numTopClasses;
private final String topClassesResultsField;
private final String resultsField;
private final int numTopFeatureImportanceValues;
public static ClassificationConfig fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName());
String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
}
return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField);
return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, featureImportance);
}
private static final ConstructingObjectParser<ClassificationConfig, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ClassificationConfig(
(Integer) args[0], (String) args[1], (String) args[2]));
(Integer) args[0], (String) args[1], (String) args[2], (Integer) args[3]));
static {
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD);
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
}
public static ClassificationConfig fromXContent(XContentParser parser) {
@ -65,19 +71,33 @@ public class ClassificationConfig implements InferenceConfig {
}
public ClassificationConfig(Integer numTopClasses) {
this(numTopClasses, null, null);
this(numTopClasses, null, null, null);
}
public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField) {
this(numTopClasses, resultsField, topClassesResultsField, 0);
}
public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField, Integer featureImportance) {
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
this.topClassesResultsField = topClassesResultsField == null ? DEFAULT_TOP_CLASSES_RESULTS_FIELD : topClassesResultsField;
this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
if (featureImportance != null && featureImportance < 0) {
throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() +
"] must be greater than or equal to 0");
}
this.numTopFeatureImportanceValues = featureImportance == null ? 0 : featureImportance;
}
public ClassificationConfig(StreamInput in) throws IOException {
this.numTopClasses = in.readInt();
this.topClassesResultsField = in.readString();
this.resultsField = in.readString();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.numTopFeatureImportanceValues = in.readVInt();
} else {
this.numTopFeatureImportanceValues = 0;
}
}
public int getNumTopClasses() {
@ -92,11 +112,23 @@ public class ClassificationConfig implements InferenceConfig {
return resultsField;
}
public int getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues;
}
@Override
public boolean requestingImportance() {
return numTopFeatureImportanceValues > 0;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(numTopClasses);
out.writeString(topClassesResultsField);
out.writeString(resultsField);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeVInt(numTopFeatureImportanceValues);
}
}
@Override
@ -104,14 +136,15 @@ public class ClassificationConfig implements InferenceConfig {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassificationConfig that = (ClassificationConfig) o;
return Objects.equals(numTopClasses, that.numTopClasses) &&
Objects.equals(topClassesResultsField, that.topClassesResultsField) &&
Objects.equals(resultsField, that.resultsField);
return Objects.equals(numTopClasses, that.numTopClasses)
&& Objects.equals(topClassesResultsField, that.topClassesResultsField)
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
}
@Override
public int hashCode() {
return Objects.hash(numTopClasses, topClassesResultsField, resultsField);
return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues);
}
@Override
@ -122,6 +155,9 @@ public class ClassificationConfig implements InferenceConfig {
}
builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField);
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
if (numTopFeatureImportanceValues > 0) {
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
}
builder.endObject();
return builder;
}
@ -143,7 +179,7 @@ public class ClassificationConfig implements InferenceConfig {
@Override
public Version getMinimalSupportedVersion() {
return MIN_SUPPORTED_VERSION;
return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION;
}
}

View File

@ -18,4 +18,8 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable {
* All nodes in the cluster must be at least this version
*/
Version getMinimalSupportedVersion();
default boolean requestingImportance() {
return false;
}
}

View File

@ -13,7 +13,9 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@ -98,4 +100,19 @@ public final class InferenceHelpers {
}
return null;
}
public static Map<String, Double> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
Map<String, Double> featureImportances) {
if (processedFeatureToOriginalFeatureMap == null || processedFeatureToOriginalFeatureMap.isEmpty()) {
return featureImportances;
}
Map<String, Double> originalFeatureImportance = new HashMap<>();
featureImportances.forEach((feature, importance) -> {
String featureName = processedFeatureToOriginalFeatureMap.getOrDefault(feature, feature);
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : v1 + importance);
});
return originalFeatureImportance;
}
}

View File

@ -16,9 +16,12 @@ import java.io.IOException;
*/
public class NullInferenceConfig implements InferenceConfig {
public static final NullInferenceConfig INSTANCE = new NullInferenceConfig();
private final boolean requestingFeatureImportance;
private NullInferenceConfig() { }
public NullInferenceConfig(boolean requestingFeatureImportance) {
this.requestingFeatureImportance = requestingFeatureImportance;
}
@Override
public boolean isTargetTypeSupported(TargetType targetType) {
@ -37,6 +40,7 @@ public class NullInferenceConfig implements InferenceConfig {
@Override
public void writeTo(StreamOutput out) throws IOException {
throw new UnsupportedOperationException("Unable to serialize NullInferenceConfig objects");
}
@Override
@ -46,6 +50,11 @@ public class NullInferenceConfig implements InferenceConfig {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder;
throw new UnsupportedOperationException("Unable to write xcontent from NullInferenceConfig objects");
}
@Override
public boolean requestingImportance() {
return requestingFeatureImportance;
}
}

View File

@ -26,24 +26,27 @@ public class RegressionConfig implements InferenceConfig {
public static final ParseField NAME = new ParseField("regression");
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
private static final String DEFAULT_RESULTS_FIELD = "predicted_value";
public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD);
public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD, null);
public static RegressionConfig fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
}
return new RegressionConfig(resultsField);
return new RegressionConfig(resultsField, featureImportance);
}
private static final ConstructingObjectParser<RegressionConfig, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0]));
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0], (Integer)args[1]));
static {
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
}
public static RegressionConfig fromXContent(XContentParser parser) {
@ -51,19 +54,43 @@ public class RegressionConfig implements InferenceConfig {
}
private final String resultsField;
private final int numTopFeatureImportanceValues;
public RegressionConfig(String resultsField) {
this(resultsField, 0);
}
public RegressionConfig(String resultsField, Integer numTopFeatureImportanceValues) {
this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) {
throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() +
"] must be greater than or equal to 0");
}
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues == null ? 0 : numTopFeatureImportanceValues;
}
public RegressionConfig(StreamInput in) throws IOException {
this.resultsField = in.readString();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.numTopFeatureImportanceValues = in.readVInt();
} else {
this.numTopFeatureImportanceValues = 0;
}
}
public int getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues;
}
public String getResultsField() {
return resultsField;
}
@Override
public boolean requestingImportance() {
return numTopFeatureImportanceValues > 0;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -72,6 +99,9 @@ public class RegressionConfig implements InferenceConfig {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(resultsField);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeVInt(numTopFeatureImportanceValues);
}
}
@Override
@ -83,6 +113,9 @@ public class RegressionConfig implements InferenceConfig {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
if (numTopFeatureImportanceValues > 0) {
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
}
builder.endObject();
return builder;
}
@ -92,12 +125,13 @@ public class RegressionConfig implements InferenceConfig {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RegressionConfig that = (RegressionConfig)o;
return Objects.equals(this.resultsField, that.resultsField);
return Objects.equals(this.resultsField, that.resultsField)
&& Objects.equals(this.numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
}
@Override
public int hashCode() {
return Objects.hash(resultsField);
return Objects.hash(resultsField, numTopFeatureImportanceValues);
}
@Override
@ -107,7 +141,7 @@ public class RegressionConfig implements InferenceConfig {
@Override
public Version getMinimalSupportedVersion() {
return MIN_SUPPORTED_VERSION;
return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION;
}
}

View File

@ -0,0 +1,162 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
/**
* Ported from https://github.com/elastic/ml-cpp/blob/master/include/maths/CTreeShapFeatureImportance.h Path struct
*/
public class ShapPath {
private static final double DBL_EPSILON = Double.MIN_VALUE;
private final PathElement[] pathElements;
private final double[] scale;
private final int elementAndScaleOffset;
public ShapPath(ShapPath parentPath, int nextIndex) {
this.elementAndScaleOffset = parentPath.elementAndScaleOffset + nextIndex;
this.pathElements = parentPath.pathElements;
this.scale = parentPath.scale;
for (int i = 0; i < nextIndex; i++) {
pathElements[elementAndScaleOffset + i].featureIndex = parentPath.getElement(i).featureIndex;
pathElements[elementAndScaleOffset + i].fractionZeros = parentPath.getElement(i).fractionZeros;
pathElements[elementAndScaleOffset + i].fractionOnes = parentPath.getElement(i).fractionOnes;
scale[elementAndScaleOffset + i] = parentPath.getScale(i);
}
}
public ShapPath(PathElement[] elements, double[] scale) {
this.pathElements = elements;
this.scale = scale;
this.elementAndScaleOffset = 0;
}
// Update binomial coefficients to be able to compute Equation (2) from the paper. In particular,
// we have in the line path.scale[i + 1] += fractionOne * path.scale[i] * (i + 1.0) / (pathDepth +
// 1.0) that if we're on the "one" path, i.e. if the last feature selects this path if we include that
// feature in S (then fractionOne is 1), and we need to consider all the additional ways we now have of
// constructing each S of each given cardinality i + 1. Each of these come by adding the last feature
// to sets of size i and we **also** need to scale by the difference in binomial coefficients as both M
// increases by one and i increases by one. So we get additive term 1{last feature selects path if in S}
// * scale(i) * (i+1)! (M+1-(i+1)-1)!/(M+1)! / (i! (M-i-1)!/ M!), whence += scale(i) * (i+1) / (M+1).
public int extend(double fractionZero, double fractionOne, int featureIndex, int nextIndex) {
setValues(nextIndex, fractionOne, fractionZero, featureIndex);
setScale(nextIndex, nextIndex == 0 ? 1.0 : 0.0);
double stepDown = fractionOne / (double)(nextIndex + 1);
double stepUp = fractionZero / (double)(nextIndex + 1);
double countDown = nextIndex * stepDown;
double countUp = stepUp;
for (int i = (nextIndex - 1); i >= 0; --i, countDown -= stepDown, countUp += stepUp) {
setScale(i + 1, getScale(i + 1) + getScale(i) * countDown);
setScale(i, getScale(i) * countUp);
}
return nextIndex + 1;
}
public double sumUnwoundPath(int pathIndex, int nextIndex) {
double total = 0.0;
int pathDepth = nextIndex - 1;
double nextFractionOne = getScale(pathDepth);
double fractionOne = fractionOnes(pathIndex);
double fractionZero = fractionZeros(pathIndex);
if (fractionOne != 0) {
double pD = pathDepth + 1;
double stepUp = fractionZero / pD;
double stepDown = fractionOne / pD;
double countUp = stepUp;
double countDown = (pD - 1.0) * stepDown;
for (int i = pathDepth - 1; i >= 0; --i, countUp += stepUp, countDown -= stepDown) {
double tmp = nextFractionOne / countDown;
nextFractionOne = getScale(i) - tmp * countUp;
total += tmp;
}
} else {
double pD = pathDepth;
for(int i = 0; i < pathDepth; i++) {
total += getScale(i) / pD--;
}
total *= (pathDepth + 1) / (fractionZero + DBL_EPSILON);
}
return total;
}
public int unwind(int pathIndex, int nextIndex) {
int pathDepth = nextIndex - 1;
double nextFractionOne = getScale(pathDepth);
double fractionOne = fractionOnes(pathIndex);
double fractionZero = fractionZeros(pathIndex);
if (fractionOne != 0) {
double stepUp = fractionZero / (double)(pathDepth + 1);
double stepDown = fractionOne / (double)nextIndex;
double countUp = 0.0;
double countDown = nextIndex * stepDown;
for (int i = pathDepth; i >= 0; --i, countUp += stepUp, countDown -= stepDown) {
double tmp = nextFractionOne / countDown;
nextFractionOne = getScale(i) - tmp * countUp;
setScale(i, tmp);
}
} else {
double stepDown = (fractionZero + DBL_EPSILON) / (double)(pathDepth + 1);
double countDown = pathDepth * stepDown;
for (int i = 0; i <= pathDepth; ++i, countDown -= stepDown) {
setScale(i, getScale(i) / countDown);
}
}
for (int i = pathIndex; i < pathDepth; ++i) {
PathElement element = getElement(i + 1);
setValues(i, element.fractionOnes, element.fractionZeros, element.featureIndex);
}
return nextIndex - 1;
}
private void setValues(int index, double fractionOnes, double fractionZeros, int featureIndex) {
pathElements[index + elementAndScaleOffset].fractionOnes = fractionOnes;
pathElements[index + elementAndScaleOffset].fractionZeros = fractionZeros;
pathElements[index + elementAndScaleOffset].featureIndex = featureIndex;
}
private double getScale(int offset) {
return scale[offset + elementAndScaleOffset];
}
private void setScale(int offset, double value) {
scale[offset + elementAndScaleOffset] = value;
}
public double fractionOnes(int pathIndex) {
return pathElements[pathIndex + elementAndScaleOffset].fractionOnes;
}
public double fractionZeros(int pathIndex) {
return pathElements[pathIndex + elementAndScaleOffset].fractionZeros;
}
public int findFeatureIndex(int splitFeature, int nextIndex) {
for (int i = elementAndScaleOffset; i < elementAndScaleOffset + nextIndex; i++) {
if (pathElements[i].featureIndex == splitFeature) {
return i - elementAndScaleOffset;
}
}
return -1;
}
public int featureIndex(int pathIndex) {
return pathElements[pathIndex + elementAndScaleOffset].featureIndex;
}
private PathElement getElement(int offset) {
return pathElements[offset + elementAndScaleOffset];
}
public static final class PathElement {
private double fractionOnes = 1.0;
private double fractionZeros = 1.0;
private int featureIndex = -1;
}
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
@ -17,12 +18,16 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
/**
* Infer against the provided fields
*
* NOTE: Must be thread safe
*
* @param fields The fields and their values to infer against
* @param config The configuration options for inference
* @param featureDecoderMap A map for decoding feature value names to their originating feature.
* Necessary for feature influence.
* @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0).
* For regression this is continuous.
*/
InferenceResults infer(Map<String, Object> fields, InferenceConfig config);
InferenceResults infer(Map<String, Object> fields, InferenceConfig config, @Nullable Map<String, String> featureDecoderMap);
/**
* @return {@link TargetType} for the model.
@ -42,4 +47,19 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
* @return The estimated number of operations required at inference time
*/
long estimatedNumOperations();
/**
* @return Does the model support feature importance
*/
boolean supportsFeatureImportance();
/**
* Calculates the importance of each feature reference by the model for the passed in field values
*
* NOTE: Must be thread safe
* @param fields The fields inferring against
* @param featureDecoder A Map translating processed feature names to their original feature names
* @return A {@code Map<String, Double>} mapping each featureName to its importance
*/
Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
}

View File

@ -37,6 +37,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -133,18 +134,25 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
}
@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
if (config.isTargetTypeSupported(targetType) == false) {
throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
List<Double> inferenceResults = this.models.stream().map(model -> {
InferenceResults results = model.infer(fields, NullInferenceConfig.INSTANCE);
assert results instanceof SingleValueInferenceResults;
return ((SingleValueInferenceResults)results).value();
}).collect(Collectors.toList());
List<Double> inferenceResults = new ArrayList<>(this.models.size());
List<Map<String, Double>> featureInfluence = new ArrayList<>();
NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
this.models.forEach(model -> {
InferenceResults result = model.infer(fields, subModelInferenceConfig, Collections.emptyMap());
assert result instanceof SingleValueInferenceResults;
SingleValueInferenceResults inferenceResult = (SingleValueInferenceResults) result;
inferenceResults.add(inferenceResult.value());
if (config.requestingImportance()) {
featureInfluence.add(inferenceResult.getFeatureImportance());
}
});
List<Double> processed = outputAggregator.processValues(inferenceResults);
return buildResults(processed, config);
return buildResults(processed, featureInfluence, config, featureDecoderMap);
}
@Override
@ -152,14 +160,20 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return targetType;
}
private InferenceResults buildResults(List<Double> processedInferences, InferenceConfig config) {
private InferenceResults buildResults(List<Double> processedInferences,
List<Map<String, Double>> featureInfluence,
InferenceConfig config,
Map<String, String> featureDecoderMap) {
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(outputAggregator.aggregate(processedInferences));
return new RawInferenceResults(outputAggregator.aggregate(processedInferences),
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
}
switch(targetType) {
case REGRESSION:
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences), config);
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
config,
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
assert classificationWeights == null || processedInferences.size() == classificationWeights.length;
@ -172,6 +186,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return new ClassificationInferenceResults((double)topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(),
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)),
config);
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
@ -293,6 +308,27 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1);
}
@Override
public boolean supportsFeatureImportance() {
return models.stream().allMatch(TrainedModel::supportsFeatureImportance);
}
Map<String, Double> featureImportance(Map<String, Object> fields) {
return featureImportance(fields, Collections.emptyMap());
}
@Override
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
Map<String, Double> collapsed = mergeFeatureImportances(models.stream()
.map(trainedModel -> trainedModel.featureImportance(fields, Collections.emptyMap()))
.collect(Collectors.toList()));
return InferenceHelpers.decodeFeatureImportances(featureDecoder, collapsed);
}
private static Map<String, Double> mergeFeatureImportances(List<Map<String, Double>> featureImportances) {
return featureImportances.stream().collect(HashMap::new, (a, b) -> b.forEach((k, v) -> a.merge(k, v, Double::sum)), Map::putAll);
}
public static Builder builder() {
return new Builder();
}

View File

@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -104,7 +105,11 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
}
@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
if (config.requestingImportance()) {
throw ExceptionsHelper.badRequestException("[{}] model does not supports feature importance",
NAME.getPreferredName());
}
if (config instanceof ClassificationConfig == false) {
throw ExceptionsHelper.badRequestException("[{}] model only supports classification",
NAME.getPreferredName());
@ -138,6 +143,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
return new ClassificationInferenceResults(topClasses.v1(),
LANGUAGE_NAMES.get(topClasses.v1()),
topClasses.v2(),
Collections.emptyMap(),
classificationConfig);
}
@ -159,6 +165,16 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
return numOps;
}
@Override
public boolean supportsFeatureImportance() {
return false;
}
@Override
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
throw new UnsupportedOperationException("[lang_ident] does not support feature importance");
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;

View File

@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -44,6 +45,7 @@ import java.util.Objects;
import java.util.Queue;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
@ -86,6 +88,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
private final TargetType targetType;
private final List<String> classificationLabels;
private final CachedSupplier<Double> highestOrderCategory;
// populated lazily when feature importance is calculated
private double[] nodeEstimates;
private Integer maxDepth;
Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
@ -120,7 +125,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
}
@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
if (config.isTargetTypeSupported(targetType) == false) {
throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
@ -129,21 +134,23 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
List<Double> features = featureNames.stream()
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
.collect(Collectors.toList());
return infer(features, config);
}
private InferenceResults infer(List<Double> features, InferenceConfig config) {
Map<String, Double> featureImportance = config.requestingImportance() ?
featureImportance(features, featureDecoderMap) :
Collections.emptyMap();
TreeNode node = nodes.get(0);
while(node.isLeaf() == false) {
node = nodes.get(node.compare(features));
}
return buildResult(node.getLeafValue(), config);
return buildResult(node.getLeafValue(), featureImportance, config);
}
private InferenceResults buildResult(Double value, InferenceConfig config) {
private InferenceResults buildResult(Double value, Map<String, Double> featureImportance, InferenceConfig config) {
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(value);
return new RawInferenceResults(value, featureImportance);
}
switch (targetType) {
case CLASSIFICATION:
@ -156,9 +163,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return new ClassificationInferenceResults(value,
classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(),
featureImportance,
config);
case REGRESSION:
return new RegressionInferenceResults(value, config);
return new RegressionInferenceResults(value, config, featureImportance);
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
}
@ -192,7 +200,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
// If we are classification, we should assume that the largest leaf value is whole.
assert maxCategory == Math.rint(maxCategory);
List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
// TODO, eventually have TreeNodes contain confidence levels
list.set(Double.valueOf(inferenceValue).intValue(), 1.0);
return list;
}
@ -263,12 +270,138 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
detectCycle();
}
@Override
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
if (nodes.stream().allMatch(n -> n.getNumberSamples() == 0)) {
throw ExceptionsHelper.badRequestException("[tree_structure.number_samples] must be greater than zero for feature importance");
}
List<Double> features = featureNames.stream()
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
.collect(Collectors.toList());
return featureImportance(features, featureDecoder);
}
private Map<String, Double> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
calculateNodeEstimatesIfNeeded();
double[] featureImportance = new double[fieldValues.size()];
int arrSize = ((this.maxDepth + 1) * (this.maxDepth + 2))/2;
ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize];
for (int i = 0; i < arrSize; i++) {
elements[i] = new ShapPath.PathElement();
}
double[] scale = new double[arrSize];
ShapPath initialPath = new ShapPath(elements, scale);
shapRecursive(fieldValues, this.nodeEstimates, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
return InferenceHelpers.decodeFeatureImportances(featureDecoder,
IntStream.range(0, featureImportance.length)
.boxed()
.collect(Collectors.toMap(featureNames::get, i -> featureImportance[i])));
}
private void calculateNodeEstimatesIfNeeded() {
if (this.nodeEstimates != null && this.maxDepth != null) {
return;
}
synchronized (this) {
if (this.nodeEstimates != null && this.maxDepth != null) {
return;
}
double[] estimates = new double[nodes.size()];
this.maxDepth = fillNodeEstimates(estimates, 0, 0);
this.nodeEstimates = estimates;
}
}
/**
* Note, this is a port from https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc
*
* If improvements in performance or accuracy have been found, it is probably best that the changes are implemented on the native
* side first and then ported to the Java side.
*/
private void shapRecursive(List<Double> processedFeatures,
double[] nodeValues,
ShapPath parentSplitPath,
int nodeIndex,
double parentFractionZero,
double parentFractionOne,
int parentFeatureIndex,
double[] featureImportance,
int nextIndex) {
ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex);
TreeNode currNode = nodes.get(nodeIndex);
nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
if (currNode.isLeaf()) {
// TODO multi-value????
double leafValue = nodeValues[nodeIndex];
for (int i = 1; i < nextIndex; ++i) {
double scale = splitPath.sumUnwoundPath(i, nextIndex);
int inputColumnIndex = splitPath.featureIndex(i);
featureImportance[inputColumnIndex] += scale * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i)) * leafValue;
}
} else {
int hotIndex = currNode.compare(processedFeatures);
int coldIndex = hotIndex == currNode.getLeftChild() ? currNode.getRightChild() : currNode.getLeftChild();
double incomingFractionZero = 1.0;
double incomingFractionOne = 1.0;
int splitFeature = currNode.getSplitFeature();
int pathIndex = splitPath.findFeatureIndex(splitFeature, nextIndex);
if (pathIndex > -1) {
incomingFractionZero = splitPath.fractionZeros(pathIndex);
incomingFractionOne = splitPath.fractionOnes(pathIndex);
nextIndex = splitPath.unwind(pathIndex, nextIndex);
}
double hotFractionZero = nodes.get(hotIndex).getNumberSamples() / (double)currNode.getNumberSamples();
double coldFractionZero = nodes.get(coldIndex).getNumberSamples() / (double)currNode.getNumberSamples();
shapRecursive(processedFeatures, nodeValues, splitPath,
hotIndex, incomingFractionZero * hotFractionZero,
incomingFractionOne, splitFeature, featureImportance, nextIndex);
shapRecursive(processedFeatures, nodeValues, splitPath,
coldIndex, incomingFractionZero * coldFractionZero,
0.0, splitFeature, featureImportance, nextIndex);
}
}
/**
* This recursively populates the provided {@code double[]} with the node estimated values
*
* Used when calculating feature importance.
* @param nodeEstimates Array to update in place with the node estimated values
* @param nodeIndex Current node index
* @param depth Current depth
* @return The current max depth
*/
private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) {
TreeNode node = nodes.get(nodeIndex);
if (node.isLeaf()) {
nodeEstimates[nodeIndex] = node.getLeafValue();
return 0;
}
int depthLeft = fillNodeEstimates(nodeEstimates, node.getLeftChild(), depth + 1);
int depthRight = fillNodeEstimates(nodeEstimates, node.getRightChild(), depth + 1);
long leftWeight = nodes.get(node.getLeftChild()).getNumberSamples();
long rightWeight = nodes.get(node.getRightChild()).getNumberSamples();
long divisor = leftWeight + rightWeight;
double averageValue = divisor == 0 ?
0.0 :
(leftWeight * nodeEstimates[node.getLeftChild()] + rightWeight * nodeEstimates[node.getRightChild()]) / divisor;
nodeEstimates[nodeIndex] = averageValue;
return Math.max(depthLeft, depthRight) + 1;
}
@Override
public long estimatedNumOperations() {
// Grabbing the features from the doc + the depth of the tree
return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
}
@Override
public boolean supportsFeatureImportance() {
return true;
}
/**
* The highest index of a feature used any of the nodes.
* If no nodes use a feature return -1. This can only happen

View File

@ -342,8 +342,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
}};
assertThat(
((ClassificationInferenceResults)definition.getTrainedModel()
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(),
equalTo("Iris-setosa"));
@ -354,8 +353,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
put("petal_width", 1.4);
}};
assertThat(
((ClassificationInferenceResults)definition.getTrainedModel()
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(),
equalTo("Iris-versicolor"));
@ -366,10 +364,8 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
put("petal_width", 2.0);
}};
assertThat(
((ClassificationInferenceResults)definition.getTrainedModel()
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(),
equalTo("Iris-virginica"));
}
}

View File

@ -8,10 +8,12 @@ package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import java.util.Collections;
public class RawInferenceResultsTests extends AbstractWireSerializingTestCase<RawInferenceResults> {
public static RawInferenceResults createRandomResults() {
return new RawInferenceResults(randomDouble());
return new RawInferenceResults(randomDouble(), randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08));
}
@Override

View File

@ -30,11 +30,12 @@ public class ClassificationConfigTests extends AbstractSerializingTestCase<Class
ClassificationConfig expected = ClassificationConfig.EMPTY_PARAMS;
assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected));
expected = new ClassificationConfig(3, "foo", "bar");
expected = new ClassificationConfig(3, "foo", "bar", 2);
Map<String, Object> configMap = new HashMap<>();
configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo");
configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar");
configMap.put(ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 2);
assertThat(ClassificationConfig.fromMap(configMap), equalTo(expected));
}

View File

@ -24,9 +24,10 @@ public class RegressionConfigTests extends AbstractSerializingTestCase<Regressio
}
public void testFromMap() {
RegressionConfig expected = new RegressionConfig("foo");
RegressionConfig expected = new RegressionConfig("foo", 3);
Map<String, Object> config = new HashMap<String, Object>(){{
put(RegressionConfig.RESULTS_FIELD.getPreferredName(), "foo");
put(RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 3);
}};
assertThat(RegressionConfig.fromMap(config), equalTo(expected));
}

View File

@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
@ -39,6 +40,7 @@ import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
private final double eps = 1.0E-8;
private boolean lenient;
@ -267,7 +269,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
List<Double> scores = Arrays.asList(0.230557435, 0.162032651);
double eps = 0.000001;
List<ClassificationInferenceResults.TopClassEntry> probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
@ -278,7 +281,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
expected = Arrays.asList(0.310025518, 0.6899744811);
scores = Arrays.asList(0.217017863, 0.2069923443);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
@ -289,7 +293,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
expected = Arrays.asList(0.768524783, 0.231475216);
scores = Arrays.asList(0.230557435, 0.162032651);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
@ -303,7 +308,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
expected = Arrays.asList(0.6899744811, 0.3100255188);
scores = Arrays.asList(0.482982136, 0.0930076556);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
@ -361,24 +367,28 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
assertThat(0.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
}
public void testMultiClassClassificationInference() {
@ -432,24 +442,28 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(2.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.6);
put("bar", null);
}};
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
}
public void testRegressionInference() {
@ -489,12 +503,16 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.9,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.5,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
// Test with NO aggregator supplied, verifies default behavior of non-weighted sum
ensemble = Ensemble.builder()
@ -506,19 +524,25 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
featureVector = Arrays.asList(0.4, 0.0);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.8,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
assertThat(1.8,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
}
public void testInferNestedFields() {
@ -564,7 +588,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
}});
}};
assertThat(0.9,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureMap = new HashMap<String, Object>() {{
put("foo", new HashMap<String, Object>(){{
@ -575,7 +601,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
}});
}};
assertThat(0.5,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
}
public void testOperationsEstimations() {
@ -590,6 +618,114 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
assertThat(ensemble.estimatedNumOperations(), equalTo(9L));
}
public void testFeatureImportance() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(
TreeNode.builder(0)
.setSplitFeature(0)
.setOperator(Operator.LT)
.setLeftChild(1)
.setRightChild(2)
.setThreshold(0.55)
.setNumberSamples(10L),
TreeNode.builder(1)
.setSplitFeature(0)
.setLeftChild(3)
.setRightChild(4)
.setOperator(Operator.LT)
.setThreshold(0.41)
.setNumberSamples(6L),
TreeNode.builder(2)
.setSplitFeature(1)
.setLeftChild(5)
.setRightChild(6)
.setOperator(Operator.LT)
.setThreshold(0.25)
.setNumberSamples(4L),
TreeNode.builder(3).setLeafValue(1.18230136).setNumberSamples(5L),
TreeNode.builder(4).setLeafValue(1.98006658).setNumberSamples(1L),
TreeNode.builder(5).setLeafValue(3.25350885).setNumberSamples(3L),
TreeNode.builder(6).setLeafValue(2.42384369).setNumberSamples(1L)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(
TreeNode.builder(0)
.setSplitFeature(0)
.setOperator(Operator.LT)
.setLeftChild(1)
.setRightChild(2)
.setThreshold(0.45)
.setNumberSamples(10L),
TreeNode.builder(1)
.setSplitFeature(0)
.setLeftChild(3)
.setRightChild(4)
.setOperator(Operator.LT)
.setThreshold(0.25)
.setNumberSamples(5L),
TreeNode.builder(2)
.setSplitFeature(0)
.setLeftChild(5)
.setRightChild(6)
.setOperator(Operator.LT)
.setThreshold(0.59)
.setNumberSamples(5L),
TreeNode.builder(3).setLeafValue(1.04476388).setNumberSamples(3L),
TreeNode.builder(4).setLeafValue(1.52799228).setNumberSamples(2L),
TreeNode.builder(5).setLeafValue(1.98006658).setNumberSamples(1L),
TreeNode.builder(6).setLeafValue(2.950216).setNumberSamples(4L)).build();
Ensemble ensemble = Ensemble.builder().setOutputAggregator(new WeightedSum())
.setTrainedModels(Arrays.asList(tree1, tree2))
.setFeatureNames(featureNames)
.build();
Map<String, Double> featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.0, 0.9)));
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.1, 0.8)));
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.2, 0.7)));
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.3, 0.6)));
assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.4, 0.5)));
assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.5, 0.4)));
assertThat(featureImportance.get("foo"), closeTo(0.0798679, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.6, 0.3)));
assertThat(featureImportance.get("foo"), closeTo(1.80491886, eps));
assertThat(featureImportance.get("bar"), closeTo(-0.4355742, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.7, 0.2)));
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.8, 0.1)));
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.9, 0.0)));
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
}
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}

View File

@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceRes
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import org.junit.Before;
import java.io.IOException;
@ -35,6 +36,7 @@ import static org.hamcrest.Matchers.equalTo;
public class TreeTests extends AbstractSerializingTestCase<Tree> {
private final double eps = 1.0E-8;
private boolean lenient;
@Before
@ -118,7 +120,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump
assertThat(42.0,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
}
public void testInfer() {
@ -138,27 +141,31 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.3,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the left child of the left child of the root node
// i.e. it takes the path left, left
featureVector = Arrays.asList(0.3, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the right child of the left child of the root node
// i.e. it takes the path left, right
featureVector = Arrays.asList(0.3, 0.9);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should still work if the internal values are strings
List<String> featureVectorStrings = Arrays.asList("0.3", "0.9");
featureMap = zipObjMap(featureNames, featureVectorStrings);
assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should handle missing values and take the default_left path
featureMap = new HashMap<String, Object>(2) {{
@ -166,7 +173,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
put("bar", null);
}};
assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
}
public void testInferNestedFields() {
@ -192,7 +200,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
}});
}};
assertThat(0.3,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the left child of the left child of the root node
// i.e. it takes the path left, left
@ -205,7 +214,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
}});
}};
assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the right child of the left child of the root node
// i.e. it takes the path left, right
@ -218,7 +228,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
}});
}};
assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
}
public void testTreeClassificationProbability() {
@ -241,7 +252,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
List<String> expectedFields = Arrays.asList("dog", "cat");
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
List<ClassificationInferenceResults.TopClassEntry> probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
@ -252,7 +264,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
featureVector = Arrays.asList(0.3, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
@ -264,7 +277,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
put("bar", null);
}};
probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
@ -345,6 +359,55 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
assertThat(tree.estimatedNumOperations(), equalTo(7L));
}
public void testFeatureImportance() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(
TreeNode.builder(0)
.setSplitFeature(0)
.setOperator(Operator.LT)
.setLeftChild(1)
.setRightChild(2)
.setThreshold(0.5)
.setNumberSamples(4L),
TreeNode.builder(1)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4)
.setOperator(Operator.LT)
.setThreshold(0.5)
.setNumberSamples(2L),
TreeNode.builder(2)
.setSplitFeature(1)
.setLeftChild(5)
.setRightChild(6)
.setOperator(Operator.LT)
.setThreshold(0.5)
.setNumberSamples(2L),
TreeNode.builder(3).setLeafValue(3.0).setNumberSamples(1L),
TreeNode.builder(4).setLeafValue(8.0).setNumberSamples(1L),
TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L),
TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build();
Map<String, Double> featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.25)),
Collections.emptyMap());
assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.75)), Collections.emptyMap());
assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.25)), Collections.emptyMap());
assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.75)), Collections.emptyMap());
assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
}
public void testMaxFeatureIndex() {
int numFeatures = randomIntBetween(1, 15);

View File

@ -115,7 +115,10 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"inference\": {\n" +
" \"target_field\": \"ml.classification\",\n" +
" \"inference_config\": {\"classification\": " +
" {\"num_top_classes\":2, \"top_classes_results_field\": \"result_class_prob\"}},\n" +
" {\"num_top_classes\":2, " +
" \"top_classes_results_field\": \"result_class_prob\"," +
" \"num_top_feature_importance_values\": 2" +
" }},\n" +
" \"model_id\": \"test_classification\",\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +
@ -153,6 +156,8 @@ public class InferenceIngestIT extends ESRestTestCase {
String responseString = EntityUtils.toString(response.getEntity());
assertThat(responseString, containsString("\"predicted_value\":\"second\""));
assertThat(responseString, containsString("\"predicted_value\":1.0"));
assertThat(responseString, containsString("\"col2\":0.944"));
assertThat(responseString, containsString("\"col1\":0.19999"));
String sourceWithMissingModel = "{\n" +
" \"pipeline\": {\n" +
@ -321,16 +326,19 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"number_samples\": 300,\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"number_samples\": 100,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"number_samples\": 200,\n" +
" \"leaf_value\": 2\n" +
" }\n" +
" ],\n" +
@ -352,15 +360,18 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"number_samples\": 150,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"number_samples\": 50,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"number_samples\": 100,\n" +
" \"leaf_value\": 2\n" +
" }\n" +
" ],\n" +
@ -445,6 +456,7 @@ public class InferenceIngestIT extends ESRestTestCase {
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"number_samples\": 100,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
@ -454,10 +466,12 @@ public class InferenceIngestIT extends ESRestTestCase {
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"number_samples\": 80,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"number_samples\": 20,\n" +
" \"leaf_value\": 0\n" +
" }\n" +
" ],\n" +
@ -476,6 +490,7 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"number_samples\": 180,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
@ -484,10 +499,12 @@ public class InferenceIngestIT extends ESRestTestCase {
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"number_samples\": 10,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"number_samples\": 170,\n" +
" \"leaf_value\": 0\n" +
" }\n" +
" ],\n" +

View File

@ -102,6 +102,43 @@ public class InferenceProcessorTests extends ESTestCase {
assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
}
public void testMutateDocumentClassificationFeatureInfluence() {
ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2);
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
"ml.my_processor",
"classification_model",
classificationConfig,
Collections.emptyMap());
Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
Map<String, Double> featureInfluence = new HashMap<>();
featureInfluence.put("feature_1", 1.13);
featureInfluence.put("feature_2", -42.0);
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0,
"foo",
classes,
featureInfluence,
classificationConfig)),
true);
inferenceProcessor.mutateDocument(response, document);
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
}
@SuppressWarnings("unchecked")
public void testMutateDocumentClassificationTopNClassesWithSpecificField() {
ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops");
@ -154,6 +191,34 @@ public class InferenceProcessorTests extends ESTestCase {
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model"));
}
public void testMutateDocumentRegressionWithTopFetures() {
RegressionConfig regressionConfig = new RegressionConfig("foo", 2);
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
"ml.my_processor",
"regression_model",
regressionConfig,
Collections.emptyMap());
Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
Map<String, Double> featureInfluence = new HashMap<>();
featureInfluence.put("feature_1", 1.13);
featureInfluence.put("feature_2", -42.0);
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true);
inferenceProcessor.mutateDocument(response, document);
assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7));
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model"));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
}
public void testGenerateRequestWithEmptyMapping() {
String modelId = "model";
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);