[ML] Adds feature importance to option to inference processor (#52218) (#52666)

This adds machine learning model feature importance calculations to the inference processor.

The new flag in the configuration matches the analytics parameter name: `num_top_feature_importance_values`
Example:
```
"inference": {
   "field_mappings": {},
   "model_id": "my_model",
   "inference_config": {
      "regression": {
         "num_top_feature_importance_values": 3
      }
   }
}
```

This will write to the document as follows:
```
"inference" : {
   "feature_importance" : {
      "FlightTimeMin" : -76.90955548511226,
      "FlightDelayType" : 114.13514762158526,
      "DistanceMiles" : 13.731580450792187
   },
   "predicted_value" : 108.33165831875137,
   "model_id" : "my_model"
}
```

This is done through calculating the [SHAP values](https://arxiv.org/abs/1802.03888).

It requires that models have populated `number_samples` for each tree node. This is not available to models that were created before 7.7.

Additionally, if the inference config is requesting feature_importance, and not all nodes have been upgraded yet, it will not allow the pipeline to be created. This is to safe-guard in a mixed-version environment where only some ingest nodes have been upgraded.

NOTE: the algorithm is a Java port of the one laid out in ml-cpp: https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc

usability blocked by: https://github.com/elastic/ml-cpp/pull/991
This commit is contained in:
Benjamin Trent 2020-02-21 18:42:31 -05:00 committed by GitHub
parent f06d692706
commit afd90647c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 980 additions and 104 deletions

View File

@ -44,6 +44,12 @@ include::common-options.asciidoc[]
Specifies the field to which the inference prediction is written. Defaults to Specifies the field to which the inference prediction is written. Defaults to
`predicted_value`. `predicted_value`.
`num_top_feature_importance_values`::::
(Optional, integer)
Specifies the maximum number of
{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature
importance] values per document. By default, it is zero and no feature importance
calculation occurs.
[discrete] [discrete]
[[inference-processor-classification-opt]] [[inference-processor-classification-opt]]
@ -63,6 +69,12 @@ Specifies the number of top class predictions to return. Defaults to 0.
Specifies the field to which the top classes are written. Defaults to Specifies the field to which the top classes are written. Defaults to
`top_classes`. `top_classes`.
`num_top_feature_importance_values`::::
(Optional, integer)
Specifies the maximum number of
{ml-docs}/dfa-classification.html#dfa-classification-feature-importance[feature
importance] values per document. By default, it is zero and no feature importance
calculation occurs.
[discrete] [discrete]
[[inference-processor-config-example]] [[inference-processor-config-example]]

View File

@ -32,6 +32,7 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -73,6 +74,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
private final TrainedModel trainedModel; private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors; private final List<PreProcessor> preProcessors;
private Map<String, String> decoderMap;
private TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) { private TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL); this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
@ -115,13 +117,35 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
return preProcessors; return preProcessors;
} }
private void preProcess(Map<String, Object> fields) { void preProcess(Map<String, Object> fields) {
preProcessors.forEach(preProcessor -> preProcessor.process(fields)); preProcessors.forEach(preProcessor -> preProcessor.process(fields));
} }
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) { public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
preProcess(fields); preProcess(fields);
return trainedModel.infer(fields, config); if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
throw ExceptionsHelper.badRequestException(
"Feature importance is not supported for the configured model of type [{}]",
trainedModel.getName());
}
return trainedModel.infer(fields,
config,
config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
}
private Map<String, String> getDecoderMap() {
if (decoderMap != null) {
return decoderMap;
}
synchronized (this) {
if (decoderMap != null) {
return decoderMap;
}
this.decoderMap = preProcessors.stream()
.map(PreProcessor::reverseLookup)
.collect(HashMap::new, Map::putAll, Map::putAll);
return decoderMap;
}
} }
@Override @Override

View File

@ -25,6 +25,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembeddi
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -235,6 +236,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
fields.put(destField, concatEmbeddings(processedFeatures)); fields.put(destField, concatEmbeddings(processedFeatures));
} }
@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(destField, fieldName);
}
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
long size = SHALLOW_SIZE; long size = SHALLOW_SIZE;

View File

@ -97,6 +97,11 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
return featureName; return featureName;
} }
@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(featureName, field);
}
@Override @Override
public String getName() { public String getName() {
return NAME.getPreferredName(); return NAME.getPreferredName();

View File

@ -18,8 +18,10 @@ import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors;
/** /**
* PreProcessor for one hot encoding a set of categorical values for a given field. * PreProcessor for one hot encoding a set of categorical values for a given field.
@ -80,6 +82,11 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
return hotMap; return hotMap;
} }
@Override
public Map<String, String> reverseLookup() {
return hotMap.entrySet().stream().collect(Collectors.toMap(HashMap.Entry::getValue, (entry) -> field));
}
@Override @Override
public String getName() { public String getName() {
return NAME.getPreferredName(); return NAME.getPreferredName();

View File

@ -24,4 +24,9 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
* @param fields The fields and their values to process * @param fields The fields and their values to process
*/ */
void process(Map<String, Object> fields); void process(Map<String, Object> fields);
/**
* @return Reverse lookup map to match resulting features to their original feature name
*/
Map<String, String> reverseLookup();
} }

View File

@ -108,6 +108,11 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
return featureName; return featureName;
} }
@Override
public Map<String, String> reverseLookup() {
return Collections.singletonMap(featureName, field);
}
@Override @Override
public String getName() { public String getName() {
return NAME.getPreferredName(); return NAME.getPreferredName();

View File

@ -35,9 +35,25 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
String classificationLabel, String classificationLabel,
List<TopClassEntry> topClasses, List<TopClassEntry> topClasses,
InferenceConfig config) { InferenceConfig config) {
super(value); this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
assert config instanceof ClassificationConfig; }
ClassificationConfig classificationConfig = (ClassificationConfig)config;
public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
InferenceConfig config) {
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
}
private ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
ClassificationConfig classificationConfig) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
classificationConfig.getNumTopFeatureImportanceValues()));
this.classificationLabel = classificationLabel; this.classificationLabel = classificationLabel;
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses); this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
this.topNumClassesField = classificationConfig.getTopClassesResultsField(); this.topNumClassesField = classificationConfig.getTopClassesResultsField();
@ -74,16 +90,17 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
if (object == this) { return true; } if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; } if (object == null || getClass() != object.getClass()) { return false; }
ClassificationInferenceResults that = (ClassificationInferenceResults) object; ClassificationInferenceResults that = (ClassificationInferenceResults) object;
return Objects.equals(value(), that.value()) && return Objects.equals(value(), that.value())
Objects.equals(classificationLabel, that.classificationLabel) && && Objects.equals(classificationLabel, that.classificationLabel)
Objects.equals(resultsField, that.resultsField) && && Objects.equals(resultsField, that.resultsField)
Objects.equals(topNumClassesField, that.topNumClassesField) && && Objects.equals(topNumClassesField, that.topNumClassesField)
Objects.equals(topClasses, that.topClasses); && Objects.equals(topClasses, that.topClasses)
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField); return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField, getFeatureImportance());
} }
@Override @Override
@ -100,6 +117,9 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
document.setFieldValue(parentResultField + "." + topNumClassesField, document.setFieldValue(parentResultField + "." + topNumClassesField,
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList())); topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
} }
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
}
} }
@Override @Override

View File

@ -10,18 +10,19 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.ingest.IngestDocument;
import java.io.IOException; import java.io.IOException;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
public class RawInferenceResults extends SingleValueInferenceResults { public class RawInferenceResults extends SingleValueInferenceResults {
public static final String NAME = "raw"; public static final String NAME = "raw";
public RawInferenceResults(double value) { public RawInferenceResults(double value, Map<String, Double> featureImportance) {
super(value); super(value, featureImportance);
} }
public RawInferenceResults(StreamInput in) throws IOException { public RawInferenceResults(StreamInput in) throws IOException {
super(in.readDouble()); super(in);
} }
@Override @Override
@ -34,12 +35,13 @@ public class RawInferenceResults extends SingleValueInferenceResults {
if (object == this) { return true; } if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; } if (object == null || getClass() != object.getClass()) { return false; }
RawInferenceResults that = (RawInferenceResults) object; RawInferenceResults that = (RawInferenceResults) object;
return Objects.equals(value(), that.value()); return Objects.equals(value(), that.value())
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(value()); return Objects.hash(value(), getFeatureImportance());
} }
@Override @Override

