mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-23 13:26:02 +00:00
This adds machine learning model feature importance calculations to the inference processor. The new flag in the configuration matches the analytics parameter name: `num_top_feature_importance_values` Example: ``` "inference": { "field_mappings": {}, "model_id": "my_model", "inference_config": { "regression": { "num_top_feature_importance_values": 3 } } } ``` This will write to the document as follows: ``` "inference" : { "feature_importance" : { "FlightTimeMin" : -76.90955548511226, "FlightDelayType" : 114.13514762158526, "DistanceMiles" : 13.731580450792187 }, "predicted_value" : 108.33165831875137, "model_id" : "my_model" } ``` This is done through calculating the [SHAP values](https://arxiv.org/abs/1802.03888). It requires that models have populated `number_samples` for each tree node. This is not available to models that were created before 7.7. Additionally, if the inference config is requesting feature_importance, and not all nodes have been upgraded yet, it will not allow the pipeline to be created. This is to safe-guard in a mixed-version environment where only some ingest nodes have been upgraded. NOTE: the algorithm is a Java port of the one laid out in ml-cpp: https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc usability blocked by: https://github.com/elastic/ml-cpp/pull/991
This commit is contained in:
parent
f06d692706
commit
afd90647c9
@ -44,6 +44,12 @@ include::common-options.asciidoc[]
|
||||
Specifies the field to which the inference prediction is written. Defaults to
|
||||
`predicted_value`.
|
||||
|
||||
`num_top_feature_importance_values`::::
|
||||
(Optional, integer)
|
||||
Specifies the maximum number of
|
||||
{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature
|
||||
importance] values per document. By default, it is zero and no feature importance
|
||||
calculation occurs.
|
||||
|
||||
[discrete]
|
||||
[[inference-processor-classification-opt]]
|
||||
@ -63,6 +69,12 @@ Specifies the number of top class predictions to return. Defaults to 0.
|
||||
Specifies the field to which the top classes are written. Defaults to
|
||||
`top_classes`.
|
||||
|
||||
`num_top_feature_importance_values`::::
|
||||
(Optional, integer)
|
||||
Specifies the maximum number of
|
||||
{ml-docs}/dfa-classification.html#dfa-classification-feature-importance[feature
|
||||
importance] values per document. By default, it is zero and no feature importance
|
||||
calculation occurs.
|
||||
|
||||
[discrete]
|
||||
[[inference-processor-config-example]]
|
||||
|
@ -32,6 +32,7 @@ import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@ -73,6 +74,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
|
||||
|
||||
private final TrainedModel trainedModel;
|
||||
private final List<PreProcessor> preProcessors;
|
||||
private Map<String, String> decoderMap;
|
||||
|
||||
private TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
|
||||
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
|
||||
@ -115,13 +117,35 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
|
||||
return preProcessors;
|
||||
}
|
||||
|
||||
private void preProcess(Map<String, Object> fields) {
|
||||
void preProcess(Map<String, Object> fields) {
|
||||
preProcessors.forEach(preProcessor -> preProcessor.process(fields));
|
||||
}
|
||||
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
|
||||
preProcess(fields);
|
||||
return trainedModel.infer(fields, config);
|
||||
if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Feature importance is not supported for the configured model of type [{}]",
|
||||
trainedModel.getName());
|
||||
}
|
||||
return trainedModel.infer(fields,
|
||||
config,
|
||||
config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
|
||||
}
|
||||
|
||||
private Map<String, String> getDecoderMap() {
|
||||
if (decoderMap != null) {
|
||||
return decoderMap;
|
||||
}
|
||||
synchronized (this) {
|
||||
if (decoderMap != null) {
|
||||
return decoderMap;
|
||||
}
|
||||
this.decoderMap = preProcessors.stream()
|
||||
.map(PreProcessor::reverseLookup)
|
||||
.collect(HashMap::new, Map::putAll, Map::putAll);
|
||||
return decoderMap;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -25,6 +25,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembeddi
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@ -235,6 +236,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
|
||||
fields.put(destField, concatEmbeddings(processedFeatures));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, String> reverseLookup() {
|
||||
return Collections.singletonMap(destField, fieldName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
|
@ -97,6 +97,11 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
||||
return featureName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, String> reverseLookup() {
|
||||
return Collections.singletonMap(featureName, field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -18,8 +18,10 @@ import org.elasticsearch.xpack.core.ml.utils.MapHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* PreProcessor for one hot encoding a set of categorical values for a given field.
|
||||
@ -80,6 +82,11 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
||||
return hotMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, String> reverseLookup() {
|
||||
return hotMap.entrySet().stream().collect(Collectors.toMap(HashMap.Entry::getValue, (entry) -> field));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -24,4 +24,9 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
|
||||
* @param fields The fields and their values to process
|
||||
*/
|
||||
void process(Map<String, Object> fields);
|
||||
|
||||
/**
|
||||
* @return Reverse lookup map to match resulting features to their original feature name
|
||||
*/
|
||||
Map<String, String> reverseLookup();
|
||||
}
|
||||
|
@ -108,6 +108,11 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
||||
return featureName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, String> reverseLookup() {
|
||||
return Collections.singletonMap(featureName, field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -35,9 +35,25 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
||||
String classificationLabel,
|
||||
List<TopClassEntry> topClasses,
|
||||
InferenceConfig config) {
|
||||
super(value);
|
||||
assert config instanceof ClassificationConfig;
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig)config;
|
||||
this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
|
||||
}
|
||||
|
||||
public ClassificationInferenceResults(double value,
|
||||
String classificationLabel,
|
||||
List<TopClassEntry> topClasses,
|
||||
Map<String, Double> featureImportance,
|
||||
InferenceConfig config) {
|
||||
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
|
||||
}
|
||||
|
||||
private ClassificationInferenceResults(double value,
|
||||
String classificationLabel,
|
||||
List<TopClassEntry> topClasses,
|
||||
Map<String, Double> featureImportance,
|
||||
ClassificationConfig classificationConfig) {
|
||||
super(value,
|
||||
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
|
||||
classificationConfig.getNumTopFeatureImportanceValues()));
|
||||
this.classificationLabel = classificationLabel;
|
||||
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
|
||||
this.topNumClassesField = classificationConfig.getTopClassesResultsField();
|
||||
@ -74,16 +90,17 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
ClassificationInferenceResults that = (ClassificationInferenceResults) object;
|
||||
return Objects.equals(value(), that.value()) &&
|
||||
Objects.equals(classificationLabel, that.classificationLabel) &&
|
||||
Objects.equals(resultsField, that.resultsField) &&
|
||||
Objects.equals(topNumClassesField, that.topNumClassesField) &&
|
||||
Objects.equals(topClasses, that.topClasses);
|
||||
return Objects.equals(value(), that.value())
|
||||
&& Objects.equals(classificationLabel, that.classificationLabel)
|
||||
&& Objects.equals(resultsField, that.resultsField)
|
||||
&& Objects.equals(topNumClassesField, that.topNumClassesField)
|
||||
&& Objects.equals(topClasses, that.topClasses)
|
||||
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField);
|
||||
return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField, getFeatureImportance());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -100,6 +117,9 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
||||
document.setFieldValue(parentResultField + "." + topNumClassesField,
|
||||
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
|
||||
}
|
||||
if (getFeatureImportance().size() > 0) {
|
||||
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -10,18 +10,19 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class RawInferenceResults extends SingleValueInferenceResults {
|
||||
|
||||
public static final String NAME = "raw";
|
||||
|
||||
public RawInferenceResults(double value) {
|
||||
super(value);
|
||||
public RawInferenceResults(double value, Map<String, Double> featureImportance) {
|
||||
super(value, featureImportance);
|
||||
}
|
||||
|
||||
public RawInferenceResults(StreamInput in) throws IOException {
|
||||
super(in.readDouble());
|
||||
super(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -34,12 +35,13 @@ public class RawInferenceResults extends SingleValueInferenceResults {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
RawInferenceResults that = (RawInferenceResults) object;
|
||||
return Objects.equals(value(), that.value());
|
||||
return Objects.equals(value(), that.value())
|
||||
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(value());
|
||||
return Objects.hash(value(), getFeatureImportance());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -13,6 +13,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class RegressionInferenceResults extends SingleValueInferenceResults {
|
||||
@ -22,14 +24,22 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
||||
private final String resultsField;
|
||||
|
||||
public RegressionInferenceResults(double value, InferenceConfig config) {
|
||||
super(value);
|
||||
assert config instanceof RegressionConfig;
|
||||
RegressionConfig regressionConfig = (RegressionConfig)config;
|
||||
this(value, (RegressionConfig) config, Collections.emptyMap());
|
||||
}
|
||||
|
||||
public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
|
||||
this(value, (RegressionConfig)config, featureImportance);
|
||||
}
|
||||
|
||||
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
|
||||
super(value,
|
||||
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
|
||||
regressionConfig.getNumTopFeatureImportanceValues()));
|
||||
this.resultsField = regressionConfig.getResultsField();
|
||||
}
|
||||
|
||||
public RegressionInferenceResults(StreamInput in) throws IOException {
|
||||
super(in.readDouble());
|
||||
super(in);
|
||||
this.resultsField = in.readString();
|
||||
}
|
||||
|
||||
@ -44,12 +54,14 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
RegressionInferenceResults that = (RegressionInferenceResults) object;
|
||||
return Objects.equals(value(), that.value()) && Objects.equals(this.resultsField, that.resultsField);
|
||||
return Objects.equals(value(), that.value())
|
||||
&& Objects.equals(this.resultsField, that.resultsField)
|
||||
&& Objects.equals(this.getFeatureImportance(), that.getFeatureImportance());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(value(), resultsField);
|
||||
return Objects.hash(value(), resultsField, getFeatureImportance());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -57,6 +69,9 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
||||
ExceptionsHelper.requireNonNull(document, "document");
|
||||
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
|
||||
document.setFieldValue(parentResultField + "." + this.resultsField, value());
|
||||
if (getFeatureImportance().size() > 0) {
|
||||
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -5,27 +5,51 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public abstract class SingleValueInferenceResults implements InferenceResults {
|
||||
|
||||
private final double value;
|
||||
private final Map<String, Double> featureImportance;
|
||||
|
||||
static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
|
||||
return unsortedFeatureImportances.entrySet()
|
||||
.stream()
|
||||
.sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
|
||||
.limit(numTopFeatures)
|
||||
.collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
|
||||
}
|
||||
|
||||
SingleValueInferenceResults(StreamInput in) throws IOException {
|
||||
value = in.readDouble();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
|
||||
} else {
|
||||
this.featureImportance = Collections.emptyMap();
|
||||
}
|
||||
}
|
||||
|
||||
SingleValueInferenceResults(double value) {
|
||||
SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
|
||||
this.value = value;
|
||||
this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
|
||||
}
|
||||
|
||||
public Double value() {
|
||||
return value;
|
||||
}
|
||||
|
||||
public Map<String, Double> getFeatureImportance() {
|
||||
return featureImportance;
|
||||
}
|
||||
|
||||
public String valueAsString() {
|
||||
return String.valueOf(value);
|
||||
}
|
||||
@ -33,6 +57,9 @@ public abstract class SingleValueInferenceResults implements InferenceResults {
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(value);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -31,33 +31,39 @@ public class ClassificationConfig implements InferenceConfig {
|
||||
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
|
||||
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||
public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field");
|
||||
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
|
||||
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
|
||||
|
||||
public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD);
|
||||
public static ClassificationConfig EMPTY_PARAMS =
|
||||
new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD, null);
|
||||
|
||||
private final int numTopClasses;
|
||||
private final String topClassesResultsField;
|
||||
private final String resultsField;
|
||||
private final int numTopFeatureImportanceValues;
|
||||
|
||||
public static ClassificationConfig fromMap(Map<String, Object> map) {
|
||||
Map<String, Object> options = new HashMap<>(map);
|
||||
Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
|
||||
String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName());
|
||||
String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
|
||||
Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
|
||||
|
||||
if (options.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
|
||||
}
|
||||
return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField);
|
||||
return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, featureImportance);
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<ClassificationConfig, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ClassificationConfig(
|
||||
(Integer) args[0], (String) args[1], (String) args[2]));
|
||||
(Integer) args[0], (String) args[1], (String) args[2], (Integer) args[3]));
|
||||
|
||||
static {
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
|
||||
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
|
||||
PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
|
||||
}
|
||||
|
||||
public static ClassificationConfig fromXContent(XContentParser parser) {
|
||||
@ -65,19 +71,33 @@ public class ClassificationConfig implements InferenceConfig {
|
||||
}
|
||||
|
||||
public ClassificationConfig(Integer numTopClasses) {
|
||||
this(numTopClasses, null, null);
|
||||
this(numTopClasses, null, null, null);
|
||||
}
|
||||
|
||||
public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField) {
|
||||
this(numTopClasses, resultsField, topClassesResultsField, 0);
|
||||
}
|
||||
|
||||
public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField, Integer featureImportance) {
|
||||
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
|
||||
this.topClassesResultsField = topClassesResultsField == null ? DEFAULT_TOP_CLASSES_RESULTS_FIELD : topClassesResultsField;
|
||||
this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
|
||||
if (featureImportance != null && featureImportance < 0) {
|
||||
throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() +
|
||||
"] must be greater than or equal to 0");
|
||||
}
|
||||
this.numTopFeatureImportanceValues = featureImportance == null ? 0 : featureImportance;
|
||||
}
|
||||
|
||||
public ClassificationConfig(StreamInput in) throws IOException {
|
||||
this.numTopClasses = in.readInt();
|
||||
this.topClassesResultsField = in.readString();
|
||||
this.resultsField = in.readString();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
this.numTopFeatureImportanceValues = in.readVInt();
|
||||
} else {
|
||||
this.numTopFeatureImportanceValues = 0;
|
||||
}
|
||||
}
|
||||
|
||||
public int getNumTopClasses() {
|
||||
@ -92,11 +112,23 @@ public class ClassificationConfig implements InferenceConfig {
|
||||
return resultsField;
|
||||
}
|
||||
|
||||
public int getNumTopFeatureImportanceValues() {
|
||||
return numTopFeatureImportanceValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean requestingImportance() {
|
||||
return numTopFeatureImportanceValues > 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeInt(numTopClasses);
|
||||
out.writeString(topClassesResultsField);
|
||||
out.writeString(resultsField);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
out.writeVInt(numTopFeatureImportanceValues);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -104,14 +136,15 @@ public class ClassificationConfig implements InferenceConfig {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ClassificationConfig that = (ClassificationConfig) o;
|
||||
return Objects.equals(numTopClasses, that.numTopClasses) &&
|
||||
Objects.equals(topClassesResultsField, that.topClassesResultsField) &&
|
||||
Objects.equals(resultsField, that.resultsField);
|
||||
return Objects.equals(numTopClasses, that.numTopClasses)
|
||||
&& Objects.equals(topClassesResultsField, that.topClassesResultsField)
|
||||
&& Objects.equals(resultsField, that.resultsField)
|
||||
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(numTopClasses, topClassesResultsField, resultsField);
|
||||
return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -122,6 +155,9 @@ public class ClassificationConfig implements InferenceConfig {
|
||||
}
|
||||
builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField);
|
||||
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
|
||||
if (numTopFeatureImportanceValues > 0) {
|
||||
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
@ -143,7 +179,7 @@ public class ClassificationConfig implements InferenceConfig {
|
||||
|
||||
@Override
|
||||
public Version getMinimalSupportedVersion() {
|
||||
return MIN_SUPPORTED_VERSION;
|
||||
return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -18,4 +18,8 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable {
|
||||
* All nodes in the cluster must be at least this version
|
||||
*/
|
||||
Version getMinimalSupportedVersion();
|
||||
|
||||
default boolean requestingImportance() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,9 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
@ -98,4 +100,19 @@ public final class InferenceHelpers {
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Map<String, Double> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
|
||||
Map<String, Double> featureImportances) {
|
||||
if (processedFeatureToOriginalFeatureMap == null || processedFeatureToOriginalFeatureMap.isEmpty()) {
|
||||
return featureImportances;
|
||||
}
|
||||
|
||||
Map<String, Double> originalFeatureImportance = new HashMap<>();
|
||||
featureImportances.forEach((feature, importance) -> {
|
||||
String featureName = processedFeatureToOriginalFeatureMap.getOrDefault(feature, feature);
|
||||
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : v1 + importance);
|
||||
});
|
||||
|
||||
return originalFeatureImportance;
|
||||
}
|
||||
}
|
||||
|
@ -16,9 +16,12 @@ import java.io.IOException;
|
||||
*/
|
||||
public class NullInferenceConfig implements InferenceConfig {
|
||||
|
||||
public static final NullInferenceConfig INSTANCE = new NullInferenceConfig();
|
||||
private final boolean requestingFeatureImportance;
|
||||
|
||||
private NullInferenceConfig() { }
|
||||
|
||||
public NullInferenceConfig(boolean requestingFeatureImportance) {
|
||||
this.requestingFeatureImportance = requestingFeatureImportance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isTargetTypeSupported(TargetType targetType) {
|
||||
@ -37,6 +40,7 @@ public class NullInferenceConfig implements InferenceConfig {
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
throw new UnsupportedOperationException("Unable to serialize NullInferenceConfig objects");
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -46,6 +50,11 @@ public class NullInferenceConfig implements InferenceConfig {
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
return builder;
|
||||
throw new UnsupportedOperationException("Unable to write xcontent from NullInferenceConfig objects");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean requestingImportance() {
|
||||
return requestingFeatureImportance;
|
||||
}
|
||||
}
|
||||
|
@ -26,24 +26,27 @@ public class RegressionConfig implements InferenceConfig {
|
||||
public static final ParseField NAME = new ParseField("regression");
|
||||
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
|
||||
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
|
||||
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
|
||||
private static final String DEFAULT_RESULTS_FIELD = "predicted_value";
|
||||
|
||||
public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD);
|
||||
public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD, null);
|
||||
|
||||
public static RegressionConfig fromMap(Map<String, Object> map) {
|
||||
Map<String, Object> options = new HashMap<>(map);
|
||||
String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
|
||||
Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
|
||||
if (options.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
|
||||
}
|
||||
return new RegressionConfig(resultsField);
|
||||
return new RegressionConfig(resultsField, featureImportance);
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<RegressionConfig, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0]));
|
||||
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0], (Integer)args[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
|
||||
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
|
||||
}
|
||||
|
||||
public static RegressionConfig fromXContent(XContentParser parser) {
|
||||
@ -51,19 +54,43 @@ public class RegressionConfig implements InferenceConfig {
|
||||
}
|
||||
|
||||
private final String resultsField;
|
||||
private final int numTopFeatureImportanceValues;
|
||||
|
||||
public RegressionConfig(String resultsField) {
|
||||
this(resultsField, 0);
|
||||
}
|
||||
|
||||
public RegressionConfig(String resultsField, Integer numTopFeatureImportanceValues) {
|
||||
this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
|
||||
if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) {
|
||||
throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() +
|
||||
"] must be greater than or equal to 0");
|
||||
}
|
||||
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues == null ? 0 : numTopFeatureImportanceValues;
|
||||
}
|
||||
|
||||
public RegressionConfig(StreamInput in) throws IOException {
|
||||
this.resultsField = in.readString();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
this.numTopFeatureImportanceValues = in.readVInt();
|
||||
} else {
|
||||
this.numTopFeatureImportanceValues = 0;
|
||||
}
|
||||
}
|
||||
|
||||
public int getNumTopFeatureImportanceValues() {
|
||||
return numTopFeatureImportanceValues;
|
||||
}
|
||||
|
||||
public String getResultsField() {
|
||||
return resultsField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean requestingImportance() {
|
||||
return numTopFeatureImportanceValues > 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
@ -72,6 +99,9 @@ public class RegressionConfig implements InferenceConfig {
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(resultsField);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
out.writeVInt(numTopFeatureImportanceValues);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -83,6 +113,9 @@ public class RegressionConfig implements InferenceConfig {
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
|
||||
if (numTopFeatureImportanceValues > 0) {
|
||||
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
@ -92,12 +125,13 @@ public class RegressionConfig implements InferenceConfig {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
RegressionConfig that = (RegressionConfig)o;
|
||||
return Objects.equals(this.resultsField, that.resultsField);
|
||||
return Objects.equals(this.resultsField, that.resultsField)
|
||||
&& Objects.equals(this.numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(resultsField);
|
||||
return Objects.hash(resultsField, numTopFeatureImportanceValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -107,7 +141,7 @@ public class RegressionConfig implements InferenceConfig {
|
||||
|
||||
@Override
|
||||
public Version getMinimalSupportedVersion() {
|
||||
return MIN_SUPPORTED_VERSION;
|
||||
return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,162 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
|
||||
/**
|
||||
* Ported from https://github.com/elastic/ml-cpp/blob/master/include/maths/CTreeShapFeatureImportance.h Path struct
|
||||
*/
|
||||
public class ShapPath {
|
||||
private static final double DBL_EPSILON = Double.MIN_VALUE;
|
||||
|
||||
private final PathElement[] pathElements;
|
||||
private final double[] scale;
|
||||
private final int elementAndScaleOffset;
|
||||
|
||||
public ShapPath(ShapPath parentPath, int nextIndex) {
|
||||
this.elementAndScaleOffset = parentPath.elementAndScaleOffset + nextIndex;
|
||||
this.pathElements = parentPath.pathElements;
|
||||
this.scale = parentPath.scale;
|
||||
for (int i = 0; i < nextIndex; i++) {
|
||||
pathElements[elementAndScaleOffset + i].featureIndex = parentPath.getElement(i).featureIndex;
|
||||
pathElements[elementAndScaleOffset + i].fractionZeros = parentPath.getElement(i).fractionZeros;
|
||||
pathElements[elementAndScaleOffset + i].fractionOnes = parentPath.getElement(i).fractionOnes;
|
||||
scale[elementAndScaleOffset + i] = parentPath.getScale(i);
|
||||
}
|
||||
}
|
||||
|
||||
public ShapPath(PathElement[] elements, double[] scale) {
|
||||
this.pathElements = elements;
|
||||
this.scale = scale;
|
||||
this.elementAndScaleOffset = 0;
|
||||
}
|
||||
|
||||
// Update binomial coefficients to be able to compute Equation (2) from the paper. In particular,
|
||||
// we have in the line path.scale[i + 1] += fractionOne * path.scale[i] * (i + 1.0) / (pathDepth +
|
||||
// 1.0) that if we're on the "one" path, i.e. if the last feature selects this path if we include that
|
||||
// feature in S (then fractionOne is 1), and we need to consider all the additional ways we now have of
|
||||
// constructing each S of each given cardinality i + 1. Each of these come by adding the last feature
|
||||
// to sets of size i and we **also** need to scale by the difference in binomial coefficients as both M
|
||||
// increases by one and i increases by one. So we get additive term 1{last feature selects path if in S}
|
||||
// * scale(i) * (i+1)! (M+1-(i+1)-1)!/(M+1)! / (i! (M-i-1)!/ M!), whence += scale(i) * (i+1) / (M+1).
|
||||
public int extend(double fractionZero, double fractionOne, int featureIndex, int nextIndex) {
|
||||
setValues(nextIndex, fractionOne, fractionZero, featureIndex);
|
||||
setScale(nextIndex, nextIndex == 0 ? 1.0 : 0.0);
|
||||
double stepDown = fractionOne / (double)(nextIndex + 1);
|
||||
double stepUp = fractionZero / (double)(nextIndex + 1);
|
||||
double countDown = nextIndex * stepDown;
|
||||
double countUp = stepUp;
|
||||
for (int i = (nextIndex - 1); i >= 0; --i, countDown -= stepDown, countUp += stepUp) {
|
||||
setScale(i + 1, getScale(i + 1) + getScale(i) * countDown);
|
||||
setScale(i, getScale(i) * countUp);
|
||||
}
|
||||
return nextIndex + 1;
|
||||
}
|
||||
|
||||
public double sumUnwoundPath(int pathIndex, int nextIndex) {
|
||||
double total = 0.0;
|
||||
int pathDepth = nextIndex - 1;
|
||||
double nextFractionOne = getScale(pathDepth);
|
||||
double fractionOne = fractionOnes(pathIndex);
|
||||
double fractionZero = fractionZeros(pathIndex);
|
||||
if (fractionOne != 0) {
|
||||
double pD = pathDepth + 1;
|
||||
double stepUp = fractionZero / pD;
|
||||
double stepDown = fractionOne / pD;
|
||||
double countUp = stepUp;
|
||||
double countDown = (pD - 1.0) * stepDown;
|
||||
for (int i = pathDepth - 1; i >= 0; --i, countUp += stepUp, countDown -= stepDown) {
|
||||
double tmp = nextFractionOne / countDown;
|
||||
nextFractionOne = getScale(i) - tmp * countUp;
|
||||
total += tmp;
|
||||
}
|
||||
} else {
|
||||
double pD = pathDepth;
|
||||
|
||||
for(int i = 0; i < pathDepth; i++) {
|
||||
total += getScale(i) / pD--;
|
||||
}
|
||||
total *= (pathDepth + 1) / (fractionZero + DBL_EPSILON);
|
||||
}
|
||||
|
||||
return total;
|
||||
}
|
||||
|
||||
public int unwind(int pathIndex, int nextIndex) {
|
||||
int pathDepth = nextIndex - 1;
|
||||
double nextFractionOne = getScale(pathDepth);
|
||||
double fractionOne = fractionOnes(pathIndex);
|
||||
double fractionZero = fractionZeros(pathIndex);
|
||||
|
||||
if (fractionOne != 0) {
|
||||
double stepUp = fractionZero / (double)(pathDepth + 1);
|
||||
double stepDown = fractionOne / (double)nextIndex;
|
||||
double countUp = 0.0;
|
||||
double countDown = nextIndex * stepDown;
|
||||
for (int i = pathDepth; i >= 0; --i, countUp += stepUp, countDown -= stepDown) {
|
||||
double tmp = nextFractionOne / countDown;
|
||||
nextFractionOne = getScale(i) - tmp * countUp;
|
||||
setScale(i, tmp);
|
||||
}
|
||||
} else {
|
||||
double stepDown = (fractionZero + DBL_EPSILON) / (double)(pathDepth + 1);
|
||||
double countDown = pathDepth * stepDown;
|
||||
for (int i = 0; i <= pathDepth; ++i, countDown -= stepDown) {
|
||||
setScale(i, getScale(i) / countDown);
|
||||
}
|
||||
}
|
||||
for (int i = pathIndex; i < pathDepth; ++i) {
|
||||
PathElement element = getElement(i + 1);
|
||||
setValues(i, element.fractionOnes, element.fractionZeros, element.featureIndex);
|
||||
}
|
||||
return nextIndex - 1;
|
||||
}
|
||||
|
||||
private void setValues(int index, double fractionOnes, double fractionZeros, int featureIndex) {
|
||||
pathElements[index + elementAndScaleOffset].fractionOnes = fractionOnes;
|
||||
pathElements[index + elementAndScaleOffset].fractionZeros = fractionZeros;
|
||||
pathElements[index + elementAndScaleOffset].featureIndex = featureIndex;
|
||||
}
|
||||
|
||||
private double getScale(int offset) {
|
||||
return scale[offset + elementAndScaleOffset];
|
||||
}
|
||||
|
||||
private void setScale(int offset, double value) {
|
||||
scale[offset + elementAndScaleOffset] = value;
|
||||
}
|
||||
|
||||
public double fractionOnes(int pathIndex) {
|
||||
return pathElements[pathIndex + elementAndScaleOffset].fractionOnes;
|
||||
}
|
||||
|
||||
public double fractionZeros(int pathIndex) {
|
||||
return pathElements[pathIndex + elementAndScaleOffset].fractionZeros;
|
||||
}
|
||||
|
||||
public int findFeatureIndex(int splitFeature, int nextIndex) {
|
||||
for (int i = elementAndScaleOffset; i < elementAndScaleOffset + nextIndex; i++) {
|
||||
if (pathElements[i].featureIndex == splitFeature) {
|
||||
return i - elementAndScaleOffset;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
public int featureIndex(int pathIndex) {
|
||||
return pathElements[pathIndex + elementAndScaleOffset].featureIndex;
|
||||
}
|
||||
|
||||
private PathElement getElement(int offset) {
|
||||
return pathElements[offset + elementAndScaleOffset];
|
||||
}
|
||||
|
||||
public static final class PathElement {
|
||||
private double fractionOnes = 1.0;
|
||||
private double fractionZeros = 1.0;
|
||||
private int featureIndex = -1;
|
||||
}
|
||||
}
|
@ -6,6 +6,7 @@
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
@ -17,12 +18,16 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
|
||||
/**
|
||||
* Infer against the provided fields
|
||||
*
|
||||
* NOTE: Must be thread safe
|
||||
*
|
||||
* @param fields The fields and their values to infer against
|
||||
* @param config The configuration options for inference
|
||||
* @param featureDecoderMap A map for decoding feature value names to their originating feature.
|
||||
* Necessary for feature influence.
|
||||
* @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0).
|
||||
* For regression this is continuous.
|
||||
*/
|
||||
InferenceResults infer(Map<String, Object> fields, InferenceConfig config);
|
||||
InferenceResults infer(Map<String, Object> fields, InferenceConfig config, @Nullable Map<String, String> featureDecoderMap);
|
||||
|
||||
/**
|
||||
* @return {@link TargetType} for the model.
|
||||
@ -42,4 +47,19 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
|
||||
* @return The estimated number of operations required at inference time
|
||||
*/
|
||||
long estimatedNumOperations();
|
||||
|
||||
/**
|
||||
* @return Does the model support feature importance
|
||||
*/
|
||||
boolean supportsFeatureImportance();
|
||||
|
||||
/**
|
||||
* Calculates the importance of each feature reference by the model for the passed in field values
|
||||
*
|
||||
* NOTE: Must be thread safe
|
||||
* @param fields The fields inferring against
|
||||
* @param featureDecoder A Map translating processed feature names to their original feature names
|
||||
* @return A {@code Map<String, Double>} mapping each featureName to its importance
|
||||
*/
|
||||
Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
|
||||
}
|
||||
|
@ -37,6 +37,7 @@ import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@ -133,18 +134,25 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
|
||||
if (config.isTargetTypeSupported(targetType) == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
|
||||
}
|
||||
List<Double> inferenceResults = this.models.stream().map(model -> {
|
||||
InferenceResults results = model.infer(fields, NullInferenceConfig.INSTANCE);
|
||||
assert results instanceof SingleValueInferenceResults;
|
||||
return ((SingleValueInferenceResults)results).value();
|
||||
}).collect(Collectors.toList());
|
||||
List<Double> inferenceResults = new ArrayList<>(this.models.size());
|
||||
List<Map<String, Double>> featureInfluence = new ArrayList<>();
|
||||
NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
|
||||
this.models.forEach(model -> {
|
||||
InferenceResults result = model.infer(fields, subModelInferenceConfig, Collections.emptyMap());
|
||||
assert result instanceof SingleValueInferenceResults;
|
||||
SingleValueInferenceResults inferenceResult = (SingleValueInferenceResults) result;
|
||||
inferenceResults.add(inferenceResult.value());
|
||||
if (config.requestingImportance()) {
|
||||
featureInfluence.add(inferenceResult.getFeatureImportance());
|
||||
}
|
||||
});
|
||||
List<Double> processed = outputAggregator.processValues(inferenceResults);
|
||||
return buildResults(processed, config);
|
||||
return buildResults(processed, featureInfluence, config, featureDecoderMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -152,14 +160,20 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
||||
return targetType;
|
||||
}
|
||||
|
||||
private InferenceResults buildResults(List<Double> processedInferences, InferenceConfig config) {
|
||||
private InferenceResults buildResults(List<Double> processedInferences,
|
||||
List<Map<String, Double>> featureInfluence,
|
||||
InferenceConfig config,
|
||||
Map<String, String> featureDecoderMap) {
|
||||
// Indicates that the config is useless and the caller just wants the raw value
|
||||
if (config instanceof NullInferenceConfig) {
|
||||
return new RawInferenceResults(outputAggregator.aggregate(processedInferences));
|
||||
return new RawInferenceResults(outputAggregator.aggregate(processedInferences),
|
||||
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
|
||||
}
|
||||
switch(targetType) {
|
||||
case REGRESSION:
|
||||
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences), config);
|
||||
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
|
||||
config,
|
||||
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
|
||||
case CLASSIFICATION:
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
assert classificationWeights == null || processedInferences.size() == classificationWeights.length;
|
||||
@ -172,6 +186,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
||||
return new ClassificationInferenceResults((double)topClasses.v1(),
|
||||
classificationLabel(topClasses.v1(), classificationLabels),
|
||||
topClasses.v2(),
|
||||
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)),
|
||||
config);
|
||||
default:
|
||||
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
|
||||
@ -293,6 +308,27 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
||||
return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsFeatureImportance() {
|
||||
return models.stream().allMatch(TrainedModel::supportsFeatureImportance);
|
||||
}
|
||||
|
||||
Map<String, Double> featureImportance(Map<String, Object> fields) {
|
||||
return featureImportance(fields, Collections.emptyMap());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
|
||||
Map<String, Double> collapsed = mergeFeatureImportances(models.stream()
|
||||
.map(trainedModel -> trainedModel.featureImportance(fields, Collections.emptyMap()))
|
||||
.collect(Collectors.toList()));
|
||||
return InferenceHelpers.decodeFeatureImportances(featureDecoder, collapsed);
|
||||
}
|
||||
|
||||
private static Map<String, Double> mergeFeatureImportances(List<Map<String, Double>> featureImportances) {
|
||||
return featureImportances.stream().collect(HashMap::new, (a, b) -> b.forEach((k, v) -> a.merge(k, v, Double::sum)), Map::putAll);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@ -104,7 +105,11 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
|
||||
if (config.requestingImportance()) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] model does not supports feature importance",
|
||||
NAME.getPreferredName());
|
||||
}
|
||||
if (config instanceof ClassificationConfig == false) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] model only supports classification",
|
||||
NAME.getPreferredName());
|
||||
@ -138,6 +143,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
|
||||
return new ClassificationInferenceResults(topClasses.v1(),
|
||||
LANGUAGE_NAMES.get(topClasses.v1()),
|
||||
topClasses.v2(),
|
||||
Collections.emptyMap(),
|
||||
classificationConfig);
|
||||
}
|
||||
|
||||
@ -159,6 +165,16 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
|
||||
return numOps;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsFeatureImportance() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
|
||||
throw new UnsupportedOperationException("[lang_ident] does not support feature importance");
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
|
@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
@ -44,6 +45,7 @@ import java.util.Objects;
|
||||
import java.util.Queue;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
|
||||
|
||||
@ -86,6 +88,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
private final TargetType targetType;
|
||||
private final List<String> classificationLabels;
|
||||
private final CachedSupplier<Double> highestOrderCategory;
|
||||
// populated lazily when feature importance is calculated
|
||||
private double[] nodeEstimates;
|
||||
private Integer maxDepth;
|
||||
|
||||
Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
|
||||
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
|
||||
@ -120,7 +125,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
|
||||
if (config.isTargetTypeSupported(targetType) == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
|
||||
@ -129,21 +134,23 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
List<Double> features = featureNames.stream()
|
||||
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
|
||||
.collect(Collectors.toList());
|
||||
return infer(features, config);
|
||||
}
|
||||
|
||||
private InferenceResults infer(List<Double> features, InferenceConfig config) {
|
||||
Map<String, Double> featureImportance = config.requestingImportance() ?
|
||||
featureImportance(features, featureDecoderMap) :
|
||||
Collections.emptyMap();
|
||||
|
||||
TreeNode node = nodes.get(0);
|
||||
while(node.isLeaf() == false) {
|
||||
node = nodes.get(node.compare(features));
|
||||
}
|
||||
return buildResult(node.getLeafValue(), config);
|
||||
|
||||
return buildResult(node.getLeafValue(), featureImportance, config);
|
||||
}
|
||||
|
||||
private InferenceResults buildResult(Double value, InferenceConfig config) {
|
||||
private InferenceResults buildResult(Double value, Map<String, Double> featureImportance, InferenceConfig config) {
|
||||
// Indicates that the config is useless and the caller just wants the raw value
|
||||
if (config instanceof NullInferenceConfig) {
|
||||
return new RawInferenceResults(value);
|
||||
return new RawInferenceResults(value, featureImportance);
|
||||
}
|
||||
switch (targetType) {
|
||||
case CLASSIFICATION:
|
||||
@ -156,9 +163,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
return new ClassificationInferenceResults(value,
|
||||
classificationLabel(topClasses.v1(), classificationLabels),
|
||||
topClasses.v2(),
|
||||
featureImportance,
|
||||
config);
|
||||
case REGRESSION:
|
||||
return new RegressionInferenceResults(value, config);
|
||||
return new RegressionInferenceResults(value, config, featureImportance);
|
||||
default:
|
||||
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
|
||||
}
|
||||
@ -192,7 +200,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
// If we are classification, we should assume that the largest leaf value is whole.
|
||||
assert maxCategory == Math.rint(maxCategory);
|
||||
List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
|
||||
// TODO, eventually have TreeNodes contain confidence levels
|
||||
list.set(Double.valueOf(inferenceValue).intValue(), 1.0);
|
||||
return list;
|
||||
}
|
||||
@ -263,12 +270,138 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
detectCycle();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
|
||||
if (nodes.stream().allMatch(n -> n.getNumberSamples() == 0)) {
|
||||
throw ExceptionsHelper.badRequestException("[tree_structure.number_samples] must be greater than zero for feature importance");
|
||||
}
|
||||
List<Double> features = featureNames.stream()
|
||||
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
|
||||
.collect(Collectors.toList());
|
||||
return featureImportance(features, featureDecoder);
|
||||
}
|
||||
|
||||
private Map<String, Double> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
|
||||
calculateNodeEstimatesIfNeeded();
|
||||
double[] featureImportance = new double[fieldValues.size()];
|
||||
int arrSize = ((this.maxDepth + 1) * (this.maxDepth + 2))/2;
|
||||
ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize];
|
||||
for (int i = 0; i < arrSize; i++) {
|
||||
elements[i] = new ShapPath.PathElement();
|
||||
}
|
||||
double[] scale = new double[arrSize];
|
||||
ShapPath initialPath = new ShapPath(elements, scale);
|
||||
shapRecursive(fieldValues, this.nodeEstimates, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
|
||||
return InferenceHelpers.decodeFeatureImportances(featureDecoder,
|
||||
IntStream.range(0, featureImportance.length)
|
||||
.boxed()
|
||||
.collect(Collectors.toMap(featureNames::get, i -> featureImportance[i])));
|
||||
}
|
||||
|
||||
private void calculateNodeEstimatesIfNeeded() {
|
||||
if (this.nodeEstimates != null && this.maxDepth != null) {
|
||||
return;
|
||||
}
|
||||
synchronized (this) {
|
||||
if (this.nodeEstimates != null && this.maxDepth != null) {
|
||||
return;
|
||||
}
|
||||
double[] estimates = new double[nodes.size()];
|
||||
this.maxDepth = fillNodeEstimates(estimates, 0, 0);
|
||||
this.nodeEstimates = estimates;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Note, this is a port from https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc
|
||||
*
|
||||
* If improvements in performance or accuracy have been found, it is probably best that the changes are implemented on the native
|
||||
* side first and then ported to the Java side.
|
||||
*/
|
||||
private void shapRecursive(List<Double> processedFeatures,
|
||||
double[] nodeValues,
|
||||
ShapPath parentSplitPath,
|
||||
int nodeIndex,
|
||||
double parentFractionZero,
|
||||
double parentFractionOne,
|
||||
int parentFeatureIndex,
|
||||
double[] featureImportance,
|
||||
int nextIndex) {
|
||||
ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex);
|
||||
TreeNode currNode = nodes.get(nodeIndex);
|
||||
nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
|
||||
if (currNode.isLeaf()) {
|
||||
// TODO multi-value????
|
||||
double leafValue = nodeValues[nodeIndex];
|
||||
for (int i = 1; i < nextIndex; ++i) {
|
||||
double scale = splitPath.sumUnwoundPath(i, nextIndex);
|
||||
int inputColumnIndex = splitPath.featureIndex(i);
|
||||
featureImportance[inputColumnIndex] += scale * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i)) * leafValue;
|
||||
}
|
||||
} else {
|
||||
int hotIndex = currNode.compare(processedFeatures);
|
||||
int coldIndex = hotIndex == currNode.getLeftChild() ? currNode.getRightChild() : currNode.getLeftChild();
|
||||
|
||||
double incomingFractionZero = 1.0;
|
||||
double incomingFractionOne = 1.0;
|
||||
int splitFeature = currNode.getSplitFeature();
|
||||
int pathIndex = splitPath.findFeatureIndex(splitFeature, nextIndex);
|
||||
if (pathIndex > -1) {
|
||||
incomingFractionZero = splitPath.fractionZeros(pathIndex);
|
||||
incomingFractionOne = splitPath.fractionOnes(pathIndex);
|
||||
nextIndex = splitPath.unwind(pathIndex, nextIndex);
|
||||
}
|
||||
|
||||
double hotFractionZero = nodes.get(hotIndex).getNumberSamples() / (double)currNode.getNumberSamples();
|
||||
double coldFractionZero = nodes.get(coldIndex).getNumberSamples() / (double)currNode.getNumberSamples();
|
||||
shapRecursive(processedFeatures, nodeValues, splitPath,
|
||||
hotIndex, incomingFractionZero * hotFractionZero,
|
||||
incomingFractionOne, splitFeature, featureImportance, nextIndex);
|
||||
shapRecursive(processedFeatures, nodeValues, splitPath,
|
||||
coldIndex, incomingFractionZero * coldFractionZero,
|
||||
0.0, splitFeature, featureImportance, nextIndex);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This recursively populates the provided {@code double[]} with the node estimated values
|
||||
*
|
||||
* Used when calculating feature importance.
|
||||
* @param nodeEstimates Array to update in place with the node estimated values
|
||||
* @param nodeIndex Current node index
|
||||
* @param depth Current depth
|
||||
* @return The current max depth
|
||||
*/
|
||||
private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) {
|
||||
TreeNode node = nodes.get(nodeIndex);
|
||||
if (node.isLeaf()) {
|
||||
nodeEstimates[nodeIndex] = node.getLeafValue();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int depthLeft = fillNodeEstimates(nodeEstimates, node.getLeftChild(), depth + 1);
|
||||
int depthRight = fillNodeEstimates(nodeEstimates, node.getRightChild(), depth + 1);
|
||||
long leftWeight = nodes.get(node.getLeftChild()).getNumberSamples();
|
||||
long rightWeight = nodes.get(node.getRightChild()).getNumberSamples();
|
||||
long divisor = leftWeight + rightWeight;
|
||||
double averageValue = divisor == 0 ?
|
||||
0.0 :
|
||||
(leftWeight * nodeEstimates[node.getLeftChild()] + rightWeight * nodeEstimates[node.getRightChild()]) / divisor;
|
||||
nodeEstimates[nodeIndex] = averageValue;
|
||||
return Math.max(depthLeft, depthRight) + 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long estimatedNumOperations() {
|
||||
// Grabbing the features from the doc + the depth of the tree
|
||||
return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsFeatureImportance() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* The highest index of a feature used any of the nodes.
|
||||
* If no nodes use a feature return -1. This can only happen
|
||||
|
@ -342,8 +342,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
|
||||
}};
|
||||
|
||||
assertThat(
|
||||
((ClassificationInferenceResults)definition.getTrainedModel()
|
||||
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
|
||||
((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
|
||||
.getClassificationLabel(),
|
||||
equalTo("Iris-setosa"));
|
||||
|
||||
@ -354,8 +353,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
|
||||
put("petal_width", 1.4);
|
||||
}};
|
||||
assertThat(
|
||||
((ClassificationInferenceResults)definition.getTrainedModel()
|
||||
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
|
||||
((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
|
||||
.getClassificationLabel(),
|
||||
equalTo("Iris-versicolor"));
|
||||
|
||||
@ -366,10 +364,8 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
|
||||
put("petal_width", 2.0);
|
||||
}};
|
||||
assertThat(
|
||||
((ClassificationInferenceResults)definition.getTrainedModel()
|
||||
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
|
||||
((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
|
||||
.getClassificationLabel(),
|
||||
equalTo("Iris-virginica"));
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -8,10 +8,12 @@ package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
|
||||
import java.util.Collections;
|
||||
|
||||
public class RawInferenceResultsTests extends AbstractWireSerializingTestCase<RawInferenceResults> {
|
||||
|
||||
public static RawInferenceResults createRandomResults() {
|
||||
return new RawInferenceResults(randomDouble());
|
||||
return new RawInferenceResults(randomDouble(), randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -30,11 +30,12 @@ public class ClassificationConfigTests extends AbstractSerializingTestCase<Class
|
||||
ClassificationConfig expected = ClassificationConfig.EMPTY_PARAMS;
|
||||
assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected));
|
||||
|
||||
expected = new ClassificationConfig(3, "foo", "bar");
|
||||
expected = new ClassificationConfig(3, "foo", "bar", 2);
|
||||
Map<String, Object> configMap = new HashMap<>();
|
||||
configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
|
||||
configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo");
|
||||
configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar");
|
||||
configMap.put(ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 2);
|
||||
assertThat(ClassificationConfig.fromMap(configMap), equalTo(expected));
|
||||
}
|
||||
|
||||
|
@ -24,9 +24,10 @@ public class RegressionConfigTests extends AbstractSerializingTestCase<Regressio
|
||||
}
|
||||
|
||||
public void testFromMap() {
|
||||
RegressionConfig expected = new RegressionConfig("foo");
|
||||
RegressionConfig expected = new RegressionConfig("foo", 3);
|
||||
Map<String, Object> config = new HashMap<String, Object>(){{
|
||||
put(RegressionConfig.RESULTS_FIELD.getPreferredName(), "foo");
|
||||
put(RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 3);
|
||||
}};
|
||||
assertThat(RegressionConfig.fromMap(config), equalTo(expected));
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Operator;
|
||||
import org.junit.Before;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
@ -39,6 +40,7 @@ import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
private final double eps = 1.0E-8;
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@ -267,7 +269,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
List<Double> scores = Arrays.asList(0.230557435, 0.162032651);
|
||||
double eps = 0.000001;
|
||||
List<ClassificationInferenceResults.TopClassEntry> probabilities =
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
|
||||
.getTopClasses();
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
||||
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
|
||||
@ -278,7 +281,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
expected = Arrays.asList(0.310025518, 0.6899744811);
|
||||
scores = Arrays.asList(0.217017863, 0.2069923443);
|
||||
probabilities =
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
|
||||
.getTopClasses();
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
||||
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
|
||||
@ -289,7 +293,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
expected = Arrays.asList(0.768524783, 0.231475216);
|
||||
scores = Arrays.asList(0.230557435, 0.162032651);
|
||||
probabilities =
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
|
||||
.getTopClasses();
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
||||
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
|
||||
@ -303,7 +308,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
expected = Arrays.asList(0.6899744811, 0.3100255188);
|
||||
scores = Arrays.asList(0.482982136, 0.0930076556);
|
||||
probabilities =
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
|
||||
.getTopClasses();
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
||||
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
|
||||
@ -361,24 +367,28 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
|
||||
featureVector = Arrays.asList(0.0, 1.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
put("foo", 0.3);
|
||||
put("bar", null);
|
||||
}};
|
||||
assertThat(0.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
}
|
||||
|
||||
public void testMultiClassClassificationInference() {
|
||||
@ -432,24 +442,28 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(2.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
|
||||
featureVector = Arrays.asList(0.0, 1.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
put("foo", 0.6);
|
||||
put("bar", null);
|
||||
}};
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
}
|
||||
|
||||
public void testRegressionInference() {
|
||||
@ -489,12 +503,16 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(0.9,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
||||
.value(),
|
||||
0.00001));
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(0.5,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
||||
.value(),
|
||||
0.00001));
|
||||
|
||||
// Test with NO aggregator supplied, verifies default behavior of non-weighted sum
|
||||
ensemble = Ensemble.builder()
|
||||
@ -506,19 +524,25 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
featureVector = Arrays.asList(0.4, 0.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(1.8,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
||||
.value(),
|
||||
0.00001));
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
||||
.value(),
|
||||
0.00001));
|
||||
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
put("foo", 0.3);
|
||||
put("bar", null);
|
||||
}};
|
||||
assertThat(1.8,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
||||
.value(),
|
||||
0.00001));
|
||||
}
|
||||
|
||||
public void testInferNestedFields() {
|
||||
@ -564,7 +588,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
}});
|
||||
}};
|
||||
assertThat(0.9,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
||||
.value(),
|
||||
0.00001));
|
||||
|
||||
featureMap = new HashMap<String, Object>() {{
|
||||
put("foo", new HashMap<String, Object>(){{
|
||||
@ -575,7 +601,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
}});
|
||||
}};
|
||||
assertThat(0.5,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
||||
.value(),
|
||||
0.00001));
|
||||
}
|
||||
|
||||
public void testOperationsEstimations() {
|
||||
@ -590,6 +618,114 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
assertThat(ensemble.estimatedNumOperations(), equalTo(9L));
|
||||
}
|
||||
|
||||
public void testFeatureImportance() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree tree1 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setNodes(
|
||||
TreeNode.builder(0)
|
||||
.setSplitFeature(0)
|
||||
.setOperator(Operator.LT)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setThreshold(0.55)
|
||||
.setNumberSamples(10L),
|
||||
TreeNode.builder(1)
|
||||
.setSplitFeature(0)
|
||||
.setLeftChild(3)
|
||||
.setRightChild(4)
|
||||
.setOperator(Operator.LT)
|
||||
.setThreshold(0.41)
|
||||
.setNumberSamples(6L),
|
||||
TreeNode.builder(2)
|
||||
.setSplitFeature(1)
|
||||
.setLeftChild(5)
|
||||
.setRightChild(6)
|
||||
.setOperator(Operator.LT)
|
||||
.setThreshold(0.25)
|
||||
.setNumberSamples(4L),
|
||||
TreeNode.builder(3).setLeafValue(1.18230136).setNumberSamples(5L),
|
||||
TreeNode.builder(4).setLeafValue(1.98006658).setNumberSamples(1L),
|
||||
TreeNode.builder(5).setLeafValue(3.25350885).setNumberSamples(3L),
|
||||
TreeNode.builder(6).setLeafValue(2.42384369).setNumberSamples(1L)).build();
|
||||
|
||||
Tree tree2 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setNodes(
|
||||
TreeNode.builder(0)
|
||||
.setSplitFeature(0)
|
||||
.setOperator(Operator.LT)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setThreshold(0.45)
|
||||
.setNumberSamples(10L),
|
||||
TreeNode.builder(1)
|
||||
.setSplitFeature(0)
|
||||
.setLeftChild(3)
|
||||
.setRightChild(4)
|
||||
.setOperator(Operator.LT)
|
||||
.setThreshold(0.25)
|
||||
.setNumberSamples(5L),
|
||||
TreeNode.builder(2)
|
||||
.setSplitFeature(0)
|
||||
.setLeftChild(5)
|
||||
.setRightChild(6)
|
||||
.setOperator(Operator.LT)
|
||||
.setThreshold(0.59)
|
||||
.setNumberSamples(5L),
|
||||
TreeNode.builder(3).setLeafValue(1.04476388).setNumberSamples(3L),
|
||||
TreeNode.builder(4).setLeafValue(1.52799228).setNumberSamples(2L),
|
||||
TreeNode.builder(5).setLeafValue(1.98006658).setNumberSamples(1L),
|
||||
TreeNode.builder(6).setLeafValue(2.950216).setNumberSamples(4L)).build();
|
||||
|
||||
Ensemble ensemble = Ensemble.builder().setOutputAggregator(new WeightedSum())
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2))
|
||||
.setFeatureNames(featureNames)
|
||||
.build();
|
||||
|
||||
|
||||
Map<String, Double> featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.0, 0.9)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.1, 0.8)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.2, 0.7)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.3, 0.6)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.4, 0.5)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.5, 0.4)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(0.0798679, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.6, 0.3)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(1.80491886, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(-0.4355742, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.7, 0.2)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.8, 0.1)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.9, 0.0)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
|
||||
}
|
||||
|
||||
|
||||
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
|
||||
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceRes
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Operator;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
@ -35,6 +36,7 @@ import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
|
||||
private final double eps = 1.0E-8;
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
@ -118,7 +120,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
List<Double> featureVector = Arrays.asList(0.6, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump
|
||||
assertThat(42.0,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
}
|
||||
|
||||
public void testInfer() {
|
||||
@ -138,27 +141,31 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
List<Double> featureVector = Arrays.asList(0.6, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(0.3,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
|
||||
// This should hit the left child of the left child of the root node
|
||||
// i.e. it takes the path left, left
|
||||
featureVector = Arrays.asList(0.3, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(0.1,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
|
||||
// This should hit the right child of the left child of the root node
|
||||
// i.e. it takes the path left, right
|
||||
featureVector = Arrays.asList(0.3, 0.9);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(0.2,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
|
||||
// This should still work if the internal values are strings
|
||||
List<String> featureVectorStrings = Arrays.asList("0.3", "0.9");
|
||||
featureMap = zipObjMap(featureNames, featureVectorStrings);
|
||||
assertThat(0.2,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
|
||||
// This should handle missing values and take the default_left path
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
@ -166,7 +173,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
put("bar", null);
|
||||
}};
|
||||
assertThat(0.1,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
}
|
||||
|
||||
public void testInferNestedFields() {
|
||||
@ -192,7 +200,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
}});
|
||||
}};
|
||||
assertThat(0.3,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
|
||||
// This should hit the left child of the left child of the root node
|
||||
// i.e. it takes the path left, left
|
||||
@ -205,7 +214,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
}});
|
||||
}};
|
||||
assertThat(0.1,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
|
||||
// This should hit the right child of the left child of the root node
|
||||
// i.e. it takes the path left, right
|
||||
@ -218,7 +228,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
}});
|
||||
}};
|
||||
assertThat(0.2,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
|
||||
0.00001));
|
||||
}
|
||||
|
||||
public void testTreeClassificationProbability() {
|
||||
@ -241,7 +252,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
List<String> expectedFields = Arrays.asList("dog", "cat");
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
List<ClassificationInferenceResults.TopClassEntry> probabilities =
|
||||
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
|
||||
.getTopClasses();
|
||||
for(int i = 0; i < expectedProbs.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
|
||||
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
|
||||
@ -252,7 +264,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
featureVector = Arrays.asList(0.3, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
probabilities =
|
||||
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
|
||||
.getTopClasses();
|
||||
for(int i = 0; i < expectedProbs.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
|
||||
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
|
||||
@ -264,7 +277,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
put("bar", null);
|
||||
}};
|
||||
probabilities =
|
||||
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
|
||||
.getTopClasses();
|
||||
for(int i = 0; i < expectedProbs.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
|
||||
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
|
||||
@ -345,6 +359,55 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
assertThat(tree.estimatedNumOperations(), equalTo(7L));
|
||||
}
|
||||
|
||||
public void testFeatureImportance() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree tree = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setNodes(
|
||||
TreeNode.builder(0)
|
||||
.setSplitFeature(0)
|
||||
.setOperator(Operator.LT)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setThreshold(0.5)
|
||||
.setNumberSamples(4L),
|
||||
TreeNode.builder(1)
|
||||
.setSplitFeature(1)
|
||||
.setLeftChild(3)
|
||||
.setRightChild(4)
|
||||
.setOperator(Operator.LT)
|
||||
.setThreshold(0.5)
|
||||
.setNumberSamples(2L),
|
||||
TreeNode.builder(2)
|
||||
.setSplitFeature(1)
|
||||
.setLeftChild(5)
|
||||
.setRightChild(6)
|
||||
.setOperator(Operator.LT)
|
||||
.setThreshold(0.5)
|
||||
.setNumberSamples(2L),
|
||||
TreeNode.builder(3).setLeafValue(3.0).setNumberSamples(1L),
|
||||
TreeNode.builder(4).setLeafValue(8.0).setNumberSamples(1L),
|
||||
TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L),
|
||||
TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build();
|
||||
|
||||
Map<String, Double> featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.25)),
|
||||
Collections.emptyMap());
|
||||
assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
|
||||
|
||||
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.75)), Collections.emptyMap());
|
||||
assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
|
||||
|
||||
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.25)), Collections.emptyMap());
|
||||
assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
|
||||
|
||||
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.75)), Collections.emptyMap());
|
||||
assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
|
||||
}
|
||||
|
||||
public void testMaxFeatureIndex() {
|
||||
|
||||
int numFeatures = randomIntBetween(1, 15);
|
||||
|
@ -115,7 +115,10 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
" \"inference\": {\n" +
|
||||
" \"target_field\": \"ml.classification\",\n" +
|
||||
" \"inference_config\": {\"classification\": " +
|
||||
" {\"num_top_classes\":2, \"top_classes_results_field\": \"result_class_prob\"}},\n" +
|
||||
" {\"num_top_classes\":2, " +
|
||||
" \"top_classes_results_field\": \"result_class_prob\"," +
|
||||
" \"num_top_feature_importance_values\": 2" +
|
||||
" }},\n" +
|
||||
" \"model_id\": \"test_classification\",\n" +
|
||||
" \"field_mappings\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
@ -153,6 +156,8 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
String responseString = EntityUtils.toString(response.getEntity());
|
||||
assertThat(responseString, containsString("\"predicted_value\":\"second\""));
|
||||
assertThat(responseString, containsString("\"predicted_value\":1.0"));
|
||||
assertThat(responseString, containsString("\"col2\":0.944"));
|
||||
assertThat(responseString, containsString("\"col1\":0.19999"));
|
||||
|
||||
String sourceWithMissingModel = "{\n" +
|
||||
" \"pipeline\": {\n" +
|
||||
@ -321,16 +326,19 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"number_samples\": 300,\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"number_samples\": 100,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"number_samples\": 200,\n" +
|
||||
" \"leaf_value\": 2\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
@ -352,15 +360,18 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"number_samples\": 150,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"number_samples\": 50,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"number_samples\": 100,\n" +
|
||||
" \"leaf_value\": 2\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
@ -445,6 +456,7 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"number_samples\": 100,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
@ -454,10 +466,12 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"number_samples\": 80,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"number_samples\": 20,\n" +
|
||||
" \"leaf_value\": 0\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
@ -476,6 +490,7 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"number_samples\": 180,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
@ -484,10 +499,12 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"number_samples\": 10,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"number_samples\": 170,\n" +
|
||||
" \"leaf_value\": 0\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
|
@ -102,6 +102,43 @@ public class InferenceProcessorTests extends ESTestCase {
|
||||
assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
|
||||
}
|
||||
|
||||
public void testMutateDocumentClassificationFeatureInfluence() {
|
||||
ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2);
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
auditor,
|
||||
"my_processor",
|
||||
"ml.my_processor",
|
||||
"classification_model",
|
||||
classificationConfig,
|
||||
Collections.emptyMap());
|
||||
|
||||
Map<String, Object> source = new HashMap<>();
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
|
||||
|
||||
Map<String, Double> featureInfluence = new HashMap<>();
|
||||
featureInfluence.put("feature_1", 1.13);
|
||||
featureInfluence.put("feature_2", -42.0);
|
||||
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0,
|
||||
"foo",
|
||||
classes,
|
||||
featureInfluence,
|
||||
classificationConfig)),
|
||||
true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testMutateDocumentClassificationTopNClassesWithSpecificField() {
|
||||
ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops");
|
||||
@ -154,6 +191,34 @@ public class InferenceProcessorTests extends ESTestCase {
|
||||
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model"));
|
||||
}
|
||||
|
||||
public void testMutateDocumentRegressionWithTopFetures() {
|
||||
RegressionConfig regressionConfig = new RegressionConfig("foo", 2);
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
auditor,
|
||||
"my_processor",
|
||||
"ml.my_processor",
|
||||
"regression_model",
|
||||
regressionConfig,
|
||||
Collections.emptyMap());
|
||||
|
||||
Map<String, Object> source = new HashMap<>();
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
Map<String, Double> featureInfluence = new HashMap<>();
|
||||
featureInfluence.put("feature_1", 1.13);
|
||||
featureInfluence.put("feature_2", -42.0);
|
||||
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7));
|
||||
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
|
||||
}
|
||||
|
||||
public void testGenerateRequestWithEmptyMapping() {
|
||||
String modelId = "model";
|
||||
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
|
||||
|
Loading…
x
Reference in New Issue
Block a user