View File

@ -13,6 +13,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
public class RegressionInferenceResults extends SingleValueInferenceResults { public class RegressionInferenceResults extends SingleValueInferenceResults {
@ -22,14 +24,22 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
private final String resultsField; private final String resultsField;
public RegressionInferenceResults(double value, InferenceConfig config) { public RegressionInferenceResults(double value, InferenceConfig config) {
super(value); this(value, (RegressionConfig) config, Collections.emptyMap());
assert config instanceof RegressionConfig; }
RegressionConfig regressionConfig = (RegressionConfig)config;
public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
this(value, (RegressionConfig)config, featureImportance);
}
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
regressionConfig.getNumTopFeatureImportanceValues()));
this.resultsField = regressionConfig.getResultsField(); this.resultsField = regressionConfig.getResultsField();
} }
public RegressionInferenceResults(StreamInput in) throws IOException { public RegressionInferenceResults(StreamInput in) throws IOException {
super(in.readDouble()); super(in);
this.resultsField = in.readString(); this.resultsField = in.readString();
} }
@ -44,12 +54,14 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
if (object == this) { return true; } if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; } if (object == null || getClass() != object.getClass()) { return false; }
RegressionInferenceResults that = (RegressionInferenceResults) object; RegressionInferenceResults that = (RegressionInferenceResults) object;
return Objects.equals(value(), that.value()) && Objects.equals(this.resultsField, that.resultsField); return Objects.equals(value(), that.value())
&& Objects.equals(this.resultsField, that.resultsField)
&& Objects.equals(this.getFeatureImportance(), that.getFeatureImportance());
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(value(), resultsField); return Objects.hash(value(), resultsField, getFeatureImportance());
} }
@Override @Override
@ -57,6 +69,9 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
ExceptionsHelper.requireNonNull(document, "document"); ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField"); ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
document.setFieldValue(parentResultField + "." + this.resultsField, value()); document.setFieldValue(parentResultField + "." + this.resultsField, value());
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
}
} }
@Override @Override

View File

@ -5,27 +5,51 @@
*/ */
package org.elasticsearch.xpack.core.ml.inference.results; package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
public abstract class SingleValueInferenceResults implements InferenceResults { public abstract class SingleValueInferenceResults implements InferenceResults {
private final double value; private final double value;
private final Map<String, Double> featureImportance;
static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
return unsortedFeatureImportances.entrySet()
.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
.limit(numTopFeatures)
.collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
}
SingleValueInferenceResults(StreamInput in) throws IOException { SingleValueInferenceResults(StreamInput in) throws IOException {
value = in.readDouble(); value = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
} else {
this.featureImportance = Collections.emptyMap();
}
} }
SingleValueInferenceResults(double value) { SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
this.value = value; this.value = value;
this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
} }
public Double value() { public Double value() {
return value; return value;
} }
public Map<String, Double> getFeatureImportance() {
return featureImportance;
}
public String valueAsString() { public String valueAsString() {
return String.valueOf(value); return String.valueOf(value);
} }
@ -33,6 +57,9 @@ public abstract class SingleValueInferenceResults implements InferenceResults {
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(value); out.writeDouble(value);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
}
} }
} }

View File

@ -31,33 +31,39 @@ public class ClassificationConfig implements InferenceConfig {
public static final ParseField RESULTS_FIELD = new ParseField("results_field"); public static final ParseField RESULTS_FIELD = new ParseField("results_field");
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field"); public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field");
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0; private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD); public static ClassificationConfig EMPTY_PARAMS =
new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD, null);
private final int numTopClasses; private final int numTopClasses;
private final String topClassesResultsField; private final String topClassesResultsField;
private final String resultsField; private final String resultsField;
private final int numTopFeatureImportanceValues;
public static ClassificationConfig fromMap(Map<String, Object> map) { public static ClassificationConfig fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map); Map<String, Object> options = new HashMap<>(map);
Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName()); Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName()); String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName());
String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName()); String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
if (options.isEmpty() == false) { if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet()); throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
} }
return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField); return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, featureImportance);
} }
private static final ConstructingObjectParser<ClassificationConfig, Void> PARSER = private static final ConstructingObjectParser<ClassificationConfig, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ClassificationConfig( new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ClassificationConfig(
(Integer) args[0], (String) args[1], (String) args[2])); (Integer) args[0], (String) args[1], (String) args[2], (Integer) args[3]));
static { static {
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD); PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD); PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD);
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
} }
public static ClassificationConfig fromXContent(XContentParser parser) { public static ClassificationConfig fromXContent(XContentParser parser) {
@ -65,19 +71,33 @@ public class ClassificationConfig implements InferenceConfig {
} }
public ClassificationConfig(Integer numTopClasses) { public ClassificationConfig(Integer numTopClasses) {
this(numTopClasses, null, null); this(numTopClasses, null, null, null);
} }
public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField) { public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField) {
this(numTopClasses, resultsField, topClassesResultsField, 0);
}
public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField, Integer featureImportance) {
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
this.topClassesResultsField = topClassesResultsField == null ? DEFAULT_TOP_CLASSES_RESULTS_FIELD : topClassesResultsField; this.topClassesResultsField = topClassesResultsField == null ? DEFAULT_TOP_CLASSES_RESULTS_FIELD : topClassesResultsField;
this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField; this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
if (featureImportance != null && featureImportance < 0) {
throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() +
"] must be greater than or equal to 0");
}
this.numTopFeatureImportanceValues = featureImportance == null ? 0 : featureImportance;
} }
public ClassificationConfig(StreamInput in) throws IOException { public ClassificationConfig(StreamInput in) throws IOException {
this.numTopClasses = in.readInt(); this.numTopClasses = in.readInt();
this.topClassesResultsField = in.readString(); this.topClassesResultsField = in.readString();
this.resultsField = in.readString(); this.resultsField = in.readString();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.numTopFeatureImportanceValues = in.readVInt();
} else {
this.numTopFeatureImportanceValues = 0;
}
} }
public int getNumTopClasses() { public int getNumTopClasses() {
@ -92,11 +112,23 @@ public class ClassificationConfig implements InferenceConfig {
return resultsField; return resultsField;
} }
public int getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues;
}
@Override
public boolean requestingImportance() {
return numTopFeatureImportanceValues > 0;
}
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeInt(numTopClasses); out.writeInt(numTopClasses);
out.writeString(topClassesResultsField); out.writeString(topClassesResultsField);
out.writeString(resultsField); out.writeString(resultsField);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeVInt(numTopFeatureImportanceValues);
}
} }
@Override @Override
@ -104,14 +136,15 @@ public class ClassificationConfig implements InferenceConfig {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
ClassificationConfig that = (ClassificationConfig) o; ClassificationConfig that = (ClassificationConfig) o;
return Objects.equals(numTopClasses, that.numTopClasses) && return Objects.equals(numTopClasses, that.numTopClasses)
Objects.equals(topClassesResultsField, that.topClassesResultsField) && && Objects.equals(topClassesResultsField, that.topClassesResultsField)
Objects.equals(resultsField, that.resultsField); && Objects.equals(resultsField, that.resultsField)
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(numTopClasses, topClassesResultsField, resultsField); return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues);
} }
@Override @Override
@ -122,6 +155,9 @@ public class ClassificationConfig implements InferenceConfig {
} }
builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField); builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField);
builder.field(RESULTS_FIELD.getPreferredName(), resultsField); builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
if (numTopFeatureImportanceValues > 0) {
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -143,7 +179,7 @@ public class ClassificationConfig implements InferenceConfig {
@Override @Override
public Version getMinimalSupportedVersion() { public Version getMinimalSupportedVersion() {
return MIN_SUPPORTED_VERSION; return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION;
} }
} }

View File

@ -18,4 +18,8 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable {
* All nodes in the cluster must be at least this version * All nodes in the cluster must be at least this version
*/ */
Version getMinimalSupportedVersion(); Version getMinimalSupportedVersion();
default boolean requestingImportance() {
return false;
}
} }

View File

@ -13,7 +13,9 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
@ -98,4 +100,19 @@ public final class InferenceHelpers {
} }
return null; return null;
} }
public static Map<String, Double> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
Map<String, Double> featureImportances) {
if (processedFeatureToOriginalFeatureMap == null || processedFeatureToOriginalFeatureMap.isEmpty()) {
return featureImportances;
}
Map<String, Double> originalFeatureImportance = new HashMap<>();
featureImportances.forEach((feature, importance) -> {
String featureName = processedFeatureToOriginalFeatureMap.getOrDefault(feature, feature);
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : v1 + importance);
});
return originalFeatureImportance;
}
} }

View File

@ -16,9 +16,12 @@ import java.io.IOException;
*/ */
public class NullInferenceConfig implements InferenceConfig { public class NullInferenceConfig implements InferenceConfig {
public static final NullInferenceConfig INSTANCE = new NullInferenceConfig(); private final boolean requestingFeatureImportance;
private NullInferenceConfig() { }
public NullInferenceConfig(boolean requestingFeatureImportance) {
this.requestingFeatureImportance = requestingFeatureImportance;
}
@Override @Override
public boolean isTargetTypeSupported(TargetType targetType) { public boolean isTargetTypeSupported(TargetType targetType) {
@ -37,6 +40,7 @@ public class NullInferenceConfig implements InferenceConfig {
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
throw new UnsupportedOperationException("Unable to serialize NullInferenceConfig objects");
} }
@Override @Override
@ -46,6 +50,11 @@ public class NullInferenceConfig implements InferenceConfig {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder; throw new UnsupportedOperationException("Unable to write xcontent from NullInferenceConfig objects");
}
@Override
public boolean requestingImportance() {
return requestingFeatureImportance;
} }
} }

View File

@ -26,24 +26,27 @@ public class RegressionConfig implements InferenceConfig {
public static final ParseField NAME = new ParseField("regression"); public static final ParseField NAME = new ParseField("regression");
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0; private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
public static final ParseField RESULTS_FIELD = new ParseField("results_field"); public static final ParseField RESULTS_FIELD = new ParseField("results_field");
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
private static final String DEFAULT_RESULTS_FIELD = "predicted_value"; private static final String DEFAULT_RESULTS_FIELD = "predicted_value";
public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD); public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD, null);
public static RegressionConfig fromMap(Map<String, Object> map) { public static RegressionConfig fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map); Map<String, Object> options = new HashMap<>(map);
String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName()); String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
if (options.isEmpty() == false) { if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet()); throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
} }
return new RegressionConfig(resultsField); return new RegressionConfig(resultsField, featureImportance);
} }
private static final ConstructingObjectParser<RegressionConfig, Void> PARSER = private static final ConstructingObjectParser<RegressionConfig, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0])); new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0], (Integer)args[1]));
static { static {
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD); PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
} }
public static RegressionConfig fromXContent(XContentParser parser) { public static RegressionConfig fromXContent(XContentParser parser) {
@ -51,19 +54,43 @@ public class RegressionConfig implements InferenceConfig {
} }
private final String resultsField; private final String resultsField;
private final int numTopFeatureImportanceValues;
public RegressionConfig(String resultsField) { public RegressionConfig(String resultsField) {
this(resultsField, 0);
}
public RegressionConfig(String resultsField, Integer numTopFeatureImportanceValues) {
this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField; this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) {
throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() +
"] must be greater than or equal to 0");
}
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues == null ? 0 : numTopFeatureImportanceValues;
} }
public RegressionConfig(StreamInput in) throws IOException { public RegressionConfig(StreamInput in) throws IOException {
this.resultsField = in.readString(); this.resultsField = in.readString();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.numTopFeatureImportanceValues = in.readVInt();
} else {
this.numTopFeatureImportanceValues = 0;
}
}
public int getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues;
} }
public String getResultsField() { public String getResultsField() {
return resultsField; return resultsField;
} }
@Override
public boolean requestingImportance() {
return numTopFeatureImportanceValues > 0;
}
@Override @Override
public String getWriteableName() { public String getWriteableName() {
return NAME.getPreferredName(); return NAME.getPreferredName();
@ -72,6 +99,9 @@ public class RegressionConfig implements InferenceConfig {
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeString(resultsField); out.writeString(resultsField);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeVInt(numTopFeatureImportanceValues);
}
} }
@Override @Override
@ -83,6 +113,9 @@ public class RegressionConfig implements InferenceConfig {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(RESULTS_FIELD.getPreferredName(), resultsField); builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
if (numTopFeatureImportanceValues > 0) {
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -92,12 +125,13 @@ public class RegressionConfig implements InferenceConfig {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
RegressionConfig that = (RegressionConfig)o; RegressionConfig that = (RegressionConfig)o;
return Objects.equals(this.resultsField, that.resultsField); return Objects.equals(this.resultsField, that.resultsField)
&& Objects.equals(this.numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(resultsField); return Objects.hash(resultsField, numTopFeatureImportanceValues);
} }
@Override @Override
@ -107,7 +141,7 @@ public class RegressionConfig implements InferenceConfig {
@Override @Override
public Version getMinimalSupportedVersion() { public Version getMinimalSupportedVersion() {
return MIN_SUPPORTED_VERSION; return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION;
} }
} }

View File

@ -0,0 +1,162 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
/**
* Ported from https://github.com/elastic/ml-cpp/blob/master/include/maths/CTreeShapFeatureImportance.h Path struct
*/
public class ShapPath {
private static final double DBL_EPSILON = Double.MIN_VALUE;
private final PathElement[] pathElements;
private final double[] scale;
private final int elementAndScaleOffset;
public ShapPath(ShapPath parentPath, int nextIndex) {
this.elementAndScaleOffset = parentPath.elementAndScaleOffset + nextIndex;
this.pathElements = parentPath.pathElements;
this.scale = parentPath.scale;
for (int i = 0; i < nextIndex; i++) {
pathElements[elementAndScaleOffset + i].featureIndex = parentPath.getElement(i).featureIndex;
pathElements[elementAndScaleOffset + i].fractionZeros = parentPath.getElement(i).fractionZeros;
pathElements[elementAndScaleOffset + i].fractionOnes = parentPath.getElement(i).fractionOnes;
scale[elementAndScaleOffset + i] = parentPath.getScale(i);
}
}
public ShapPath(PathElement[] elements, double[] scale) {
this.pathElements = elements;
this.scale = scale;
this.elementAndScaleOffset = 0;
}
// Update binomial coefficients to be able to compute Equation (2) from the paper. In particular,
// we have in the line path.scale[i + 1] += fractionOne * path.scale[i] * (i + 1.0) / (pathDepth +
// 1.0) that if we're on the "one" path, i.e. if the last feature selects this path if we include that
// feature in S (then fractionOne is 1), and we need to consider all the additional ways we now have of
// constructing each S of each given cardinality i + 1. Each of these come by adding the last feature
// to sets of size i and we **also** need to scale by the difference in binomial coefficients as both M
// increases by one and i increases by one. So we get additive term 1{last feature selects path if in S}
// * scale(i) * (i+1)! (M+1-(i+1)-1)!/(M+1)! / (i! (M-i-1)!/ M!), whence += scale(i) * (i+1) / (M+1).
public int extend(double fractionZero, double fractionOne, int featureIndex, int nextIndex) {
setValues(nextIndex, fractionOne, fractionZero, featureIndex);
setScale(nextIndex, nextIndex == 0 ? 1.0 : 0.0);
double stepDown = fractionOne / (double)(nextIndex + 1);
double stepUp = fractionZero / (double)(nextIndex + 1);
double countDown = nextIndex * stepDown;
double countUp = stepUp;
for (int i = (nextIndex - 1); i >= 0; --i, countDown -= stepDown, countUp += stepUp) {
setScale(i + 1, getScale(i + 1) + getScale(i) * countDown);
setScale(i, getScale(i) * countUp);
}
return nextIndex + 1;
}
public double sumUnwoundPath(int pathIndex, int nextIndex) {
double total = 0.0;
int pathDepth = nextIndex - 1;
double nextFractionOne = getScale(pathDepth);
double fractionOne = fractionOnes(pathIndex);
double fractionZero = fractionZeros(pathIndex);
if (fractionOne != 0) {
double pD = pathDepth + 1;
double stepUp = fractionZero / pD;
double stepDown = fractionOne / pD;
double countUp = stepUp;
double countDown = (pD - 1.0) * stepDown;
for (int i = pathDepth - 1; i >= 0; --i, countUp += stepUp, countDown -= stepDown) {
double tmp = nextFractionOne / countDown;
nextFractionOne = getScale(i) - tmp * countUp;
total += tmp;
}
} else {
double pD = pathDepth;
for(int i = 0; i < pathDepth; i++) {
total += getScale(i) / pD--;
}
total *= (pathDepth + 1) / (fractionZero + DBL_EPSILON);
}
return total;
}
public int unwind(int pathIndex, int nextIndex) {
int pathDepth = nextIndex - 1;
double nextFractionOne = getScale(pathDepth);
double fractionOne = fractionOnes(pathIndex);
double fractionZero = fractionZeros(pathIndex);
if (fractionOne != 0) {
double stepUp = fractionZero / (double)(pathDepth + 1);
double stepDown = fractionOne / (double)nextIndex;
double countUp = 0.0;
double countDown = nextIndex * stepDown;
for (int i = pathDepth; i >= 0; --i, countUp += stepUp, countDown -= stepDown) {
double tmp = nextFractionOne / countDown;
nextFractionOne = getScale(i) - tmp * countUp;
setScale(i, tmp);
}
} else {
double stepDown = (fractionZero + DBL_EPSILON) / (double)(pathDepth + 1);
double countDown = pathDepth * stepDown;
for (int i = 0; i <= pathDepth; ++i, countDown -= stepDown) {
setScale(i, getScale(i) / countDown);
}
}
for (int i = pathIndex; i < pathDepth; ++i) {
PathElement element = getElement(i + 1);
setValues(i, element.fractionOnes, element.fractionZeros, element.featureIndex);
}
return nextIndex - 1;
}
private void setValues(int index, double fractionOnes, double fractionZeros, int featureIndex) {
pathElements[index + elementAndScaleOffset].fractionOnes = fractionOnes;
pathElements[index + elementAndScaleOffset].fractionZeros = fractionZeros;
pathElements[index + elementAndScaleOffset].featureIndex = featureIndex;
}
private double getScale(int offset) {
return scale[offset + elementAndScaleOffset];
}
private void setScale(int offset, double value) {
scale[offset + elementAndScaleOffset] = value;
}
public double fractionOnes(int pathIndex) {
return pathElements[pathIndex + elementAndScaleOffset].fractionOnes;
}
public double fractionZeros(int pathIndex) {
return pathElements[pathIndex + elementAndScaleOffset].fractionZeros;
}
public int findFeatureIndex(int splitFeature, int nextIndex) {
for (int i = elementAndScaleOffset; i < elementAndScaleOffset + nextIndex; i++) {
if (pathElements[i].featureIndex == splitFeature) {
return i - elementAndScaleOffset;
}
}
return -1;
}
public int featureIndex(int pathIndex) {
return pathElements[pathIndex + elementAndScaleOffset].featureIndex;
}
private PathElement getElement(int offset) {
return pathElements[offset + elementAndScaleOffset];
}
public static final class PathElement {
private double fractionOnes = 1.0;
private double fractionZeros = 1.0;
private int featureIndex = -1;
}
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel; package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
@ -17,12 +18,16 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
/** /**
* Infer against the provided fields * Infer against the provided fields
* *
* NOTE: Must be thread safe
*
* @param fields The fields and their values to infer against * @param fields The fields and their values to infer against
* @param config The configuration options for inference * @param config The configuration options for inference
* @param featureDecoderMap A map for decoding feature value names to their originating feature.
* Necessary for feature influence.
* @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0). * @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0).
* For regression this is continuous. * For regression this is continuous.
*/ */
InferenceResults infer(Map<String, Object> fields, InferenceConfig config); InferenceResults infer(Map<String, Object> fields, InferenceConfig config, @Nullable Map<String, String> featureDecoderMap);
/** /**
* @return {@link TargetType} for the model. * @return {@link TargetType} for the model.
@ -42,4 +47,19 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
* @return The estimated number of operations required at inference time * @return The estimated number of operations required at inference time
*/ */
long estimatedNumOperations(); long estimatedNumOperations();
/**
* @return Does the model support feature importance
*/
boolean supportsFeatureImportance();
/**
* Calculates the importance of each feature reference by the model for the passed in field values
*
* NOTE: Must be thread safe
* @param fields The fields inferring against
* @param featureDecoder A Map translating processed feature names to their original feature names
* @return A {@code Map<String, Double>} mapping each featureName to its importance
*/
Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
} }

View File

@ -37,6 +37,7 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -133,18 +134,25 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
} }
@Override @Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) { public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
if (config.isTargetTypeSupported(targetType) == false) { if (config.isTargetTypeSupported(targetType) == false) {
throw ExceptionsHelper.badRequestException( throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
} }
List<Double> inferenceResults = this.models.stream().map(model -> { List<Double> inferenceResults = new ArrayList<>(this.models.size());
InferenceResults results = model.infer(fields, NullInferenceConfig.INSTANCE); List<Map<String, Double>> featureInfluence = new ArrayList<>();
assert results instanceof SingleValueInferenceResults; NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
return ((SingleValueInferenceResults)results).value(); this.models.forEach(model -> {
}).collect(Collectors.toList()); InferenceResults result = model.infer(fields, subModelInferenceConfig, Collections.emptyMap());
assert result instanceof SingleValueInferenceResults;
SingleValueInferenceResults inferenceResult = (SingleValueInferenceResults) result;
inferenceResults.add(inferenceResult.value());
if (config.requestingImportance()) {
featureInfluence.add(inferenceResult.getFeatureImportance());
}
});
List<Double> processed = outputAggregator.processValues(inferenceResults); List<Double> processed = outputAggregator.processValues(inferenceResults);
return buildResults(processed, config); return buildResults(processed, featureInfluence, config, featureDecoderMap);
} }
@Override @Override
@ -152,14 +160,20 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return targetType; return targetType;
} }
private InferenceResults buildResults(List<Double> processedInferences, InferenceConfig config) { private InferenceResults buildResults(List<Double> processedInferences,
List<Map<String, Double>> featureInfluence,
InferenceConfig config,
Map<String, String> featureDecoderMap) {
// Indicates that the config is useless and the caller just wants the raw value // Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) { if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(outputAggregator.aggregate(processedInferences)); return new RawInferenceResults(outputAggregator.aggregate(processedInferences),
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
} }
switch(targetType) { switch(targetType) {
case REGRESSION: case REGRESSION:
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences), config); return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
config,
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
case CLASSIFICATION: case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config; ClassificationConfig classificationConfig = (ClassificationConfig) config;
assert classificationWeights == null || processedInferences.size() == classificationWeights.length; assert classificationWeights == null || processedInferences.size() == classificationWeights.length;
@ -172,6 +186,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return new ClassificationInferenceResults((double)topClasses.v1(), return new ClassificationInferenceResults((double)topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels), classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(), topClasses.v2(),
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)),
config); config);
default: default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model"); throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
@ -293,6 +308,27 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1); return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1);
} }
@Override
public boolean supportsFeatureImportance() {
return models.stream().allMatch(TrainedModel::supportsFeatureImportance);
}
Map<String, Double> featureImportance(Map<String, Object> fields) {
return featureImportance(fields, Collections.emptyMap());
}
@Override
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
Map<String, Double> collapsed = mergeFeatureImportances(models.stream()
.map(trainedModel -> trainedModel.featureImportance(fields, Collections.emptyMap()))
.collect(Collectors.toList()));
return InferenceHelpers.decodeFeatureImportances(featureDecoder, collapsed);
}
private static Map<String, Double> mergeFeatureImportances(List<Map<String, Double>> featureImportances) {
return featureImportances.stream().collect(HashMap::new, (a, b) -> b.forEach((k, v) -> a.merge(k, v, Double::sum)), Map::putAll);
}
public static Builder builder() { public static Builder builder() {
return new Builder(); return new Builder();
} }

View File

@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -104,7 +105,11 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
} }
@Override @Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) { public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
if (config.requestingImportance()) {
throw ExceptionsHelper.badRequestException("[{}] model does not supports feature importance",
NAME.getPreferredName());
}
if (config instanceof ClassificationConfig == false) { if (config instanceof ClassificationConfig == false) {
throw ExceptionsHelper.badRequestException("[{}] model only supports classification", throw ExceptionsHelper.badRequestException("[{}] model only supports classification",
NAME.getPreferredName()); NAME.getPreferredName());
@ -138,6 +143,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
return new ClassificationInferenceResults(topClasses.v1(), return new ClassificationInferenceResults(topClasses.v1(),
LANGUAGE_NAMES.get(topClasses.v1()), LANGUAGE_NAMES.get(topClasses.v1()),
topClasses.v2(), topClasses.v2(),
Collections.emptyMap(),
classificationConfig); classificationConfig);
} }
@ -159,6 +165,16 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
return numOps; return numOps;
} }
@Override
public boolean supportsFeatureImportance() {
return false;
}
@Override
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
throw new UnsupportedOperationException("[lang_ident] does not support feature importance");
}
@Override @Override
public long ramBytesUsed() { public long ramBytesUsed() {
long size = SHALLOW_SIZE; long size = SHALLOW_SIZE;

View File

@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -44,6 +45,7 @@ import java.util.Objects;
import java.util.Queue; import java.util.Queue;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
@ -86,6 +88,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
private final TargetType targetType; private final TargetType targetType;
private final List<String> classificationLabels; private final List<String> classificationLabels;
private final CachedSupplier<Double> highestOrderCategory; private final CachedSupplier<Double> highestOrderCategory;
// populated lazily when feature importance is calculated
private double[] nodeEstimates;
private Integer maxDepth;
Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) { Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES)); this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
@ -120,7 +125,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
} }
@Override @Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) { public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
if (config.isTargetTypeSupported(targetType) == false) { if (config.isTargetTypeSupported(targetType) == false) {
throw ExceptionsHelper.badRequestException( throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
@ -129,21 +134,23 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
List<Double> features = featureNames.stream() List<Double> features = featureNames.stream()
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields))) .map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
.collect(Collectors.toList()); .collect(Collectors.toList());
return infer(features, config);
}
private InferenceResults infer(List<Double> features, InferenceConfig config) { Map<String, Double> featureImportance = config.requestingImportance() ?
featureImportance(features, featureDecoderMap) :
Collections.emptyMap();
TreeNode node = nodes.get(0); TreeNode node = nodes.get(0);
while(node.isLeaf() == false) { while(node.isLeaf() == false) {
node = nodes.get(node.compare(features)); node = nodes.get(node.compare(features));
} }
return buildResult(node.getLeafValue(), config);
return buildResult(node.getLeafValue(), featureImportance, config);
} }
private InferenceResults buildResult(Double value, InferenceConfig config) { private InferenceResults buildResult(Double value, Map<String, Double> featureImportance, InferenceConfig config) {
// Indicates that the config is useless and the caller just wants the raw value // Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) { if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(value); return new RawInferenceResults(value, featureImportance);
} }
switch (targetType) { switch (targetType) {
case CLASSIFICATION: case CLASSIFICATION:
@ -156,9 +163,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return new ClassificationInferenceResults(value, return new ClassificationInferenceResults(value,
classificationLabel(topClasses.v1(), classificationLabels), classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(), topClasses.v2(),
featureImportance,
config); config);
case REGRESSION: case REGRESSION:
return new RegressionInferenceResults(value, config); return new RegressionInferenceResults(value, config, featureImportance);
default: default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model"); throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
} }
@ -192,7 +200,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
// If we are classification, we should assume that the largest leaf value is whole. // If we are classification, we should assume that the largest leaf value is whole.
assert maxCategory == Math.rint(maxCategory); assert maxCategory == Math.rint(maxCategory);
List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)); List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
// TODO, eventually have TreeNodes contain confidence levels
list.set(Double.valueOf(inferenceValue).intValue(), 1.0); list.set(Double.valueOf(inferenceValue).intValue(), 1.0);
return list; return list;
} }
@ -263,12 +270,138 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
detectCycle(); detectCycle();
} }
@Override
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
if (nodes.stream().allMatch(n -> n.getNumberSamples() == 0)) {
throw ExceptionsHelper.badRequestException("[tree_structure.number_samples] must be greater than zero for feature importance");
}
List<Double> features = featureNames.stream()
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
.collect(Collectors.toList());
return featureImportance(features, featureDecoder);
}
private Map<String, Double> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
calculateNodeEstimatesIfNeeded();
double[] featureImportance = new double[fieldValues.size()];
int arrSize = ((this.maxDepth + 1) * (this.maxDepth + 2))/2;
ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize];
for (int i = 0; i < arrSize; i++) {
elements[i] = new ShapPath.PathElement();
}
double[] scale = new double[arrSize];
ShapPath initialPath = new ShapPath(elements, scale);
shapRecursive(fieldValues, this.nodeEstimates, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
return InferenceHelpers.decodeFeatureImportances(featureDecoder,
IntStream.range(0, featureImportance.length)
.boxed()
.collect(Collectors.toMap(featureNames::get, i -> featureImportance[i])));
}
private void calculateNodeEstimatesIfNeeded() {
if (this.nodeEstimates != null && this.maxDepth != null) {
return;
}
synchronized (this) {
if (this.nodeEstimates != null && this.maxDepth != null) {
return;
}
double[] estimates = new double[nodes.size()];
this.maxDepth = fillNodeEstimates(estimates, 0, 0);
this.nodeEstimates = estimates;
}
}
/**
* Note, this is a port from https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc
*
* If improvements in performance or accuracy have been found, it is probably best that the changes are implemented on the native
* side first and then ported to the Java side.
*/
private void shapRecursive(List<Double> processedFeatures,
double[] nodeValues,
ShapPath parentSplitPath,
int nodeIndex,
double parentFractionZero,
double parentFractionOne,
int parentFeatureIndex,
double[] featureImportance,
int nextIndex) {
ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex);
TreeNode currNode = nodes.get(nodeIndex);
nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
if (currNode.isLeaf()) {
// TODO multi-value????
double leafValue = nodeValues[nodeIndex];
for (int i = 1; i < nextIndex; ++i) {
double scale = splitPath.sumUnwoundPath(i, nextIndex);
int inputColumnIndex = splitPath.featureIndex(i);
featureImportance[inputColumnIndex] += scale * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i)) * leafValue;
}
} else {
int hotIndex = currNode.compare(processedFeatures);
int coldIndex = hotIndex == currNode.getLeftChild() ? currNode.getRightChild() : currNode.getLeftChild();
double incomingFractionZero = 1.0;
double incomingFractionOne = 1.0;
int splitFeature = currNode.getSplitFeature();
int pathIndex = splitPath.findFeatureIndex(splitFeature, nextIndex);
if (pathIndex > -1) {
incomingFractionZero = splitPath.fractionZeros(pathIndex);
incomingFractionOne = splitPath.fractionOnes(pathIndex);
nextIndex = splitPath.unwind(pathIndex, nextIndex);
}
double hotFractionZero = nodes.get(hotIndex).getNumberSamples() / (double)currNode.getNumberSamples();
double coldFractionZero = nodes.get(coldIndex).getNumberSamples() / (double)currNode.getNumberSamples();
shapRecursive(processedFeatures, nodeValues, splitPath,
hotIndex, incomingFractionZero * hotFractionZero,
incomingFractionOne, splitFeature, featureImportance, nextIndex);
shapRecursive(processedFeatures, nodeValues, splitPath,
coldIndex, incomingFractionZero * coldFractionZero,
0.0, splitFeature, featureImportance, nextIndex);
}
}
/**
* This recursively populates the provided {@code double[]} with the node estimated values
*
* Used when calculating feature importance.
* @param nodeEstimates Array to update in place with the node estimated values
* @param nodeIndex Current node index
* @param depth Current depth
* @return The current max depth
*/
private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) {
TreeNode node = nodes.get(nodeIndex);
if (node.isLeaf()) {
nodeEstimates[nodeIndex] = node.getLeafValue();
return 0;
}
int depthLeft = fillNodeEstimates(nodeEstimates, node.getLeftChild(), depth + 1);
int depthRight = fillNodeEstimates(nodeEstimates, node.getRightChild(), depth + 1);
long leftWeight = nodes.get(node.getLeftChild()).getNumberSamples();
long rightWeight = nodes.get(node.getRightChild()).getNumberSamples();
long divisor = leftWeight + rightWeight;
double averageValue = divisor == 0 ?
0.0 :
(leftWeight * nodeEstimates[node.getLeftChild()] + rightWeight * nodeEstimates[node.getRightChild()]) / divisor;
nodeEstimates[nodeIndex] = averageValue;
return Math.max(depthLeft, depthRight) + 1;
}
@Override @Override
public long estimatedNumOperations() { public long estimatedNumOperations() {
// Grabbing the features from the doc + the depth of the tree // Grabbing the features from the doc + the depth of the tree
return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size(); return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
} }
@Override
public boolean supportsFeatureImportance() {
return true;
}
/** /**
* The highest index of a feature used any of the nodes. * The highest index of a feature used any of the nodes.
* If no nodes use a feature return -1. This can only happen * If no nodes use a feature return -1. This can only happen

View File

@ -342,8 +342,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
}}; }};
assertThat( assertThat(
((ClassificationInferenceResults)definition.getTrainedModel() ((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(), .getClassificationLabel(),
equalTo("Iris-setosa")); equalTo("Iris-setosa"));
@ -354,8 +353,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
put("petal_width", 1.4); put("petal_width", 1.4);
}}; }};
assertThat( assertThat(
((ClassificationInferenceResults)definition.getTrainedModel() ((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(), .getClassificationLabel(),
equalTo("Iris-versicolor")); equalTo("Iris-versicolor"));
@ -366,10 +364,8 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
put("petal_width", 2.0); put("petal_width", 2.0);
}}; }};
assertThat( assertThat(
((ClassificationInferenceResults)definition.getTrainedModel() ((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(), .getClassificationLabel(),
equalTo("Iris-virginica")); equalTo("Iris-virginica"));
} }
} }

View File

@ -8,10 +8,12 @@ package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.test.AbstractWireSerializingTestCase;
import java.util.Collections;
public class RawInferenceResultsTests extends AbstractWireSerializingTestCase<RawInferenceResults> { public class RawInferenceResultsTests extends AbstractWireSerializingTestCase<RawInferenceResults> {
public static RawInferenceResults createRandomResults() { public static RawInferenceResults createRandomResults() {
return new RawInferenceResults(randomDouble()); return new RawInferenceResults(randomDouble(), randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08));
} }
@Override @Override

View File

@ -30,11 +30,12 @@ public class ClassificationConfigTests extends AbstractSerializingTestCase<Class
ClassificationConfig expected = ClassificationConfig.EMPTY_PARAMS; ClassificationConfig expected = ClassificationConfig.EMPTY_PARAMS;
assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected)); assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected));
expected = new ClassificationConfig(3, "foo", "bar"); expected = new ClassificationConfig(3, "foo", "bar", 2);
Map<String, Object> configMap = new HashMap<>(); Map<String, Object> configMap = new HashMap<>();
configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3); configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo"); configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo");
configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar"); configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar");
configMap.put(ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 2);
assertThat(ClassificationConfig.fromMap(configMap), equalTo(expected)); assertThat(ClassificationConfig.fromMap(configMap), equalTo(expected));
} }

View File

@ -24,9 +24,10 @@ public class RegressionConfigTests extends AbstractSerializingTestCase<Regressio
} }
public void testFromMap() { public void testFromMap() {
RegressionConfig expected = new RegressionConfig("foo"); RegressionConfig expected = new RegressionConfig("foo", 3);
Map<String, Object> config = new HashMap<String, Object>(){{ Map<String, Object> config = new HashMap<String, Object>(){{
put(RegressionConfig.RESULTS_FIELD.getPreferredName(), "foo"); put(RegressionConfig.RESULTS_FIELD.getPreferredName(), "foo");
put(RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 3);
}}; }};
assertThat(RegressionConfig.fromMap(config), equalTo(expected)); assertThat(RegressionConfig.fromMap(config), equalTo(expected));
} }

View File

@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import org.junit.Before; import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -39,6 +40,7 @@ import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> { public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
private final double eps = 1.0E-8;
private boolean lenient; private boolean lenient;
@ -267,7 +269,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
List<Double> scores = Arrays.asList(0.230557435, 0.162032651); List<Double> scores = Arrays.asList(0.230557435, 0.162032651);
double eps = 0.000001; double eps = 0.000001;
List<ClassificationInferenceResults.TopClassEntry> probabilities = List<ClassificationInferenceResults.TopClassEntry> probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) { for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps)); assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
@ -278,7 +281,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
expected = Arrays.asList(0.310025518, 0.6899744811); expected = Arrays.asList(0.310025518, 0.6899744811);
scores = Arrays.asList(0.217017863, 0.2069923443); scores = Arrays.asList(0.217017863, 0.2069923443);
probabilities = probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) { for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps)); assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
@ -289,7 +293,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
expected = Arrays.asList(0.768524783, 0.231475216); expected = Arrays.asList(0.768524783, 0.231475216);
scores = Arrays.asList(0.230557435, 0.162032651); scores = Arrays.asList(0.230557435, 0.162032651);
probabilities = probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) { for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps)); assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
@ -303,7 +308,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
expected = Arrays.asList(0.6899744811, 0.3100255188); expected = Arrays.asList(0.6899744811, 0.3100255188);
scores = Arrays.asList(0.482982136, 0.0930076556); scores = Arrays.asList(0.482982136, 0.0930076556);
probabilities = probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) { for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps)); assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
@ -361,24 +367,28 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
List<Double> featureVector = Arrays.asList(0.4, 0.0); List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0, assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7); featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector); featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0, assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(0.0, 1.0); featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector); featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0, assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureMap = new HashMap<String, Object>(2) {{ featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3); put("foo", 0.3);
put("bar", null); put("bar", null);
}}; }};
assertThat(0.0, assertThat(0.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
} }
public void testMultiClassClassificationInference() { public void testMultiClassClassificationInference() {
@ -432,24 +442,28 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
List<Double> featureVector = Arrays.asList(0.4, 0.0); List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(2.0, assertThat(2.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7); featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector); featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0, assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(0.0, 1.0); featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector); featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0, assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureMap = new HashMap<String, Object>(2) {{ featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.6); put("foo", 0.6);
put("bar", null); put("bar", null);
}}; }};
assertThat(1.0, assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
} }
public void testRegressionInference() { public void testRegressionInference() {
@ -489,12 +503,16 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
List<Double> featureVector = Arrays.asList(0.4, 0.0); List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.9, assertThat(0.9,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7); featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector); featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.5, assertThat(0.5,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
// Test with NO aggregator supplied, verifies default behavior of non-weighted sum // Test with NO aggregator supplied, verifies default behavior of non-weighted sum
ensemble = Ensemble.builder() ensemble = Ensemble.builder()
@ -506,19 +524,25 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
featureVector = Arrays.asList(0.4, 0.0); featureVector = Arrays.asList(0.4, 0.0);
featureMap = zipObjMap(featureNames, featureVector); featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.8, assertThat(1.8,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7); featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector); featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0, assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureMap = new HashMap<String, Object>(2) {{ featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3); put("foo", 0.3);
put("bar", null); put("bar", null);
}}; }};
assertThat(1.8, assertThat(1.8,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
} }
public void testInferNestedFields() { public void testInferNestedFields() {
@ -564,7 +588,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
}}); }});
}}; }};
assertThat(0.9, assertThat(0.9,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureMap = new HashMap<String, Object>() {{ featureMap = new HashMap<String, Object>() {{
put("foo", new HashMap<String, Object>(){{ put("foo", new HashMap<String, Object>(){{
@ -575,7 +601,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
}}); }});
}}; }};
assertThat(0.5, assertThat(0.5,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
} }
public void testOperationsEstimations() { public void testOperationsEstimations() {
@ -590,6 +618,114 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
assertThat(ensemble.estimatedNumOperations(), equalTo(9L)); assertThat(ensemble.estimatedNumOperations(), equalTo(9L));
} }
public void testFeatureImportance() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(
TreeNode.builder(0)
.setSplitFeature(0)
.setOperator(Operator.LT)
.setLeftChild(1)
.setRightChild(2)
.setThreshold(0.55)
.setNumberSamples(10L),
TreeNode.builder(1)
.setSplitFeature(0)
.setLeftChild(3)
.setRightChild(4)
.setOperator(Operator.LT)
.setThreshold(0.41)
.setNumberSamples(6L),
TreeNode.builder(2)
.setSplitFeature(1)
.setLeftChild(5)
.setRightChild(6)
.setOperator(Operator.LT)
.setThreshold(0.25)
.setNumberSamples(4L),
TreeNode.builder(3).setLeafValue(1.18230136).setNumberSamples(5L),
TreeNode.builder(4).setLeafValue(1.98006658).setNumberSamples(1L),
TreeNode.builder(5).setLeafValue(3.25350885).setNumberSamples(3L),
TreeNode.builder(6).setLeafValue(2.42384369).setNumberSamples(1L)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(
TreeNode.builder(0)
.setSplitFeature(0)
.setOperator(Operator.LT)
.setLeftChild(1)
.setRightChild(2)
.setThreshold(0.45)
.setNumberSamples(10L),
TreeNode.builder(1)
.setSplitFeature(0)
.setLeftChild(3)
.setRightChild(4)
.setOperator(Operator.LT)
.setThreshold(0.25)
.setNumberSamples(5L),
TreeNode.builder(2)
.setSplitFeature(0)
.setLeftChild(5)
.setRightChild(6)
.setOperator(Operator.LT)
.setThreshold(0.59)
.setNumberSamples(5L),
TreeNode.builder(3).setLeafValue(1.04476388).setNumberSamples(3L),
TreeNode.builder(4).setLeafValue(1.52799228).setNumberSamples(2L),
TreeNode.builder(5).setLeafValue(1.98006658).setNumberSamples(1L),
TreeNode.builder(6).setLeafValue(2.950216).setNumberSamples(4L)).build();
Ensemble ensemble = Ensemble.builder().setOutputAggregator(new WeightedSum())
.setTrainedModels(Arrays.asList(tree1, tree2))
.setFeatureNames(featureNames)
.build();
Map<String, Double> featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.0, 0.9)));
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.1, 0.8)));
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.2, 0.7)));
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.3, 0.6)));
assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.4, 0.5)));
assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.5, 0.4)));
assertThat(featureImportance.get("foo"), closeTo(0.0798679, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.6, 0.3)));
assertThat(featureImportance.get("foo"), closeTo(1.80491886, eps));
assertThat(featureImportance.get("bar"), closeTo(-0.4355742, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.7, 0.2)));
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.8, 0.1)));
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.9, 0.0)));
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
}
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) { private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
} }

View File

@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceRes
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import org.junit.Before; import org.junit.Before;
import java.io.IOException; import java.io.IOException;
@ -35,6 +36,7 @@ import static org.hamcrest.Matchers.equalTo;
public class TreeTests extends AbstractSerializingTestCase<Tree> { public class TreeTests extends AbstractSerializingTestCase<Tree> {
private final double eps = 1.0E-8;
private boolean lenient; private boolean lenient;
@Before @Before
@ -118,7 +120,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
List<Double> featureVector = Arrays.asList(0.6, 0.0); List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump
assertThat(42.0, assertThat(42.0,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
} }
public void testInfer() { public void testInfer() {
@ -138,27 +141,31 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
List<Double> featureVector = Arrays.asList(0.6, 0.0); List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.3, assertThat(0.3,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the left child of the left child of the root node // This should hit the left child of the left child of the root node
// i.e. it takes the path left, left // i.e. it takes the path left, left
featureVector = Arrays.asList(0.3, 0.7); featureVector = Arrays.asList(0.3, 0.7);
featureMap = zipObjMap(featureNames, featureVector); featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.1, assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the right child of the left child of the root node // This should hit the right child of the left child of the root node
// i.e. it takes the path left, right // i.e. it takes the path left, right
featureVector = Arrays.asList(0.3, 0.9); featureVector = Arrays.asList(0.3, 0.9);
featureMap = zipObjMap(featureNames, featureVector); featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.2, assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should still work if the internal values are strings // This should still work if the internal values are strings
List<String> featureVectorStrings = Arrays.asList("0.3", "0.9"); List<String> featureVectorStrings = Arrays.asList("0.3", "0.9");
featureMap = zipObjMap(featureNames, featureVectorStrings); featureMap = zipObjMap(featureNames, featureVectorStrings);
assertThat(0.2, assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should handle missing values and take the default_left path // This should handle missing values and take the default_left path
featureMap = new HashMap<String, Object>(2) {{ featureMap = new HashMap<String, Object>(2) {{
@ -166,7 +173,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
put("bar", null); put("bar", null);
}}; }};
assertThat(0.1, assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
} }
public void testInferNestedFields() { public void testInferNestedFields() {
@ -192,7 +200,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
}}); }});
}}; }};
assertThat(0.3, assertThat(0.3,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the left child of the left child of the root node // This should hit the left child of the left child of the root node
// i.e. it takes the path left, left // i.e. it takes the path left, left
@ -205,7 +214,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
}}); }});
}}; }};
assertThat(0.1, assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the right child of the left child of the root node // This should hit the right child of the left child of the root node
// i.e. it takes the path left, right // i.e. it takes the path left, right
@ -218,7 +228,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
}}); }});
}}; }};
assertThat(0.2, assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
} }
public void testTreeClassificationProbability() { public void testTreeClassificationProbability() {
@ -241,7 +252,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
List<String> expectedFields = Arrays.asList("dog", "cat"); List<String> expectedFields = Arrays.asList("dog", "cat");
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
List<ClassificationInferenceResults.TopClassEntry> probabilities = List<ClassificationInferenceResults.TopClassEntry> probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) { for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
@ -252,7 +264,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
featureVector = Arrays.asList(0.3, 0.7); featureVector = Arrays.asList(0.3, 0.7);
featureMap = zipObjMap(featureNames, featureVector); featureMap = zipObjMap(featureNames, featureVector);
probabilities = probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) { for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
@ -264,7 +277,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
put("bar", null); put("bar", null);
}}; }};
probabilities = probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) { for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
@ -345,6 +359,55 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
assertThat(tree.estimatedNumOperations(), equalTo(7L)); assertThat(tree.estimatedNumOperations(), equalTo(7L));
} }
public void testFeatureImportance() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(
TreeNode.builder(0)
.setSplitFeature(0)
.setOperator(Operator.LT)
.setLeftChild(1)
.setRightChild(2)
.setThreshold(0.5)
.setNumberSamples(4L),
TreeNode.builder(1)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4)
.setOperator(Operator.LT)
.setThreshold(0.5)
.setNumberSamples(2L),
TreeNode.builder(2)
.setSplitFeature(1)
.setLeftChild(5)
.setRightChild(6)
.setOperator(Operator.LT)
.setThreshold(0.5)
.setNumberSamples(2L),
TreeNode.builder(3).setLeafValue(3.0).setNumberSamples(1L),
TreeNode.builder(4).setLeafValue(8.0).setNumberSamples(1L),
TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L),
TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build();
Map<String, Double> featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.25)),
Collections.emptyMap());
assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.75)), Collections.emptyMap());
assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.25)), Collections.emptyMap());
assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.75)), Collections.emptyMap());
assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
}
public void testMaxFeatureIndex() { public void testMaxFeatureIndex() {
int numFeatures = randomIntBetween(1, 15); int numFeatures = randomIntBetween(1, 15);

View File

@ -115,7 +115,10 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"inference\": {\n" + " \"inference\": {\n" +
" \"target_field\": \"ml.classification\",\n" + " \"target_field\": \"ml.classification\",\n" +
" \"inference_config\": {\"classification\": " + " \"inference_config\": {\"classification\": " +
" {\"num_top_classes\":2, \"top_classes_results_field\": \"result_class_prob\"}},\n" + " {\"num_top_classes\":2, " +
" \"top_classes_results_field\": \"result_class_prob\"," +
" \"num_top_feature_importance_values\": 2" +
" }},\n" +
" \"model_id\": \"test_classification\",\n" + " \"model_id\": \"test_classification\",\n" +
" \"field_mappings\": {\n" + " \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" + " \"col1\": \"col1\",\n" +
@ -153,6 +156,8 @@ public class InferenceIngestIT extends ESRestTestCase {
String responseString = EntityUtils.toString(response.getEntity()); String responseString = EntityUtils.toString(response.getEntity());
assertThat(responseString, containsString("\"predicted_value\":\"second\"")); assertThat(responseString, containsString("\"predicted_value\":\"second\""));
assertThat(responseString, containsString("\"predicted_value\":1.0")); assertThat(responseString, containsString("\"predicted_value\":1.0"));
assertThat(responseString, containsString("\"col2\":0.944"));
assertThat(responseString, containsString("\"col1\":0.19999"));
String sourceWithMissingModel = "{\n" + String sourceWithMissingModel = "{\n" +
" \"pipeline\": {\n" + " \"pipeline\": {\n" +
@ -321,16 +326,19 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"split_gain\": 12.0,\n" + " \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" + " \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" + " \"decision_type\": \"lte\",\n" +
" \"number_samples\": 300,\n" +
" \"default_left\": true,\n" + " \"default_left\": true,\n" +
" \"left_child\": 1,\n" + " \"left_child\": 1,\n" +
" \"right_child\": 2\n" + " \"right_child\": 2\n" +
" },\n" + " },\n" +
" {\n" + " {\n" +
" \"node_index\": 1,\n" + " \"node_index\": 1,\n" +
" \"number_samples\": 100,\n" +
" \"leaf_value\": 1\n" + " \"leaf_value\": 1\n" +
" },\n" + " },\n" +
" {\n" + " {\n" +
" \"node_index\": 2,\n" + " \"node_index\": 2,\n" +
" \"number_samples\": 200,\n" +
" \"leaf_value\": 2\n" + " \"leaf_value\": 2\n" +
" }\n" + " }\n" +
" ],\n" + " ],\n" +
@ -352,15 +360,18 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"threshold\": 10.0,\n" + " \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" + " \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" + " \"default_left\": true,\n" +
" \"number_samples\": 150,\n" +
" \"left_child\": 1,\n" + " \"left_child\": 1,\n" +
" \"right_child\": 2\n" + " \"right_child\": 2\n" +
" },\n" + " },\n" +
" {\n" + " {\n" +
" \"node_index\": 1,\n" + " \"node_index\": 1,\n" +
" \"number_samples\": 50,\n" +
" \"leaf_value\": 1\n" + " \"leaf_value\": 1\n" +
" },\n" + " },\n" +
" {\n" + " {\n" +
" \"node_index\": 2,\n" + " \"node_index\": 2,\n" +
" \"number_samples\": 100,\n" +
" \"leaf_value\": 2\n" + " \"leaf_value\": 2\n" +
" }\n" + " }\n" +
" ],\n" + " ],\n" +
@ -445,6 +456,7 @@ public class InferenceIngestIT extends ESRestTestCase {
" {\n" + " {\n" +
" \"node_index\": 0,\n" + " \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" + " \"split_feature\": 0,\n" +
" \"number_samples\": 100,\n" +
" \"split_gain\": 12.0,\n" + " \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" + " \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" + " \"decision_type\": \"lte\",\n" +
@ -454,10 +466,12 @@ public class InferenceIngestIT extends ESRestTestCase {
" },\n" + " },\n" +
" {\n" + " {\n" +
" \"node_index\": 1,\n" + " \"node_index\": 1,\n" +
" \"number_samples\": 80,\n" +
" \"leaf_value\": 1\n" + " \"leaf_value\": 1\n" +
" },\n" + " },\n" +
" {\n" + " {\n" +
" \"node_index\": 2,\n" + " \"node_index\": 2,\n" +
" \"number_samples\": 20,\n" +
" \"leaf_value\": 0\n" + " \"leaf_value\": 0\n" +
" }\n" + " }\n" +
" ],\n" + " ],\n" +
@ -476,6 +490,7 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"node_index\": 0,\n" + " \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" + " \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" + " \"split_gain\": 12.0,\n" +
" \"number_samples\": 180,\n" +
" \"threshold\": 10.0,\n" + " \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" + " \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" + " \"default_left\": true,\n" +
@ -484,10 +499,12 @@ public class InferenceIngestIT extends ESRestTestCase {
" },\n" + " },\n" +
" {\n" + " {\n" +
" \"node_index\": 1,\n" + " \"node_index\": 1,\n" +
" \"number_samples\": 10,\n" +
" \"leaf_value\": 1\n" + " \"leaf_value\": 1\n" +
" },\n" + " },\n" +
" {\n" + " {\n" +
" \"node_index\": 2,\n" + " \"node_index\": 2,\n" +
" \"number_samples\": 170,\n" +
" \"leaf_value\": 0\n" + " \"leaf_value\": 0\n" +
" }\n" + " }\n" +
" ],\n" + " ],\n" +

View File

@ -102,6 +102,43 @@ public class InferenceProcessorTests extends ESTestCase {
assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo")); assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
} }
public void testMutateDocumentClassificationFeatureInfluence() {
ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2);
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
"ml.my_processor",
"classification_model",
classificationConfig,
Collections.emptyMap());
Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
Map<String, Double> featureInfluence = new HashMap<>();
featureInfluence.put("feature_1", 1.13);
featureInfluence.put("feature_2", -42.0);
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0,
"foo",
classes,
featureInfluence,
classificationConfig)),
true);
inferenceProcessor.mutateDocument(response, document);
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void testMutateDocumentClassificationTopNClassesWithSpecificField() { public void testMutateDocumentClassificationTopNClassesWithSpecificField() {
ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops"); ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops");
@ -154,6 +191,34 @@ public class InferenceProcessorTests extends ESTestCase {
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model")); assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model"));
} }
public void testMutateDocumentRegressionWithTopFetures() {
RegressionConfig regressionConfig = new RegressionConfig("foo", 2);
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
"ml.my_processor",
"regression_model",
regressionConfig,
Collections.emptyMap());
Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
Map<String, Double> featureInfluence = new HashMap<>();
featureInfluence.put("feature_1", 1.13);
featureInfluence.put("feature_2", -42.0);
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true);
inferenceProcessor.mutateDocument(response, document);
assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7));
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model"));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
}
public void testGenerateRequestWithEmptyMapping() { public void testGenerateRequestWithEmptyMapping() {
String modelId = "model"; String modelId = "model";
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);