[ML] inference performance optimizations and refactor (#57674) (#57753)

This is a major refactor of the underlying inference logic.

The main refactor is now we are separating the model configuration and
the inference interfaces.

This has the following benefits:
 - we can store extra things with the model that are not
   necessary for inference (i.e. treenode split information gain)
 - we can optimize inference separate from model serialization and storage.
 - The user is oblivious to the optimizations (other than seeing the benefits).

A major part of this commit is removing all inference related methods from the
trained model configurations (ensemble, tree, etc.) and moving them to a new class.

This new class satisfies a new interface that is ONLY for inference.

The optimizations applied currently are:
- feature maps are flattened once
- feature extraction only happens once at the highest level
  (improves inference + feature importance through put)
- Only storing what we need for inference + feature importance on heap
This commit is contained in:
Benjamin Trent 2020-06-05 14:20:58 -04:00 committed by GitHub
parent f170b52e64
commit 9666a895f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 2356 additions and 1608 deletions

View File

@ -10,6 +10,8 @@ import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContentObject;
@ -37,7 +39,7 @@ public final class InferenceToXContentCompressor {
// Either 10% of the configured JVM heap, or 1 GB, which ever is smaller
private static final long MAX_INFLATED_BYTES = Math.min(
(long)((0.10) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()),
1_000_000_000); // 1 gb maximum
new ByteSizeValue(1, ByteSizeUnit.GB).getBytes());
private InferenceToXContentCompressor() {}
@ -46,9 +48,9 @@ public final class InferenceToXContentCompressor {
return deflate(reference);
}
static <T> T inflate(String compressedString,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry) throws IOException {
public static <T> T inflate(String compressedString,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry) throws IOException {
try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry,
LoggingDeprecationHandler.INSTANCE,
inflate(compressedString, MAX_INFLATED_BYTES))) {

View File

@ -30,6 +30,9 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAgg
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.StrictlyParsedOutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
@ -119,6 +122,13 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, RegressionConfigUpdate.NAME,
RegressionConfigUpdate::fromXContentStrict));
// Inference models
namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Ensemble.NAME, EnsembleInferenceModel::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Tree.NAME, TreeInferenceModel::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class,
LangIdentNeuralNetwork.NAME,
LangIdentNeuralNetwork::fromXContentLenient));
return namedXContent;
}

View File

@ -33,6 +33,7 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@ -607,6 +608,14 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
if (input != null && input.getFieldNames().isEmpty()) {
validationException = addValidationError("[input.field_names] must not be empty", validationException);
}
if (input != null && input.getFieldNames()
.stream()
.filter(s -> s.contains("."))
.flatMap(s -> Arrays.stream(Strings.delimitedListToStringArray(s, ".")))
.anyMatch(String::isEmpty)) {
validationException = addValidationError("[input.field_names] must only contain valid dot delimited field names",
validationException);
}
if (forCreation) {
validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);

View File

@ -20,8 +20,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
@ -32,9 +30,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
public class TrainedModelDefinition implements ToXContentObject, Writeable, Accountable {
@ -73,7 +69,6 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors;
private Map<String, String> decoderMap;
private TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
@ -116,37 +111,6 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
return preProcessors;
}
void preProcess(Map<String, Object> fields) {
preProcessors.forEach(preProcessor -> preProcessor.process(fields));
}
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
preProcess(fields);
if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
throw ExceptionsHelper.badRequestException(
"Feature importance is not supported for the configured model of type [{}]",
trainedModel.getName());
}
return trainedModel.infer(fields,
config,
config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
}
private Map<String, String> getDecoderMap() {
if (decoderMap != null) {
return decoderMap;
}
synchronized (this) {
if (decoderMap != null) {
return decoderMap;
}
this.decoderMap = preProcessors.stream()
.map(PreProcessor::reverseLookup)
.collect(HashMap::new, Map::putAll, Map::putAll);
return decoderMap;
}
}
@Override
public String toString() {
return Strings.toString(this);
@ -218,14 +182,6 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
return this;
}
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
if (trainedModel.size() != 1) {
throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
TRAINED_MODEL.getPreferredName());
}
return setTrainedModel(trainedModel.get(0));
}
private void setProcessorsInOrder(boolean value) {
this.processorsInOrder = value;
}

View File

@ -14,7 +14,6 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.io.IOException;
import java.util.Collections;
@ -109,7 +108,7 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
@Override
public void process(Map<String, Object> fields) {
Object value = MapHelper.dig(field, fields);
Object value = fields.get(field);
if (value == null) {
return;
}

View File

@ -14,7 +14,6 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.io.IOException;
import java.util.Collections;
@ -94,7 +93,7 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
@Override
public void process(Map<String, Object> fields) {
Object value = MapHelper.dig(field, fields);
Object value = fields.get(field);
if (value == null) {
return;
}

View File

@ -14,7 +14,6 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.io.IOException;
import java.util.Collections;
@ -120,7 +119,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
@Override
public void process(Map<String, Object> fields) {
Object value = MapHelper.dig(field, fields);
Object value = fields.get(field);
if (value == null) {
return;
}

View File

@ -10,7 +10,6 @@ import org.elasticsearch.ingest.IngestDocument;
import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
public class RawInferenceResults implements InferenceResults {
@ -18,9 +17,9 @@ public class RawInferenceResults implements InferenceResults {
public static final String NAME = "raw";
private final double[] value;
private final Map<String, double[]> featureImportance;
private final double[][] featureImportance;
public RawInferenceResults(double[] value, Map<String, double[]> featureImportance) {
public RawInferenceResults(double[] value, double[][] featureImportance) {
this.value = value;
this.featureImportance = featureImportance;
}
@ -29,7 +28,7 @@ public class RawInferenceResults implements InferenceResults {
return value;
}
public Map<String, double[]> getFeatureImportance() {
public double[][] getFeatureImportance() {
return featureImportance;
}
@ -44,7 +43,7 @@ public class RawInferenceResults implements InferenceResults {
if (object == null || getClass() != object.getClass()) { return false; }
RawInferenceResults that = (RawInferenceResults) object;
return Arrays.equals(value, that.value)
&& Objects.equals(featureImportance, that.featureImportance);
&& Arrays.deepEquals(featureImportance, that.featureImportance);
}
@Override

View File

@ -7,29 +7,12 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
import java.util.Map;
public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accountable {
/**
* Infer against the provided fields
*
* NOTE: Must be thread safe
*
* @param fields The fields and their values to infer against
* @param config The configuration options for inference
* @param featureDecoderMap A map for decoding feature value names to their originating feature.
* Necessary for feature influence.
* @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0).
* For regression this is continuous.
*/
InferenceResults infer(Map<String, Object> fields, InferenceConfig config, @Nullable Map<String, String> featureDecoderMap);
/**
* @return {@link TargetType} for the model.
*/
@ -49,21 +32,6 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
*/
long estimatedNumOperations();
/**
* @return Does the model support feature importance
*/
boolean supportsFeatureImportance();
/**
* Calculates the importance of each feature reference by the model for the passed in field values
*
* NOTE: Must be thread safe
* @param fields The fields inferring against
* @param featureDecoder A Map translating processed feature names to their original feature names
* @return A {@code Map<String, double[]>} mapping each featureName to its importance
*/
Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
default Version getMinimalCompatibilityVersion() {
return Version.V_7_6_0;
}

View File

@ -11,21 +11,12 @@ import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
@ -37,16 +28,10 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalDouble;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance;
public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
@ -134,70 +119,11 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
}
}
@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
if (config.isTargetTypeSupported(targetType) == false) {
throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
double[][] inferenceResults = new double[this.models.size()][];
List<Map<String, double[]>> featureInfluence = new ArrayList<>();
int i = 0;
NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
for (TrainedModel model : models) {
InferenceResults result = model.infer(fields, subModelInferenceConfig, Collections.emptyMap());
assert result instanceof RawInferenceResults;
RawInferenceResults inferenceResult = (RawInferenceResults) result;
inferenceResults[i++] = inferenceResult.getValue();
if (config.requestingImportance()) {
featureInfluence.add(inferenceResult.getFeatureImportance());
}
}
double[] processed = outputAggregator.processValues(inferenceResults);
return buildResults(processed,
decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)),
config);
}
@Override
public TargetType targetType() {
return targetType;
}
private InferenceResults buildResults(double[] processedInferences,
Map<String, double[]> featureInfluence,
InferenceConfig config) {
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(
new double[] {outputAggregator.aggregate(processedInferences)},
featureInfluence);
}
switch(targetType) {
case REGRESSION:
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
config,
transformFeatureImportance(featureInfluence, null));
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
// Adjust the probabilities according to the thresholds
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
processedInferences,
classificationLabels,
classificationWeights,
classificationConfig.getNumTopClasses(),
classificationConfig.getPredictionFieldType());
return new ClassificationInferenceResults((double)topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(),
transformFeatureImportance(featureInfluence, classificationLabels),
config);
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
}
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -313,30 +239,6 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1);
}
@Override
public boolean supportsFeatureImportance() {
return models.stream().allMatch(TrainedModel::supportsFeatureImportance);
}
Map<String, double[]> featureImportance(Map<String, Object> fields) {
return featureImportance(fields, Collections.emptyMap());
}
@Override
public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
Map<String, double[]> collapsed = mergeFeatureImportances(models.stream()
.map(trainedModel -> trainedModel.featureImportance(fields, Collections.emptyMap()))
.collect(Collectors.toList()));
return decodeFeatureImportances(featureDecoder, collapsed);
}
private static Map<String, double[]> mergeFeatureImportances(List<Map<String, double[]>> featureImportances) {
return featureImportances.stream()
.collect(HashMap::new,
(a, b) -> b.forEach((k, v) -> a.merge(k, v, InferenceHelpers::sumDoubleArrays)),
Map::putAll);
}
public static Builder builder() {
return new Builder();
}

View File

@ -0,0 +1,278 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.FEATURE_NAMES;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TARGET_TYPE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS;
public class EnsembleInferenceModel implements InferenceModel {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<EnsembleInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
"ensemble_inference_model",
true,
a -> new EnsembleInferenceModel((List<String>)a[0],
(List<InferenceModel>)a[1],
(OutputAggregator)a[2],
TargetType.fromString((String)a[3]),
(List<String>)a[4],
(List<Double>)a[5]));
static {
PARSER.declareStringArray(constructorArg(), FEATURE_NAMES);
PARSER.declareNamedObjects(constructorArg(),
(p, c, n) -> p.namedObject(InferenceModel.class, n, null),
(ensembleBuilder) -> {},
TRAINED_MODELS);
PARSER.declareNamedObject(constructorArg(),
(p, c, n) -> p.namedObject(LenientlyParsedOutputAggregator.class, n, null),
AGGREGATE_OUTPUT);
PARSER.declareString(constructorArg(), TARGET_TYPE);
PARSER.declareStringArray(optionalConstructorArg(), CLASSIFICATION_LABELS);
PARSER.declareDoubleArray(optionalConstructorArg(), CLASSIFICATION_WEIGHTS);
}
public static EnsembleInferenceModel fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private String[] featureNames;
private final List<InferenceModel> models;
private final OutputAggregator outputAggregator;
private final TargetType targetType;
private final List<String> classificationLabels;
private final double[] classificationWeights;
EnsembleInferenceModel(List<String> featureNames,
List<InferenceModel> models,
OutputAggregator outputAggregator,
TargetType targetType,
List<String> classificationLabels,
List<Double> classificationWeights) {
this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]);
this.models = ExceptionsHelper.requireNonNull(models, TRAINED_MODELS);
this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
this.classificationLabels = classificationLabels == null ? null : classificationLabels;
this.classificationWeights = classificationWeights == null ?
null :
classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
}
@Override
public String[] getFeatureNames() {
return featureNames;
}
@Override
public TargetType targetType() {
return targetType;
}
private double[] getFeatures(Map<String, Object> fields) {
double[] features = new double[featureNames.length];
int i = 0;
for (String featureName : featureNames) {
Double val = InferenceHelpers.toDouble(fields.get(featureName));
features[i++] = val == null ? Double.NaN : val;
}
return features;
}
@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
return innerInfer(getFeatures(fields), config, featureDecoderMap);
}
@Override
public InferenceResults infer(double[] features, InferenceConfig config) {
return innerInfer(features, config, Collections.emptyMap());
}
private InferenceResults innerInfer(double[] features, InferenceConfig config, Map<String, String> featureDecoderMap) {
if (config.isTargetTypeSupported(targetType) == false) {
throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
double[][] inferenceResults = new double[this.models.size()][];
double[][] featureInfluence = new double[features.length][];
int i = 0;
NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
for (InferenceModel model : models) {
InferenceResults result = model.infer(features, subModelInferenceConfig);
assert result instanceof RawInferenceResults;
RawInferenceResults inferenceResult = (RawInferenceResults) result;
inferenceResults[i++] = inferenceResult.getValue();
if (config.requestingImportance()) {
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
assert modelFeatureImportance.length == featureInfluence.length;
for (int j = 0; j < modelFeatureImportance.length; j++) {
if (featureInfluence[j] == null) {
featureInfluence[j] = new double[modelFeatureImportance[j].length];
}
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
}
}
}
double[] processed = outputAggregator.processValues(inferenceResults);
return buildResults(processed, featureInfluence, featureDecoderMap, config);
}
//For testing
double[][] featureImportance(double[] features) {
double[][] featureInfluence = new double[features.length][];
NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(true);
for (InferenceModel model : models) {
InferenceResults result = model.infer(features, subModelInferenceConfig);
assert result instanceof RawInferenceResults;
RawInferenceResults inferenceResult = (RawInferenceResults) result;
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
assert modelFeatureImportance.length == featureInfluence.length;
for (int j = 0; j < modelFeatureImportance.length; j++) {
if (featureInfluence[j] == null) {
featureInfluence[j] = new double[modelFeatureImportance[j].length];
}
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
}
}
return featureInfluence;
}
private InferenceResults buildResults(double[] processedInferences,
double[][] featureImportance,
Map<String, String> featureDecoderMap,
InferenceConfig config) {
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(
new double[] {outputAggregator.aggregate(processedInferences)},
featureImportance);
}
Map<String, double[]> decodedFeatureImportance = config.requestingImportance() ?
decodeFeatureImportances(featureDecoderMap,
IntStream.range(0, featureImportance.length)
.boxed()
.collect(Collectors.toMap(i -> featureNames[i], i -> featureImportance[i]))) :
Collections.emptyMap();
switch(targetType) {
case REGRESSION:
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
config,
transformFeatureImportance(decodedFeatureImportance, null));
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
// Adjust the probabilities according to the thresholds
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
processedInferences,
classificationLabels,
classificationWeights,
classificationConfig.getNumTopClasses(),
classificationConfig.getPredictionFieldType());
return new ClassificationInferenceResults((double)topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(),
transformFeatureImportance(decodedFeatureImportance, classificationLabels),
config);
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
}
}
@Override
public boolean supportsFeatureImportance() {
return models.stream().allMatch(InferenceModel::supportsFeatureImportance);
}
@Override
public String getName() {
return "ensemble";
}
@Override
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
Set<String> referencedFeatures = subModelFeatures();
int newFeatureIndex = 0;
newFeatureIndexMapping = new HashMap<>();
this.featureNames = new String[referencedFeatures.size()];
for (String featureName : referencedFeatures) {
newFeatureIndexMapping.put(featureName, newFeatureIndex);
this.featureNames[newFeatureIndex++] = featureName;
}
} else {
this.featureNames = new String[0];
}
for (InferenceModel model : models) {
model.rewriteFeatureIndices(newFeatureIndexMapping);
}
}
private Set<String> subModelFeatures() {
Set<String> referencedFeatures = new LinkedHashSet<>();
for (InferenceModel model : models) {
if (model instanceof EnsembleInferenceModel) {
referencedFeatures.addAll(((EnsembleInferenceModel) model).subModelFeatures());
} else {
for (String featureName : model.getFeatureNames()) {
referencedFeatures.add(featureName);
}
}
}
return referencedFeatures;
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOf(featureNames);
size += RamUsageEstimator.sizeOfCollection(classificationLabels);
size += RamUsageEstimator.sizeOfCollection(models);
if (classificationWeights != null) {
size += RamUsageEstimator.sizeOf(classificationWeights);
}
size += outputAggregator.ramBytesUsed();
return size;
}
}

View File

@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition.PREPROCESSORS;
import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition.TRAINED_MODEL;
public class InferenceDefinition {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InferenceDefinition.class);
public static final String NAME = "inference_model_definition";
private final InferenceModel trainedModel;
private final List<PreProcessor> preProcessors;
private Map<String, String> decoderMap;
private static final ObjectParser<InferenceDefinition.Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
InferenceDefinition.Builder::new);
static {
PARSER.declareNamedObject(InferenceDefinition.Builder::setTrainedModel,
(p, c, n) -> p.namedObject(InferenceModel.class, n, null),
TRAINED_MODEL);
PARSER.declareNamedObjects(InferenceDefinition.Builder::setPreProcessors,
(p, c, n) -> p.namedObject(LenientlyParsedPreProcessor.class, n, null),
(trainedModelDefBuilder) -> {},
PREPROCESSORS);
}
public static InferenceDefinition fromXContent(XContentParser parser) {
return PARSER.apply(parser, null).build();
}
public InferenceDefinition(InferenceModel trainedModel, List<PreProcessor> preProcessors) {
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
}
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOf(trainedModel);
size += RamUsageEstimator.sizeOfCollection(preProcessors);
return size;
}
InferenceModel getTrainedModel() {
return trainedModel;
}
private void preProcess(Map<String, Object> fields) {
preProcessors.forEach(preProcessor -> preProcessor.process(fields));
}
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
preProcess(fields);
if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
throw ExceptionsHelper.badRequestException(
"Feature importance is not supported for the configured model of type [{}]",
trainedModel.getName());
}
return trainedModel.infer(fields,
config,
config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
}
public TargetType getTargetType() {
return this.trainedModel.targetType();
}
private Map<String, String> getDecoderMap() {
if (decoderMap != null) {
return decoderMap;
}
synchronized (this) {
if (decoderMap != null) {
return decoderMap;
}
this.decoderMap = preProcessors.stream()
.map(PreProcessor::reverseLookup)
.collect(HashMap::new, Map::putAll, Map::putAll);
return decoderMap;
}
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private List<PreProcessor> preProcessors;
private InferenceModel inferenceModel;
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
this.preProcessors = preProcessors;
return this;
}
public Builder setTrainedModel(InferenceModel trainedModel) {
this.inferenceModel = trainedModel;
return this;
}
public InferenceDefinition build() {
this.inferenceModel.rewriteFeatureIndices(Collections.emptyMap());
return new InferenceDefinition(this.inferenceModel, this.preProcessors);
}
}
}

View File

@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.util.Map;
public interface InferenceModel extends Accountable {
/**
* @return The feature names in their desired order
*/
String[] getFeatureNames();
/**
* @return {@link TargetType} for the model.
*/
TargetType targetType();
/**
* Infer against the provided fields
*
* @param fields The fields and their values to infer against
* @param config The configuration options for inference
* @param featureDecoderMap A map for decoding feature value names to their originating feature.
* Necessary for feature influence.
* @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0).
* For regression this is continuous.
*/
InferenceResults infer(Map<String, Object> fields, InferenceConfig config, @Nullable Map<String, String> featureDecoderMap);
/**
* Same as {@link InferenceModel#infer(Map, InferenceConfig, Map)} but the features are already extracted.
*/
InferenceResults infer(double[] features, InferenceConfig config);
/**
* @return Does the model support feature importance
*/
boolean supportsFeatureImportance();
String getName();
/**
* Rewrites underlying feature index mappings.
* This is to allow optimization of the underlying models.
*/
void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping);
}

View File

@ -0,0 +1,521 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.Numbers;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree.CLASSIFICATION_LABELS;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree.FEATURE_NAMES;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree.TARGET_TYPE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree.TREE_STRUCTURE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.DECISION_TYPE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.DEFAULT_LEFT;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.LEAF_VALUE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.LEFT_CHILD;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.NUMBER_SAMPLES;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.RIGHT_CHILD;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.SPLIT_FEATURE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.THRESHOLD;
public class TreeInferenceModel implements InferenceModel {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TreeInferenceModel.class);
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<TreeInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
"tree_inference_model",
true,
a -> new TreeInferenceModel((List<String>)a[0], (List<NodeBuilder>)a[1], TargetType.fromString((String)a[2]), (List<String>)a[3]));
static {
PARSER.declareStringArray(constructorArg(), FEATURE_NAMES);
PARSER.declareObjectArray(constructorArg(), NodeBuilder.PARSER::apply, TREE_STRUCTURE);
PARSER.declareString(constructorArg(), TARGET_TYPE);
PARSER.declareStringArray(optionalConstructorArg(), CLASSIFICATION_LABELS);
}
public static TreeInferenceModel fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final Node[] nodes;
private String[] featureNames;
private final TargetType targetType;
private final List<String> classificationLabels;
private final double highOrderCategory;
private final int maxDepth;
private final int leafSize;
TreeInferenceModel(List<String> featureNames, List<NodeBuilder> nodes, TargetType targetType, List<String> classificationLabels) {
this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]);
if(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) {
throw new IllegalArgumentException("[tree_structure] must not be empty");
}
this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new);
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
this.highOrderCategory = maxLeafValue();
int leafSize = 1;
for (Node node : this.nodes) {
if (node instanceof LeafNode) {
leafSize = ((LeafNode)node).leafValue.length;
break;
}
}
this.leafSize = leafSize;
this.maxDepth = getDepth(this.nodes, 0, 0);
}
@Override
public String[] getFeatureNames() {
return featureNames;
}
@Override
public TargetType targetType() {
return targetType;
}
@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
return innerInfer(getFeatures(fields), config, featureDecoderMap);
}
@Override
public InferenceResults infer(double[] features, InferenceConfig config) {
return innerInfer(features, config, Collections.emptyMap());
}
private double[] getFeatures(Map<String, Object> fields) {
double[] features = new double[featureNames.length];
int i = 0;
for (String featureName : featureNames) {
Double val = InferenceHelpers.toDouble(fields.get(featureName));
features[i++] = val == null ? Double.NaN : val;
}
return features;
}
private InferenceResults innerInfer(double[] features, InferenceConfig config, Map<String, String> featureDecoderMap) {
if (config.isTargetTypeSupported(targetType) == false) {
throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
double[][] featureImportance = config.requestingImportance() ?
featureImportance(features) :
new double[0][];
return buildResult(getLeaf(features), featureImportance, featureDecoderMap, config);
}
private InferenceResults buildResult(double[] value,
double[][] featureImportance,
Map<String, String> featureDecoderMap,
InferenceConfig config) {
assert value != null && value.length > 0;
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(value, featureImportance);
}
Map<String, double[]> decodedFeatureImportance = config.requestingImportance() ?
decodeFeatureImportances(featureDecoderMap,
IntStream.range(0, featureImportance.length)
.boxed()
.collect(Collectors.toMap(i -> featureNames[i], i -> featureImportance[i]))) :
Collections.emptyMap();
switch (targetType) {
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
classificationProbability(value),
classificationLabels,
null,
classificationConfig.getNumTopClasses(),
classificationConfig.getPredictionFieldType());
return new ClassificationInferenceResults(topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(),
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, classificationLabels),
config);
case REGRESSION:
return new RegressionInferenceResults(value[0],
config,
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, null));
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
}
}
private double[] classificationProbability(double[] inferenceValue) {
// Multi-value leaves, indicates that the leaves contain an array of values.
// The index of which corresponds to classification values
if (inferenceValue.length > 1) {
return Statistics.softMax(inferenceValue);
}
// If we are classification, we should assume that the inference return value is whole.
assert inferenceValue[0] == Math.rint(inferenceValue[0]);
double maxCategory = this.highOrderCategory;
// If we are classification, we should assume that the largest leaf value is whole.
assert maxCategory == Math.rint(maxCategory);
double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)
.stream()
.mapToDouble(Double::doubleValue)
.toArray();
list[Double.valueOf(inferenceValue[0]).intValue()] = 1.0;
return list;
}
private double[] getLeaf(double[] features) {
Node node = nodes[0];
while(node.isLeaf() == false) {
node = nodes[node.compare(features)];
}
return ((LeafNode)node).leafValue;
}
public double[][] featureImportance(double[] fieldValues) {
double[][] featureImportance = new double[fieldValues.length][leafSize];
for (int i = 0; i < fieldValues.length; i++) {
featureImportance[i] = new double[leafSize];
}
int arrSize = ((this.maxDepth + 1) * (this.maxDepth + 2))/2;
ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize];
for (int i = 0; i < arrSize; i++) {
elements[i] = new ShapPath.PathElement();
}
double[] scale = new double[arrSize];
ShapPath initialPath = new ShapPath(elements, scale);
shapRecursive(fieldValues, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
return featureImportance;
}
/**
* Note, this is a port from https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc
*
* If improvements in performance or accuracy have been found, it is probably best that the changes are implemented on the native
* side first and then ported to the Java side.
*/
private void shapRecursive(double[] processedFeatures,
ShapPath parentSplitPath,
int nodeIndex,
double parentFractionZero,
double parentFractionOne,
int parentFeatureIndex,
double[][] featureImportance,
int nextIndex) {
ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex);
Node currNode = nodes[nodeIndex];
nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
if (currNode.isLeaf()) {
double[] leafValue = ((LeafNode)currNode).leafValue;
for (int i = 1; i < nextIndex; ++i) {
int inputColumnIndex = splitPath.featureIndex(i);
double scaled = splitPath.sumUnwoundPath(i, nextIndex) * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i));
for (int j = 0; j < leafValue.length; j++) {
featureImportance[inputColumnIndex][j] += scaled * leafValue[j];
}
}
} else {
InnerNode innerNode = (InnerNode)currNode;
int hotIndex = currNode.compare(processedFeatures);
int coldIndex = hotIndex == innerNode.leftChild ? innerNode.rightChild : innerNode.leftChild;
double incomingFractionZero = 1.0;
double incomingFractionOne = 1.0;
int splitFeature = innerNode.splitFeature;
int pathIndex = splitPath.findFeatureIndex(splitFeature, nextIndex);
if (pathIndex > -1) {
incomingFractionZero = splitPath.fractionZeros(pathIndex);
incomingFractionOne = splitPath.fractionOnes(pathIndex);
nextIndex = splitPath.unwind(pathIndex, nextIndex);
}
double hotFractionZero = nodes[hotIndex].getNumberSamples() / (double)currNode.getNumberSamples();
double coldFractionZero = nodes[coldIndex].getNumberSamples() / (double)currNode.getNumberSamples();
shapRecursive(processedFeatures, splitPath,
hotIndex, incomingFractionZero * hotFractionZero,
incomingFractionOne, splitFeature, featureImportance, nextIndex);
shapRecursive(processedFeatures, splitPath,
coldIndex, incomingFractionZero * coldFractionZero,
0.0, splitFeature, featureImportance, nextIndex);
}
}
@Override
public boolean supportsFeatureImportance() {
return true;
}
@Override
public String getName() {
return "tree";
}
@Override
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
return;
}
for (Node node : nodes) {
if (node.isLeaf()) {
continue;
}
InnerNode treeNode = (InnerNode)node;
Integer newSplitFeatureIndex = newFeatureIndexMapping.get(featureNames[treeNode.splitFeature]);
if (newSplitFeatureIndex == null) {
throw new IllegalArgumentException("[tree] failed to optimize for inference");
}
treeNode.splitFeature = newSplitFeatureIndex;
}
this.featureNames = new String[0];
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOfCollection(classificationLabels);
size += RamUsageEstimator.sizeOf(featureNames);
size += RamUsageEstimator.shallowSizeOf(nodes);
for (Node node : nodes) {
size += node.ramBytesUsed();
}
size += RamUsageEstimator.sizeOfCollection(Arrays.asList(nodes));
return size;
}
private double maxLeafValue() {
if (targetType != TargetType.CLASSIFICATION) {
return Double.NaN;
}
double max = 0.0;
for (Node node : this.nodes) {
if (node instanceof LeafNode) {
LeafNode leafNode = (LeafNode) node;
if (leafNode.leafValue.length > 1) {
return (double)leafNode.leafValue.length;
} else {
max = Math.max(leafNode.leafValue[0], max);
}
}
}
return max;
}
private static int getDepth(Node[] nodes, int nodeIndex, int depth) {
Node node = nodes[nodeIndex];
if (node instanceof LeafNode) {
return 0;
}
InnerNode innerNode = (InnerNode)node;
int depthLeft = getDepth(nodes, innerNode.leftChild, depth + 1);
int depthRight = getDepth(nodes, innerNode.rightChild, depth + 1);
return Math.max(depthLeft, depthRight) + 1;
}
private static class NodeBuilder {
private static final ObjectParser<NodeBuilder, Void> PARSER = new ObjectParser<>(
"tree_inference_model_node",
true,
NodeBuilder::new);
static {
PARSER.declareDouble(NodeBuilder::setThreshold, THRESHOLD);
PARSER.declareField(NodeBuilder::setOperator,
p -> Operator.fromString(p.text()),
DECISION_TYPE,
ObjectParser.ValueType.STRING);
PARSER.declareInt(NodeBuilder::setLeftChild, LEFT_CHILD);
PARSER.declareInt(NodeBuilder::setRightChild, RIGHT_CHILD);
PARSER.declareBoolean(NodeBuilder::setDefaultLeft, DEFAULT_LEFT);
PARSER.declareInt(NodeBuilder::setSplitFeature, SPLIT_FEATURE);
PARSER.declareDoubleArray(NodeBuilder::setLeafValue, LEAF_VALUE);
PARSER.declareLong(NodeBuilder::setNumberSamples, NUMBER_SAMPLES);
}
private Operator operator = Operator.LTE;
private double threshold = Double.NaN;
private int splitFeature = -1;
private boolean defaultLeft = false;
private int leftChild = -1;
private int rightChild = -1;
private long numberSamples;
private double[] leafValue = new double[0];
public NodeBuilder setOperator(Operator operator) {
this.operator = operator;
return this;
}
public NodeBuilder setThreshold(double threshold) {
this.threshold = threshold;
return this;
}
public NodeBuilder setSplitFeature(int splitFeature) {
this.splitFeature = splitFeature;
return this;
}
public NodeBuilder setDefaultLeft(boolean defaultLeft) {
this.defaultLeft = defaultLeft;
return this;
}
public NodeBuilder setLeftChild(int leftChild) {
this.leftChild = leftChild;
return this;
}
public NodeBuilder setRightChild(int rightChild) {
this.rightChild = rightChild;
return this;
}
public NodeBuilder setNumberSamples(long numberSamples) {
this.numberSamples = numberSamples;
return this;
}
private NodeBuilder setLeafValue(List<Double> leafValue) {
return setLeafValue(leafValue.stream().mapToDouble(Double::doubleValue).toArray());
}
public NodeBuilder setLeafValue(double[] leafValue) {
this.leafValue = leafValue;
return this;
}
Node build() {
if (this.leftChild < 0) {
return new LeafNode(leafValue, numberSamples);
}
return new InnerNode(operator,
threshold,
splitFeature,
defaultLeft,
leftChild,
rightChild,
numberSamples);
}
}
private abstract static class Node implements Accountable {
int compare(double[] features) {
throw new IllegalArgumentException("cannot call compare against a leaf node.");
}
abstract long getNumberSamples();
boolean isLeaf() {
return this instanceof LeafNode;
}
}
private static class InnerNode extends Node {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InnerNode.class);
private final Operator operator;
private final double threshold;
// Allowed to be adjusted for inference optimization
private int splitFeature;
private final boolean defaultLeft;
private final int leftChild;
private final int rightChild;
private final long numberSamples;
InnerNode(Operator operator,
double threshold,
int splitFeature,
boolean defaultLeft,
int leftChild,
int rightChild,
long numberSamples) {
this.operator = operator;
this.threshold = threshold;
this.splitFeature = splitFeature;
this.defaultLeft = defaultLeft;
this.leftChild = leftChild;
this.rightChild = rightChild;
this.numberSamples = numberSamples;
}
@Override
public int compare(double[] features) {
double feature = features[splitFeature];
if (isMissing(feature)) {
return defaultLeft ? leftChild : rightChild;
}
return operator.test(feature, threshold) ? leftChild : rightChild;
}
@Override
long getNumberSamples() {
return numberSamples;
}
private static boolean isMissing(double feature) {
return Numbers.isValidDouble(feature) == false;
}
@Override
public long ramBytesUsed() {
return SHALLOW_SIZE;
}
}
private static class LeafNode extends Node {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LeafNode.class);
private final double[] leafValue;
private final long numberSamples;
LeafNode(double[] leafValue, long numberSamples) {
this.leafValue = leafValue;
this.numberSamples = numberSamples;
}
@Override
public long ramBytesUsed() {
return SHALLOW_SIZE;
}
@Override
long getNumberSamples() {
return numberSamples;
}
}
}

View File

@ -23,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTra
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModel;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -35,7 +36,7 @@ import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, LenientlyParsedTrainedModel {
public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, LenientlyParsedTrainedModel, InferenceModel {
public static final ParseField NAME = new ParseField("lang_ident_neural_network");
public static final ParseField EMBEDDED_VECTOR_FEATURE_NAME = new ParseField("embedded_vector_feature_name");
@ -148,6 +149,23 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
classificationConfig);
}
@Override
public InferenceResults infer(double[] embeddedVector, InferenceConfig config) {
throw new UnsupportedOperationException("[lang_ident] does not support nested inference");
}
@Override
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
if (newFeatureIndexMapping != null && newFeatureIndexMapping.isEmpty() == false) {
throw new UnsupportedOperationException("[lang_ident] does not support nested inference");
}
}
@Override
public String[] getFeatureNames() {
return new String[] {embeddedVectorFeatureName};
}
@Override
public TargetType targetType() {
return TargetType.CLASSIFICATION;
@ -171,11 +189,6 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
return false;
}
@Override
public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
throw new UnsupportedOperationException("[lang_ident] does not support feature importance");
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;

View File

@ -11,28 +11,15 @@ import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.CachedSupplier;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.io.IOException;
import java.util.ArrayDeque;
@ -42,14 +29,10 @@ import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Queue;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel, Accountable {
@ -89,10 +72,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
private final List<TreeNode> nodes;
private final TargetType targetType;
private final List<String> classificationLabels;
private final CachedSupplier<Double> highestOrderCategory;
// populated lazily when feature importance is calculated
private Integer maxDepth;
private Integer leafSize;
Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
@ -102,7 +81,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
this.nodes = Collections.unmodifiableList(nodes);
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
this.highestOrderCategory = new CachedSupplier<>(this::maxLeafValue);
}
public Tree(StreamInput in) throws IOException {
@ -114,7 +92,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
} else {
this.classificationLabels = null;
}
this.highestOrderCategory = new CachedSupplier<>(this::maxLeafValue);
}
@Override
@ -122,102 +99,11 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return NAME.getPreferredName();
}
public List<TreeNode> getNodes() {
return nodes;
}
@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
if (config.isTargetTypeSupported(targetType) == false) {
throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
List<Double> features = featureNames.stream()
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
.collect(Collectors.toList());
Map<String, double[]> featureImportance = config.requestingImportance() ?
featureImportance(features, featureDecoderMap) :
Collections.emptyMap();
TreeNode node = nodes.get(0);
while(node.isLeaf() == false) {
node = nodes.get(node.compare(features));
}
return buildResult(node.getLeafValue(), featureImportance, config);
}
private InferenceResults buildResult(double[] value, Map<String, double[]> featureImportance, InferenceConfig config) {
assert value != null && value.length > 0;
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(value, featureImportance);
}
switch (targetType) {
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
classificationProbability(value),
classificationLabels,
null,
classificationConfig.getNumTopClasses(),
classificationConfig.getPredictionFieldType());
return new ClassificationInferenceResults(topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(),
InferenceHelpers.transformFeatureImportance(featureImportance, classificationLabels),
config);
case REGRESSION:
return new RegressionInferenceResults(value[0],
config,
InferenceHelpers.transformFeatureImportance(featureImportance, null));
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
}
}
/**
* Trace the route predicting on the feature vector takes.
* @param features The feature vector
* @return The list of traversed nodes ordered from root to leaf
*/
public List<TreeNode> trace(List<Double> features) {
List<TreeNode> visited = new ArrayList<>();
TreeNode node = nodes.get(0);
visited.add(node);
while(node.isLeaf() == false) {
node = nodes.get(node.compare(features));
visited.add(node);
}
return visited;
}
@Override
public TargetType targetType() {
return targetType;
}
private double[] classificationProbability(double[] inferenceValue) {
// Multi-value leaves, indicates that the leaves contain an array of values.
// The index of which corresponds to classification values
if (inferenceValue.length > 1) {
return Statistics.softMax(inferenceValue);
}
// If we are classification, we should assume that the inference return value is whole.
assert inferenceValue[0] == Math.rint(inferenceValue[0]);
double maxCategory = this.highestOrderCategory.get();
// If we are classification, we should assume that the largest leaf value is whole.
assert maxCategory == Math.rint(maxCategory);
double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)
.stream()
.mapToDouble(Double::doubleValue)
.toArray();
list[Double.valueOf(inferenceValue[0]).intValue()] = 1.0;
return list;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -285,131 +171,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
verifyLeafNodeUniformity();
}
@Override
public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
if (nodes.stream().allMatch(n -> n.getNumberSamples() == 0)) {
throw ExceptionsHelper.badRequestException("[tree_structure.number_samples] must be greater than zero for feature importance");
}
List<Double> features = featureNames.stream()
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
.collect(Collectors.toList());
return featureImportance(features, featureDecoder);
}
private Map<String, double[]> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
calculateDepthAndLeafValueSize();
double[][] featureImportance = new double[fieldValues.size()][leafSize];
for (int i = 0; i < fieldValues.size(); i++) {
featureImportance[i] = new double[leafSize];
}
int arrSize = ((this.maxDepth + 1) * (this.maxDepth + 2))/2;
ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize];
for (int i = 0; i < arrSize; i++) {
elements[i] = new ShapPath.PathElement();
}
double[] scale = new double[arrSize];
ShapPath initialPath = new ShapPath(elements, scale);
shapRecursive(fieldValues, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
return InferenceHelpers.decodeFeatureImportances(featureDecoder,
IntStream.range(0, featureImportance.length)
.boxed()
.collect(Collectors.toMap(featureNames::get, i -> featureImportance[i])));
}
private void calculateDepthAndLeafValueSize() {
if (this.maxDepth != null && this.leafSize != null) {
return;
}
synchronized (this) {
if (this.maxDepth != null && this.leafSize != null) {
return;
}
this.maxDepth = getDepth(0, 0);
}
}
/**
* Note, this is a port from https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc
*
* If improvements in performance or accuracy have been found, it is probably best that the changes are implemented on the native
* side first and then ported to the Java side.
*/
private void shapRecursive(List<Double> processedFeatures,
ShapPath parentSplitPath,
int nodeIndex,
double parentFractionZero,
double parentFractionOne,
int parentFeatureIndex,
double[][] featureImportance,
int nextIndex) {
ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex);
TreeNode currNode = nodes.get(nodeIndex);
nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
if (currNode.isLeaf()) {
double[] leafValue = currNode.getLeafValue();
for (int i = 1; i < nextIndex; ++i) {
int inputColumnIndex = splitPath.featureIndex(i);
double scaled = splitPath.sumUnwoundPath(i, nextIndex) * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i));
for (int j = 0; j < leafValue.length; j++) {
featureImportance[inputColumnIndex][j] += scaled * leafValue[j];
}
}
} else {
int hotIndex = currNode.compare(processedFeatures);
int coldIndex = hotIndex == currNode.getLeftChild() ? currNode.getRightChild() : currNode.getLeftChild();
double incomingFractionZero = 1.0;
double incomingFractionOne = 1.0;
int splitFeature = currNode.getSplitFeature();
int pathIndex = splitPath.findFeatureIndex(splitFeature, nextIndex);
if (pathIndex > -1) {
incomingFractionZero = splitPath.fractionZeros(pathIndex);
incomingFractionOne = splitPath.fractionOnes(pathIndex);
nextIndex = splitPath.unwind(pathIndex, nextIndex);
}
double hotFractionZero = nodes.get(hotIndex).getNumberSamples() / (double)currNode.getNumberSamples();
double coldFractionZero = nodes.get(coldIndex).getNumberSamples() / (double)currNode.getNumberSamples();
shapRecursive(processedFeatures, splitPath,
hotIndex, incomingFractionZero * hotFractionZero,
incomingFractionOne, splitFeature, featureImportance, nextIndex);
shapRecursive(processedFeatures, splitPath,
coldIndex, incomingFractionZero * coldFractionZero,
0.0, splitFeature, featureImportance, nextIndex);
}
}
/**
* Get the depth of the tree and sets leafSize if it is null
*
* @param nodeIndex Current node index
* @param depth Current depth
* @return The current max depth
*/
private int getDepth(int nodeIndex, int depth) {
TreeNode node = nodes.get(nodeIndex);
if (node.isLeaf()) {
if (leafSize == null) {
this.leafSize = node.getLeafValue().length;
}
return 0;
}
int depthLeft = getDepth(node.getLeftChild(), depth + 1);
int depthRight = getDepth(node.getRightChild(), depth + 1);
return Math.max(depthLeft, depthRight) + 1;
}
@Override
public long estimatedNumOperations() {
// Grabbing the features from the doc + the depth of the tree
return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
}
@Override
public boolean supportsFeatureImportance() {
return true;
}
/**
* The highest index of a feature used any of the nodes.
* If no nodes use a feature return -1. This can only happen
@ -495,23 +262,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return nodeIdx >= nodes.size();
}
private Double maxLeafValue() {
if (targetType != TargetType.CLASSIFICATION) {
return null;
}
double max = 0.0;
for (TreeNode node : this.nodes) {
if (node.isLeaf()) {
if (node.getLeafValue().length > 1) {
return (double)node.getLeafValue().length;
} else {
max = Math.max(node.getLeafValue()[0], max);
}
}
}
return max;
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
@ -600,7 +350,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
* @param decisionThreshold The decision threshold
* @return The created node
*/
TreeNode.Builder addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft, double decisionThreshold) {
public TreeNode.Builder addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft, double decisionThreshold) {
int leftChild = numNodes++;
int rightChild = numNodes++;
nodes.ensureCapacity(nodeIndex + 1);
@ -630,11 +380,11 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
* @param value The prediction value
* @return this
*/
Tree.Builder addLeaf(int nodeIndex, double value) {
public Tree.Builder addLeaf(int nodeIndex, double value) {
return addLeaf(nodeIndex, Arrays.asList(value));
}
Tree.Builder addLeaf(int nodeIndex, List<Double> value) {
public Tree.Builder addLeaf(int nodeIndex, List<Double> value) {
for (int i = nodes.size(); i < nodeIndex + 1; i++) {
nodes.add(null);
}

View File

@ -129,7 +129,6 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
}
}
public Operator getOperator() {
return operator;
}
@ -174,21 +173,6 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
return numberSamples;
}
public int compare(List<Double> features) {
if (isLeaf()) {
throw new IllegalArgumentException("cannot call compare against a leaf node.");
}
Double feature = features.get(splitFeature);
if (isMissing(feature)) {
return defaultLeft ? leftChild : rightChild;
}
return operator.test(feature, threshold) ? leftChild : rightChild;
}
private boolean isMissing(Double feature) {
return feature == null;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
operator.writeTo(out);
@ -359,7 +343,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
return this;
}
Integer getLeftChild() {
public Integer getLeftChild() {
return leftChild;
}
@ -368,7 +352,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
return this;
}
Integer getRightChild() {
public Integer getRightChild() {
return rightChild;
}

View File

@ -7,8 +7,6 @@ package org.elasticsearch.xpack.core.ml.inference.utils;
import org.elasticsearch.common.Numbers;
import java.util.Arrays;
public final class Statistics {
private Statistics(){}
@ -23,7 +21,12 @@ public final class Statistics {
*/
public static double[] softMax(double[] values) {
double expSum = 0.0;
double max = Arrays.stream(values).filter(Statistics::isValid).max().orElse(Double.NaN);
double max = Double.NEGATIVE_INFINITY;
for (double val : values) {
if (isValid(val)) {
max = Math.max(max, val);
}
}
if (isValid(max) == false) {
throw new IllegalArgumentException("no valid values present");
}

View File

@ -6,8 +6,10 @@
package org.elasticsearch.xpack.core.ml.utils;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Stack;
@ -57,25 +59,51 @@ public final class MapHelper {
*
* Instead we lazily create potential paths once we know that they are possibilities.
*
* @param path Dot delimited path containing the field desired
* @param path Dot delimited path containing the field desired. Assumes that the path contains no empty strings
* @param map The {@link Map} map to dig
* @return The found object. Returns {@code null} if not found
*/
@Nullable
public static Object dig(String path, Map<String, Object> map) {
// short cut before search
if (map.keySet().contains(path)) {
return map.get(path);
}
String[] fields = path.split("\\.");
if (Arrays.stream(fields).anyMatch(String::isEmpty)) {
throw new IllegalArgumentException("Empty path detected. Invalid field name");
Object obj = map.get(path);
if (obj != null) {
return obj;
}
String[] fields = Strings.delimitedListToStringArray(path, ".");
Stack<PotentialPath> pathStack = new Stack<>();
pathStack.push(new PotentialPath(map, 0));
return explore(fields, pathStack);
}
/**
* Collapses dot delimited fields so that the map is a single layer.
*
* Example:
* {
* "a" :{"b": {"c": {"d" : 2}}}
* }
* becomes:
* {
* "a.b.c.d": 2
* }
*
* @param map The map that has nested and/or collapsed paths
* @param pathsToCollapse The desired paths to collapse
* @return A fully collapsed map
*/
public static Map<String, Object> dotCollapse(Map<String, Object> map, Collection<String> pathsToCollapse) {
// default load factor is 0.75 (3/4).
Map<String, Object> collapsed = new HashMap<>(((pathsToCollapse.size() * 4)/3) + 1);
for (String path : pathsToCollapse) {
Object dug = dig(path, map);
if (dug != null) {
collapsed.put(path, dug);
}
}
return collapsed;
}
@SuppressWarnings("unchecked")
private static Object explore(String[] path, Stack<PotentialPath> pathStack) {
while (pathStack.empty() == false) {
@ -95,8 +123,11 @@ public final class MapHelper {
}
endPos++;
}
if (candidateKey != null && map.containsKey(candidateKey)) { //exit early
return map.get(candidateKey);
if (candidateKey != null) { // exit early
Object val = map.get(candidateKey);
if (val != null) {
return val;
}
}
}

View File

@ -18,8 +18,6 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
@ -29,9 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -79,7 +75,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
return createRandomBuilder(randomFrom(TargetType.values()));
}
private static final String ENSEMBLE_MODEL = "" +
public static final String ENSEMBLE_MODEL = "" +
"{\n" +
" \"preprocessors\": [\n" +
" {\n" +
@ -199,7 +195,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
" }\n" +
"}";
private static final String TREE_MODEL = "" +
public static final String TREE_MODEL = "" +
"{\n" +
" \"preprocessors\": [\n" +
" {\n" +
@ -310,67 +306,4 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
TrainedModelDefinition test = createTestInstance();
assertThat(test.ramBytesUsed(), greaterThan(0L));
}
public void testMultiClassIrisInference() throws IOException {
// Fairly simple, random forest classification model built to fit in our format
// Trained on the well known Iris dataset
String compressedDef = "H4sIAPbiMl4C/+1b246bMBD9lVWet8jjG3b/oN9QVYgmToLEkghIL6r23wukl90" +
"YxRMGlt2WPKwEC/gYe2bOnBl+rOoyzQq3SR4OG5ev3t/9WLmicg+fc9cd1Gm5c3VSfz+2x6t1nlZVts3Wa" +
"Z0ditX93Wrr0vpUuqRIH1zVXPJxVbljmie5K3b1vr3ifPw125wPj65+9u/z8fnfn+4vh0jy9LPLzw/+UGb" +
"Vu8rVhyptb+wOv7iyytaH/FD+PZWVu6xo7u8e92x+3XOaSZVurtm1QydVXZ7W7XPPcIoGWpIVG/etOWbNR" +
"Ru3zqp28r+B5bVrH5a7bZ2s91m+aU5Cc6LMdvu/Z3gL55hndfILdnNOtGPuS1ftD901LDKs+wFYziy3j/d" +
"3FwjgKoJ0m3xJ81N7kvn3cix64aEH1gOfX8CXkVEtemFAahvz2IcgsBCkB0GhEMTKH1Ri3xn49yosYO0Bj" +
"hErDpGy3Y9JLbjSRvoQNAF+jIVvPPi2Bz67gK8iK1v0ptmsWoHoWXFDQG+x9/IeQ8Hbqm+swBGT15dr1wM" +
"CKDNA2yv0GKxE7b4+cwFBWDKQ+BlfDSgsat43tH94xD49diMtoeEVhgaN2mi6iwzMKqFjKUDPEBqCrmq6O" +
"HHd0PViMreajEEFJxlaccAi4B4CgdhzHBHdOcFqCSYTI14g2WS2z0007DfAe4Hy7DdkrI2I+9yGIhitJhh" +
"tTBjXYN+axcX1Ab7Oom2P+RgAtffDLj/A0a5vfkAbL/jWCwJHj9jT3afMzSQtQJYEhR6ibQ984+McsYQqg" +
"m4baTBKMB6LHhDo/Aj8BInDcI6q0ePG/rgMx+57hkXnU+AnVGBxCWH3zq3ijclwI/tW3lC2jSVsWM4oN1O" +
"SIc4XkjRGXjGEosylOUkUQ7AhhkBgSXYc1YvAksw4PG1kGWsAT5tOxbruOKbTnwIkSYxD1MbXsWAIUwMKz" +
"eGUeDUbRwI9Fkek5CiwqAM3Bz6NUgdUt+vBslhIo8UM6kDQac4kDiicpHfe+FwY2SQI5q3oadvnoQ3hMHE" +
"pCaHUgkqoVcRCG5aiKzCUCN03cUtJ4ikJxZTVlcWvDvarL626DiiVLH71pf0qG1y9H7mEPSQBNoTtQpFba" +
"NzfDFfXSNJqPFJBkFb/1iiNLxhSAW3u4Ns7qHHi+i1F9fmyj1vV0sDIZonP0wh+waxjLr1vOPcmxORe7n3" +
"pKOKIhVp9Rtb4+Owa3xCX/TpFPnrig6nKTNisNl8aNEKQRfQITh9kG/NhTzcvpwRZoARZvkh8S6h7Oz1zI" +
"atZeuYWk5nvC4TJ2aFFJXBCTkcO9UuQQ0qb3FXdx4xTPH6dBeApP0CQ43QejN8kd7l64jI1krMVgJfPEf7" +
"h3uq3o/K/ztZqP1QKFagz/G+t1XxwjeIFuqkRbXoTdlOTGnwCIoKZ6ku1AbrBoN6oCdX56w3UEOO0y2B9g" +
"aLbAYWcAdpeweKa2IfIT2jz5QzXxD6AoP+DrdXtxeluV7pdWrvkcKqPp7rjS19d+wp/fff/5Ez3FPjzFNy" +
"fdpTi9JB0sDp2JR7b309mn5HuPkEAAA==";
TrainedModelDefinition definition = InferenceToXContentCompressor.inflate(compressedDef,
parser -> TrainedModelDefinition.fromXContent(parser, true).build(),
xContentRegistry());
Map<String, Object> fields = new HashMap<String, Object>(){{
put("sepal_length", 5.1);
put("sepal_width", 3.5);
put("petal_length", 1.4);
put("petal_width", 0.2);
}};
assertThat(
((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(),
equalTo("Iris-setosa"));
fields = new HashMap<String, Object>(){{
put("sepal_length", 7.0);
put("sepal_width", 3.2);
put("petal_length", 4.7);
put("petal_width", 1.4);
}};
assertThat(
((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(),
equalTo("Iris-versicolor"));
fields = new HashMap<String, Object>(){{
put("sepal_length", 6.5);
put("sepal_width", 3.0);
put("petal_length", 5.2);
put("petal_width", 2.0);
}};
assertThat(
((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(),
equalTo("Iris-virginica"));
}
}

View File

@ -65,22 +65,4 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
testProcess(encoding, fieldValues, matchers);
}
public void testProcessWithNestedField() {
String field = "categorical.child";
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
v -> randomDoubleBetween(0.0, 1.0, false)));
String encodedFeatureName = "encoded";
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap);
Map<String, Object> fieldValues = new HashMap<String, Object>() {{
put("categorical", new HashMap<String, Object>(){{
put("child", "farequote");
}});
}};
encoding.process(fieldValues);
assertThat(fieldValues.get("encoded"), equalTo(valueMap.get("farequote")));
}
}

View File

@ -67,19 +67,4 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
testProcess(encoding, fieldValues, matchers);
}
public void testProcessWithNestedField() {
String field = "categorical.child";
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> "Column_" + v.toString()));
OneHotEncoding encoding = new OneHotEncoding(field, valueMap);
Map<String, Object> fieldValues = new HashMap<String, Object>() {{
put("categorical", new HashMap<String, Object>(){{
put("child", "farequote");
}});
}};
encoding.process(fieldValues);
assertThat(fieldValues.get("Column_farequote"), equalTo(1));
}
}

View File

@ -68,24 +68,4 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
testProcess(encoding, fieldValues, matchers);
}
public void testProcessWithNestedField() {
String field = "categorical.child";
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
v -> randomDoubleBetween(0.0, 1.0, false)));
String encodedFeatureName = "encoded";
Double defaultvalue = randomDouble();
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue);
Map<String, Object> fieldValues = new HashMap<String, Object>() {{
put("categorical", new HashMap<String, Object>(){{
put("child", "farequote");
}});
}};
encoding.process(fieldValues);
assertThat(fieldValues.get("encoded"), equalTo(valueMap.get("farequote")));
}
}

View File

@ -8,9 +8,6 @@ package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.test.ESTestCase;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.CoreMatchers.equalTo;
@ -22,8 +19,7 @@ public class RawInferenceResultsTests extends ESTestCase {
for (int i = 0; i < n; i++) {
results[i] = randomDouble();
}
return new RawInferenceResults(results,
randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", new double[]{1.08}));
return new RawInferenceResults(results, randomBoolean() ? new double[0][] : new double[][]{{1.08}} );
}
public void testEqualityAndHashcode() {
@ -32,11 +28,11 @@ public class RawInferenceResultsTests extends ESTestCase {
for (int i = 0; i < n; i++) {
results[i] = randomDouble();
}
Map<String, double[]> importance = randomBoolean() ?
Collections.emptyMap() :
Collections.singletonMap("foo", new double[]{1.08, 42.0});
RawInferenceResults lft = new RawInferenceResults(results, new HashMap<>(importance));
RawInferenceResults rgt = new RawInferenceResults(Arrays.copyOf(results, n), new HashMap<>(importance));
double[][] importance = randomBoolean() ?
new double[0][] :
new double[][]{{1.08, 42.0}};
RawInferenceResults lft = new RawInferenceResults(results, importance);
RawInferenceResults rgt = new RawInferenceResults(Arrays.copyOf(results, n), importance);
assertThat(lft, equalTo(rgt));
assertThat(lft.hashCode(), equalTo(rgt.hashCode()));
}

View File

@ -13,34 +13,24 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
private final double eps = 1.0E-8;
private boolean lenient;
@ -71,6 +61,10 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
public static Ensemble createRandom(TargetType targetType) {
int numberOfFeatures = randomIntBetween(1, 10);
List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList());
return createRandom(targetType, featureNames);
}
public static Ensemble createRandom(TargetType targetType, List<String> featureNames) {
int numberOfModels = randomIntBetween(1, 10);
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
.limit(numberOfModels)
@ -221,394 +215,6 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
assertThat(ex.getMessage(), equalTo("[trained_models] must not be empty"));
}
public void testClassificationProbability() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.0))
.addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.0))
.addNode(TreeNode.builder(2).setLeafValue(1.0))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(1)
.setThreshold(1.0))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2).setLeafValue(0.0))
.build();
Ensemble ensemble = Ensemble.builder()
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 2))
.setClassificationWeights(Arrays.asList(0.7, 0.3))
.build();
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
List<Double> expected = Arrays.asList(0.768524783, 0.231475216);
List<Double> scores = Arrays.asList(0.230557435, 0.162032651);
double eps = 0.000001;
List<ClassificationInferenceResults.TopClassEntry> probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
}
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
expected = Arrays.asList(0.310025518, 0.6899744811);
scores = Arrays.asList(0.217017863, 0.2069923443);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
}
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
expected = Arrays.asList(0.768524783, 0.231475216);
scores = Arrays.asList(0.230557435, 0.162032651);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
}
// This should handle missing values and take the default_left path
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
expected = Arrays.asList(0.6899744811, 0.3100255188);
scores = Arrays.asList(0.482982136, 0.0930076556);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
}
}
public void testClassificationInference() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.0))
.addNode(TreeNode.builder(4).setLeafValue(1.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.0))
.addNode(TreeNode.builder(2).setLeafValue(1.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(1)
.setThreshold(1.0))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2).setLeafValue(0.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Ensemble ensemble = Ensemble.builder()
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 2))
.build();
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
assertThat(0.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
}
public void testMultiClassClassificationInference() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(2.0))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.0))
.addNode(TreeNode.builder(4).setLeafValue(1.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(1)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(2.0))
.addNode(TreeNode.builder(2).setLeafValue(1.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(1)
.setThreshold(2.0))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2).setLeafValue(0.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Ensemble ensemble = Ensemble.builder()
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 3))
.build();
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(2.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.6);
put("bar", null);
}};
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
}
public void testRegressionInference() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.3))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.1))
.addNode(TreeNode.builder(4).setLeafValue(0.2)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(1.5))
.addNode(TreeNode.builder(2).setLeafValue(0.9))
.build();
Ensemble ensemble = Ensemble.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2))
.setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5}))
.build();
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.9,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.5,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
// Test with NO aggregator supplied, verifies default behavior of non-weighted sum
ensemble = Ensemble.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2))
.build();
featureVector = Arrays.asList(0.4, 0.0);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.8,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
assertThat(1.8,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
}
public void testInferNestedFields() {
List<String> featureNames = Arrays.asList("foo.baz", "bar.biz");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.3))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.1))
.addNode(TreeNode.builder(4).setLeafValue(0.2)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(1.5))
.addNode(TreeNode.builder(2).setLeafValue(0.9))
.build();
Ensemble ensemble = Ensemble.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2))
.setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5}))
.build();
Map<String, Object> featureMap = new HashMap<String, Object>() {{
put("foo", new HashMap<String, Object>(){{
put("baz", 0.4);
}});
put("bar", new HashMap<String, Object>(){{
put("biz", 0.0);
}});
}};
assertThat(0.9,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureMap = new HashMap<String, Object>() {{
put("foo", new HashMap<String, Object>(){{
put("baz", 2.0);
}});
put("bar", new HashMap<String, Object>(){{
put("biz", 0.7);
}});
}};
assertThat(0.5,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
}
public void testOperationsEstimations() {
Tree tree1 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 2);
Tree tree2 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
@ -621,115 +227,4 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
assertThat(ensemble.estimatedNumOperations(), equalTo(9L));
}
public void testFeatureImportance() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(
TreeNode.builder(0)
.setSplitFeature(0)
.setOperator(Operator.LT)
.setLeftChild(1)
.setRightChild(2)
.setThreshold(0.55)
.setNumberSamples(10L),
TreeNode.builder(1)
.setSplitFeature(0)
.setLeftChild(3)
.setRightChild(4)
.setOperator(Operator.LT)
.setThreshold(0.41)
.setNumberSamples(6L),
TreeNode.builder(2)
.setSplitFeature(1)
.setLeftChild(5)
.setRightChild(6)
.setOperator(Operator.LT)
.setThreshold(0.25)
.setNumberSamples(4L),
TreeNode.builder(3).setLeafValue(1.18230136).setNumberSamples(5L),
TreeNode.builder(4).setLeafValue(1.98006658).setNumberSamples(1L),
TreeNode.builder(5).setLeafValue(3.25350885).setNumberSamples(3L),
TreeNode.builder(6).setLeafValue(2.42384369).setNumberSamples(1L)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(
TreeNode.builder(0)
.setSplitFeature(0)
.setOperator(Operator.LT)
.setLeftChild(1)
.setRightChild(2)
.setThreshold(0.45)
.setNumberSamples(10L),
TreeNode.builder(1)
.setSplitFeature(0)
.setLeftChild(3)
.setRightChild(4)
.setOperator(Operator.LT)
.setThreshold(0.25)
.setNumberSamples(5L),
TreeNode.builder(2)
.setSplitFeature(0)
.setLeftChild(5)
.setRightChild(6)
.setOperator(Operator.LT)
.setThreshold(0.59)
.setNumberSamples(5L),
TreeNode.builder(3).setLeafValue(1.04476388).setNumberSamples(3L),
TreeNode.builder(4).setLeafValue(1.52799228).setNumberSamples(2L),
TreeNode.builder(5).setLeafValue(1.98006658).setNumberSamples(1L),
TreeNode.builder(6).setLeafValue(2.950216).setNumberSamples(4L)).build();
Ensemble ensemble = Ensemble.builder().setOutputAggregator(new WeightedSum())
.setTrainedModels(Arrays.asList(tree1, tree2))
.setFeatureNames(featureNames)
.build();
Map<String, double[]> featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.0, 0.9)));
assertThat(featureImportance.get("foo")[0], closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.1, 0.8)));
assertThat(featureImportance.get("foo")[0], closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.2, 0.7)));
assertThat(featureImportance.get("foo")[0], closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.3, 0.6)));
assertThat(featureImportance.get("foo")[0], closeTo(-1.16997162, eps));
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.4, 0.5)));
assertThat(featureImportance.get("foo")[0], closeTo(-1.16997162, eps));
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.5, 0.4)));
assertThat(featureImportance.get("foo")[0], closeTo(0.0798679, eps));
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.6, 0.3)));
assertThat(featureImportance.get("foo")[0], closeTo(1.80491886, eps));
assertThat(featureImportance.get("bar")[0], closeTo(-0.4355742, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.7, 0.2)));
assertThat(featureImportance.get("foo")[0], closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar")[0], closeTo(0.1451914, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.8, 0.1)));
assertThat(featureImportance.get("foo")[0], closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar")[0], closeTo(0.1451914, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.9, 0.0)));
assertThat(featureImportance.get("foo")[0], closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar")[0], closeTo(0.1451914, eps));
}
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}
}

View File

@ -0,0 +1,514 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
import static org.hamcrest.Matchers.closeTo;
public class EnsembleInferenceModelTests extends ESTestCase {
private final double eps = 1.0E-8;
public static EnsembleInferenceModel serializeFromTrainedModel(Ensemble ensemble) throws IOException {
NamedXContentRegistry registry = new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return deserializeFromTrainedModel(ensemble,
registry,
EnsembleInferenceModel::fromXContent);
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}
public void testClassificationProbability() throws IOException {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.0))
.addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.0))
.addNode(TreeNode.builder(2).setLeafValue(1.0))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(1)
.setThreshold(1.0))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2).setLeafValue(0.0))
.build();
Ensemble ensembleObject = Ensemble.builder()
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 2))
.setClassificationWeights(Arrays.asList(0.7, 0.3))
.build();
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
xContentRegistry(),
EnsembleInferenceModel::fromXContent);
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
List<Double> expected = Arrays.asList(0.768524783, 0.231475216);
List<Double> scores = Arrays.asList(0.230557435, 0.162032651);
double eps = 0.000001;
List<ClassificationInferenceResults.TopClassEntry> probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
}
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
expected = Arrays.asList(0.310025518, 0.6899744811);
scores = Arrays.asList(0.217017863, 0.2069923443);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
}
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
expected = Arrays.asList(0.768524783, 0.231475216);
scores = Arrays.asList(0.230557435, 0.162032651);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
}
// This should handle missing values and take the default_left path
featureMap = new HashMap<String, Object>(2, 1.0f) {{
put("foo", 0.3);
put("bar", null);
}};
expected = Arrays.asList(0.6899744811, 0.3100255188);
scores = Arrays.asList(0.482982136, 0.0930076556);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
}
}
public void testClassificationInference() throws IOException {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.0))
.addNode(TreeNode.builder(4).setLeafValue(1.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.0))
.addNode(TreeNode.builder(2).setLeafValue(1.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(1)
.setThreshold(1.0))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2).setLeafValue(0.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Ensemble ensembleObject = Ensemble.builder()
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 2))
.build();
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
xContentRegistry(),
EnsembleInferenceModel::fromXContent);
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureMap = new HashMap<String, Object>(2, 1.0f) {{
put("foo", 0.3);
put("bar", null);
}};
assertThat(0.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
}
public void testMultiClassClassificationInference() throws IOException {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(2.0))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.0))
.addNode(TreeNode.builder(4).setLeafValue(1.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(1)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(2.0))
.addNode(TreeNode.builder(2).setLeafValue(1.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(1)
.setThreshold(2.0))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2).setLeafValue(0.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Ensemble ensembleObject = Ensemble.builder()
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 3))
.build();
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
xContentRegistry(),
EnsembleInferenceModel::fromXContent);
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(2.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
featureMap = new HashMap<String, Object>(2, 1.0f) {{
put("foo", 0.6);
put("bar", null);
}};
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
0.00001));
}
public void testRegressionInference() throws IOException {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.3))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.1))
.addNode(TreeNode.builder(4).setLeafValue(0.2)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(1.5))
.addNode(TreeNode.builder(2).setLeafValue(0.9))
.build();
Ensemble ensembleObject = Ensemble.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2))
.setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5}))
.build();
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
xContentRegistry(),
EnsembleInferenceModel::fromXContent);
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.9,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.5,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
// Test with NO aggregator supplied, verifies default behavior of non-weighted sum
ensembleObject = Ensemble.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2))
.build();
ensemble = deserializeFromTrainedModel(ensembleObject,
xContentRegistry(),
EnsembleInferenceModel::fromXContent);
featureVector = Arrays.asList(0.4, 0.0);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.8,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
featureMap = new HashMap<String, Object>(2, 1.0f) {{
put("foo", 0.3);
put("bar", null);
}};
assertThat(1.8,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
.value(),
0.00001));
}
public void testFeatureImportance() throws IOException {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(
TreeNode.builder(0)
.setSplitFeature(0)
.setOperator(Operator.LT)
.setLeftChild(1)
.setRightChild(2)
.setThreshold(0.55)
.setNumberSamples(10L),
TreeNode.builder(1)
.setSplitFeature(0)
.setLeftChild(3)
.setRightChild(4)
.setOperator(Operator.LT)
.setThreshold(0.41)
.setNumberSamples(6L),
TreeNode.builder(2)
.setSplitFeature(1)
.setLeftChild(5)
.setRightChild(6)
.setOperator(Operator.LT)
.setThreshold(0.25)
.setNumberSamples(4L),
TreeNode.builder(3).setLeafValue(1.18230136).setNumberSamples(5L),
TreeNode.builder(4).setLeafValue(1.98006658).setNumberSamples(1L),
TreeNode.builder(5).setLeafValue(3.25350885).setNumberSamples(3L),
TreeNode.builder(6).setLeafValue(2.42384369).setNumberSamples(1L)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(
TreeNode.builder(0)
.setSplitFeature(0)
.setOperator(Operator.LT)
.setLeftChild(1)
.setRightChild(2)
.setThreshold(0.45)
.setNumberSamples(10L),
TreeNode.builder(1)
.setSplitFeature(0)
.setLeftChild(3)
.setRightChild(4)
.setOperator(Operator.LT)
.setThreshold(0.25)
.setNumberSamples(5L),
TreeNode.builder(2)
.setSplitFeature(0)
.setLeftChild(5)
.setRightChild(6)
.setOperator(Operator.LT)
.setThreshold(0.59)
.setNumberSamples(5L),
TreeNode.builder(3).setLeafValue(1.04476388).setNumberSamples(3L),
TreeNode.builder(4).setLeafValue(1.52799228).setNumberSamples(2L),
TreeNode.builder(5).setLeafValue(1.98006658).setNumberSamples(1L),
TreeNode.builder(6).setLeafValue(2.950216).setNumberSamples(4L)).build();
Ensemble ensembleObject = Ensemble.builder().setOutputAggregator(new WeightedSum((double[])null))
.setTrainedModels(Arrays.asList(tree1, tree2))
.setFeatureNames(featureNames)
.build();
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
xContentRegistry(),
EnsembleInferenceModel::fromXContent);
double[][] featureImportance = ensemble.featureImportance(new double[]{0.0, 0.9});
assertThat(featureImportance[0][0], closeTo(-1.653200025, eps));
assertThat(featureImportance[1][0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(new double[]{0.1, 0.8});
assertThat(featureImportance[0][0], closeTo(-1.653200025, eps));
assertThat(featureImportance[1][0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(new double[]{0.2, 0.7});
assertThat(featureImportance[0][0], closeTo(-1.653200025, eps));
assertThat(featureImportance[1][0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(new double[]{0.3, 0.6});
assertThat(featureImportance[0][0], closeTo(-1.16997162, eps));
assertThat(featureImportance[1][0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(new double[]{0.4, 0.5});
assertThat(featureImportance[0][0], closeTo(-1.16997162, eps));
assertThat(featureImportance[1][0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(new double[]{0.5, 0.4});
assertThat(featureImportance[0][0], closeTo(0.0798679, eps));
assertThat(featureImportance[1][0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(new double[]{0.6, 0.3});
assertThat(featureImportance[0][0], closeTo(1.80491886, eps));
assertThat(featureImportance[1][0], closeTo(-0.4355742, eps));
featureImportance = ensemble.featureImportance(new double[]{0.7, 0.2});
assertThat(featureImportance[0][0], closeTo(2.0538184, eps));
assertThat(featureImportance[1][0], closeTo(0.1451914, eps));
featureImportance = ensemble.featureImportance(new double[]{0.8, 0.1});
assertThat(featureImportance[0][0], closeTo(2.0538184, eps));
assertThat(featureImportance[1][0], closeTo(0.1451914, eps));
featureImportance = ensemble.featureImportance(new double[]{0.9, 0.0});
assertThat(featureImportance[0][0], closeTo(2.0538184, eps));
assertThat(featureImportance[1][0], closeTo(0.1451914, eps));
}
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}
}

View File

@ -0,0 +1,267 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests.ENSEMBLE_MODEL;
import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests.TREE_MODEL;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
public class InferenceDefinitionTests extends ESTestCase {
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}
public void testEnsembleSchemaDeserialization() throws IOException {
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, ENSEMBLE_MODEL);
InferenceDefinition definition = InferenceDefinition.fromXContent(parser);
assertThat(definition.getTrainedModel().getClass(), equalTo(EnsembleInferenceModel.class));
}
public void testTreeSchemaDeserialization() throws IOException {
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, TREE_MODEL);
InferenceDefinition definition = InferenceDefinition.fromXContent(parser);
assertThat(definition.getTrainedModel().getClass(), equalTo(TreeInferenceModel.class));
}
public void testMultiClassIrisInference() throws IOException {
// Fairly simple, random forest classification model built to fit in our format
// Trained on the well known Iris dataset
String compressedDef = "H4sIAPbiMl4C/+1b246bMBD9lVWet8jjG3b/oN9QVYgmToLEkghIL6r23wukl90" +
"YxRMGlt2WPKwEC/gYe2bOnBl+rOoyzQq3SR4OG5ev3t/9WLmicg+fc9cd1Gm5c3VSfz+2x6t1nlZVts3Wa" +
"Z0ditX93Wrr0vpUuqRIH1zVXPJxVbljmie5K3b1vr3ifPw125wPj65+9u/z8fnfn+4vh0jy9LPLzw/+UGb" +
"Vu8rVhyptb+wOv7iyytaH/FD+PZWVu6xo7u8e92x+3XOaSZVurtm1QydVXZ7W7XPPcIoGWpIVG/etOWbNR" +
"Ru3zqp28r+B5bVrH5a7bZ2s91m+aU5Cc6LMdvu/Z3gL55hndfILdnNOtGPuS1ftD901LDKs+wFYziy3j/d" +
"3FwjgKoJ0m3xJ81N7kvn3cix64aEH1gOfX8CXkVEtemFAahvz2IcgsBCkB0GhEMTKH1Ri3xn49yosYO0Bj" +
"hErDpGy3Y9JLbjSRvoQNAF+jIVvPPi2Bz67gK8iK1v0ptmsWoHoWXFDQG+x9/IeQ8Hbqm+swBGT15dr1wM" +
"CKDNA2yv0GKxE7b4+cwFBWDKQ+BlfDSgsat43tH94xD49diMtoeEVhgaN2mi6iwzMKqFjKUDPEBqCrmq6O" +
"HHd0PViMreajEEFJxlaccAi4B4CgdhzHBHdOcFqCSYTI14g2WS2z0007DfAe4Hy7DdkrI2I+9yGIhitJhh" +
"tTBjXYN+axcX1Ab7Oom2P+RgAtffDLj/A0a5vfkAbL/jWCwJHj9jT3afMzSQtQJYEhR6ibQ984+McsYQqg" +
"m4baTBKMB6LHhDo/Aj8BInDcI6q0ePG/rgMx+57hkXnU+AnVGBxCWH3zq3ijclwI/tW3lC2jSVsWM4oN1O" +
"SIc4XkjRGXjGEosylOUkUQ7AhhkBgSXYc1YvAksw4PG1kGWsAT5tOxbruOKbTnwIkSYxD1MbXsWAIUwMKz" +
"eGUeDUbRwI9Fkek5CiwqAM3Bz6NUgdUt+vBslhIo8UM6kDQac4kDiicpHfe+FwY2SQI5q3oadvnoQ3hMHE" +
"pCaHUgkqoVcRCG5aiKzCUCN03cUtJ4ikJxZTVlcWvDvarL626DiiVLH71pf0qG1y9H7mEPSQBNoTtQpFba" +
"NzfDFfXSNJqPFJBkFb/1iiNLxhSAW3u4Ns7qHHi+i1F9fmyj1vV0sDIZonP0wh+waxjLr1vOPcmxORe7n3" +
"pKOKIhVp9Rtb4+Owa3xCX/TpFPnrig6nKTNisNl8aNEKQRfQITh9kG/NhTzcvpwRZoARZvkh8S6h7Oz1zI" +
"atZeuYWk5nvC4TJ2aFFJXBCTkcO9UuQQ0qb3FXdx4xTPH6dBeApP0CQ43QejN8kd7l64jI1krMVgJfPEf7" +
"h3uq3o/K/ztZqP1QKFagz/G+t1XxwjeIFuqkRbXoTdlOTGnwCIoKZ6ku1AbrBoN6oCdX56w3UEOO0y2B9g" +
"aLbAYWcAdpeweKa2IfIT2jz5QzXxD6AoP+DrdXtxeluV7pdWrvkcKqPp7rjS19d+wp/fff/5Ez3FPjzFNy" +
"fdpTi9JB0sDp2JR7b309mn5HuPkEAAA==";
InferenceDefinition definition = InferenceToXContentCompressor.inflate(compressedDef,
InferenceDefinition::fromXContent,
xContentRegistry());
Map<String, Object> fields = new HashMap<String, Object>(){{
put("sepal_length", 5.1);
put("sepal_width", 3.5);
put("petal_length", 1.4);
put("petal_width", 0.2);
}};
assertThat(
((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(),
equalTo("Iris-setosa"));
fields = new HashMap<String, Object>(){{
put("sepal_length", 7.0);
put("sepal_width", 3.2);
put("petal_length", 4.7);
put("petal_width", 1.4);
}};
assertThat(
((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(),
equalTo("Iris-versicolor"));
fields = new HashMap<String, Object>(){{
put("sepal_length", 6.5);
put("sepal_width", 3.0);
put("petal_length", 5.2);
put("petal_width", 2.0);
}};
assertThat(
((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
.getClassificationLabel(),
equalTo("Iris-virginica"));
}
public void testComplexInferenceDefinitionInfer() throws IOException {
XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(CLASSIFICATION_DEFINITION),
XContentType.JSON);
InferenceDefinition inferenceDefinition = InferenceDefinition.fromXContent(parser);
ClassificationConfig config = new ClassificationConfig(2, null, null, 2, null);
Map<String, Object> featureMap = new HashMap<>();
featureMap.put("col1", "female");
featureMap.put("col2", "M");
featureMap.put("col3", "none");
featureMap.put("col4", 10);
ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
assertThat(results.valueAsString(), equalTo("second"));
assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2"));
assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001));
assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1"));
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
}
public static final String CLASSIFICATION_DEFINITION = "{" +
" \"preprocessors\": [\n" +
" {\n" +
" \"one_hot_encoding\": {\n" +
" \"field\": \"col1\",\n" +
" \"hot_map\": {\n" +
" \"male\": \"col1_male\",\n" +
" \"female\": \"col1_female\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"target_mean_encoding\": {\n" +
" \"field\": \"col2\",\n" +
" \"feature_name\": \"col2_encoded\",\n" +
" \"target_map\": {\n" +
" \"S\": 5.0,\n" +
" \"M\": 10.0,\n" +
" \"L\": 20\n" +
" },\n" +
" \"default_value\": 5.0\n" +
" }\n" +
" },\n" +
" {\n" +
" \"frequency_encoding\": {\n" +
" \"field\": \"col3\",\n" +
" \"feature_name\": \"col3_encoded\",\n" +
" \"frequency_map\": {\n" +
" \"none\": 0.75,\n" +
" \"true\": 0.10,\n" +
" \"false\": 0.15\n" +
" }\n" +
" }\n" +
" }\n" +
" ],\n" +
" \"trained_model\": {\n" +
" \"ensemble\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"aggregate_output\": {\n" +
" \"weighted_mode\": {\n" +
" \"num_classes\": \"2\",\n" +
" \"weights\": [\n" +
" 0.5,\n" +
" 0.5\n" +
" ]\n" +
" }\n" +
" },\n" +
" \"target_type\": \"classification\",\n" +
" \"classification_labels\": [\"first\", \"second\"],\n" +
" \"trained_models\": [\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"number_samples\": 100,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"number_samples\": 80,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"number_samples\": 20,\n" +
" \"leaf_value\": 0\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" },\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"number_samples\": 180,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"number_samples\": 10,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"number_samples\": 170,\n" +
" \"leaf_value\": 0\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" }\n" +
" ]\n" +
" }\n" +
" }\n" +
"}";
}

View File

@ -0,0 +1,40 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import java.io.IOException;
import static org.elasticsearch.common.xcontent.ToXContent.EMPTY_PARAMS;
final class InferenceModelTestUtils {
static <T extends TrainedModel, U extends InferenceModel> U deserializeFromTrainedModel(
T trainedModel,
NamedXContentRegistry registry,
CheckedFunction<XContentParser, U, IOException> parser) throws IOException {
try(XContentBuilder builder = trainedModel.toXContent(XContentFactory.jsonBuilder(), EMPTY_PARAMS);
XContentParser xContentParser = XContentType.JSON
.xContent()
.createParser(registry,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
BytesReference.bytes(builder).streamInput())) {
return parser.apply(xContentParser);
}
}
}

View File

@ -0,0 +1,228 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
public class TreeInferenceModelTests extends ESTestCase {
private final double eps = 1.0E-8;
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}
public void testInferWithStump() throws IOException {
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
builder.setRoot(TreeNode.builder(0).setLeafValue(Collections.singletonList(42.0)));
builder.setFeatureNames(Collections.emptyList());
Tree treeObject = builder.build();
TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject,
xContentRegistry(),
TreeInferenceModel::fromXContent);
List<String> featureNames = Arrays.asList("foo", "bar");
List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump
assertThat(42.0,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
}
public void testInfer() throws IOException {
// Build a tree with 2 nodes and 3 leaves using 2 features
// The leaves have unique values 0.1, 0.2, 0.3
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
builder.addLeaf(rootNode.getRightChild(), 0.3);
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
builder.addLeaf(leftChildNode.getLeftChild(), 0.1);
builder.addLeaf(leftChildNode.getRightChild(), 0.2);
List<String> featureNames = Arrays.asList("foo", "bar");
Tree treeObject = builder.setFeatureNames(featureNames).build();
TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject,
xContentRegistry(),
TreeInferenceModel::fromXContent);
// This feature vector should hit the right child of the root node
List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.3,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the left child of the left child of the root node
// i.e. it takes the path left, left
featureVector = Arrays.asList(0.3, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the right child of the left child of the root node
// i.e. it takes the path left, right
featureVector = Arrays.asList(0.3, 0.9);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should still work if the internal values are strings
List<String> featureVectorStrings = Arrays.asList("0.3", "0.9");
featureMap = zipObjMap(featureNames, featureVectorStrings);
assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should handle missing values and take the default_left path
featureMap = new HashMap<String, Object>(2, 1.0f) {{
put("foo", 0.3);
put("bar", null);
}};
assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
}
public void testTreeClassificationProbability() throws IOException {
// Build a tree with 2 nodes and 3 leaves using 2 features
// The leaves have unique values 0.1, 0.2, 0.3
Tree.Builder builder = Tree.builder().setTargetType(TargetType.CLASSIFICATION);
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
builder.addLeaf(rootNode.getRightChild(), 1.0);
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
builder.addLeaf(leftChildNode.getLeftChild(), 1.0);
builder.addLeaf(leftChildNode.getRightChild(), 0.0);
List<String> featureNames = Arrays.asList("foo", "bar");
Tree treeObject = builder.setFeatureNames(featureNames).setClassificationLabels(Arrays.asList("cat", "dog")).build();
TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject,
xContentRegistry(),
TreeInferenceModel::fromXContent);
double eps = 0.000001;
// This feature vector should hit the right child of the root node
List<Double> featureVector = Arrays.asList(0.6, 0.0);
List<Double> expectedProbs = Arrays.asList(1.0, 0.0);
List<String> expectedFields = Arrays.asList("dog", "cat");
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
List<ClassificationInferenceResults.TopClassEntry> probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
}
// This should hit the left child of the left child of the root node
// i.e. it takes the path left, left
featureVector = Arrays.asList(0.3, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
}
// This should handle missing values and take the default_left path
featureMap = new HashMap<String, Object>(2, 1.0f) {{
put("foo", 0.3);
put("bar", null);
}};
probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
}
}
public void testFeatureImportance() throws IOException {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree treeObject = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(
TreeNode.builder(0)
.setSplitFeature(0)
.setOperator(Operator.LT)
.setLeftChild(1)
.setRightChild(2)
.setThreshold(0.5)
.setNumberSamples(4L),
TreeNode.builder(1)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4)
.setOperator(Operator.LT)
.setThreshold(0.5)
.setNumberSamples(2L),
TreeNode.builder(2)
.setSplitFeature(1)
.setLeftChild(5)
.setRightChild(6)
.setOperator(Operator.LT)
.setThreshold(0.5)
.setNumberSamples(2L),
TreeNode.builder(3).setLeafValue(3.0).setNumberSamples(1L),
TreeNode.builder(4).setLeafValue(8.0).setNumberSamples(1L),
TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L),
TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build();
TreeInferenceModel tree = InferenceModelTestUtils.deserializeFromTrainedModel(treeObject,
xContentRegistry(),
TreeInferenceModel::fromXContent);
double[][] featureImportance = tree.featureImportance(new double[]{0.25, 0.25});
assertThat(featureImportance[0][0], closeTo(-5.0, eps));
assertThat(featureImportance[1][0], closeTo(-2.5, eps));
featureImportance = tree.featureImportance(new double[]{0.25, 0.75});
assertThat(featureImportance[0][0], closeTo(-5.0, eps));
assertThat(featureImportance[1][0], closeTo(2.5, eps));
featureImportance = tree.featureImportance(new double[]{0.75, 0.25});
assertThat(featureImportance[0][0], closeTo(5.0, eps));
assertThat(featureImportance[1][0], closeTo(-2.5, eps));
featureImportance = tree.featureImportance(new double[]{0.75, 0.75});
assertThat(featureImportance[0][0], closeTo(5.0, eps));
assertThat(featureImportance[1][0], closeTo(2.5, eps));
}
private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}
}

View File

@ -12,11 +12,7 @@ import org.elasticsearch.xpack.core.ml.job.config.Operator;
import org.junit.Before;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.hamcrest.Matchers.equalTo;
public class TreeNodeTests extends AbstractSerializingTestCase<TreeNode> {
@ -82,21 +78,4 @@ public class TreeNodeTests extends AbstractSerializingTestCase<TreeNode> {
protected Writeable.Reader<TreeNode> instanceReader() {
return TreeNode::new;
}
public void testCompare() {
expectThrows(IllegalArgumentException.class,
() -> createRandomLeafNode(randomDouble()).compare(Collections.singletonList(randomDouble())));
List<Double> featureValues = Arrays.asList(0.1, null);
assertThat(createRandom(0, 2, 3, 0.0, 0, null).build().compare(featureValues),
equalTo(3));
assertThat(createRandom(0, 2, 3, 0.0, 0, Operator.GT).build().compare(featureValues),
equalTo(2));
assertThat(createRandom(0, 2, 3, 0.2, 0, null).build().compare(featureValues),
equalTo(2));
assertThat(createRandom(0, 2, 3, 0.0, 1, null).setDefaultLeft(true).build().compare(featureValues),
equalTo(2));
assertThat(createRandom(0, 2, 3, 0.0, 1, null).setDefaultLeft(false).build().compare(featureValues),
equalTo(3));
}
}

View File

@ -10,33 +10,22 @@ import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
public class TreeTests extends AbstractSerializingTestCase<Tree> {
private final double eps = 1.0E-8;
private boolean lenient;
@Before
@ -117,181 +106,6 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
return Tree::new;
}
public void testInferWithStump() {
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
builder.setRoot(TreeNode.builder(0).setLeafValue(Collections.singletonList(42.0)));
builder.setFeatureNames(Collections.emptyList());
Tree tree = builder.build();
List<String> featureNames = Arrays.asList("foo", "bar");
List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump
assertThat(42.0,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
}
public void testInfer() {
// Build a tree with 2 nodes and 3 leaves using 2 features
// The leaves have unique values 0.1, 0.2, 0.3
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
builder.addLeaf(rootNode.getRightChild(), 0.3);
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
builder.addLeaf(leftChildNode.getLeftChild(), 0.1);
builder.addLeaf(leftChildNode.getRightChild(), 0.2);
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree = builder.setFeatureNames(featureNames).build();
// This feature vector should hit the right child of the root node
List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.3,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the left child of the left child of the root node
// i.e. it takes the path left, left
featureVector = Arrays.asList(0.3, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the right child of the left child of the root node
// i.e. it takes the path left, right
featureVector = Arrays.asList(0.3, 0.9);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should still work if the internal values are strings
List<String> featureVectorStrings = Arrays.asList("0.3", "0.9");
featureMap = zipObjMap(featureNames, featureVectorStrings);
assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should handle missing values and take the default_left path
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
}
public void testInferNestedFields() {
// Build a tree with 2 nodes and 3 leaves using 2 features
// The leaves have unique values 0.1, 0.2, 0.3
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
builder.addLeaf(rootNode.getRightChild(), 0.3);
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
builder.addLeaf(leftChildNode.getLeftChild(), 0.1);
builder.addLeaf(leftChildNode.getRightChild(), 0.2);
List<String> featureNames = Arrays.asList("foo.baz", "bar.biz");
Tree tree = builder.setFeatureNames(featureNames).build();
// This feature vector should hit the right child of the root node
Map<String, Object> featureMap = new HashMap<String, Object>() {{
put("foo", new HashMap<String, Object>(){{
put("baz", 0.6);
}});
put("bar", new HashMap<String, Object>(){{
put("biz", 0.0);
}});
}};
assertThat(0.3,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the left child of the left child of the root node
// i.e. it takes the path left, left
featureMap = new HashMap<String, Object>() {{
put("foo", new HashMap<String, Object>(){{
put("baz", 0.3);
}});
put("bar", new HashMap<String, Object>(){{
put("biz", 0.7);
}});
}};
assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
// This should hit the right child of the left child of the root node
// i.e. it takes the path left, right
featureMap = new HashMap<String, Object>() {{
put("foo", new HashMap<String, Object>(){{
put("baz", 0.3);
}});
put("bar", new HashMap<String, Object>(){{
put("biz", 0.9);
}});
}};
assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
0.00001));
}
public void testTreeClassificationProbability() {
// Build a tree with 2 nodes and 3 leaves using 2 features
// The leaves have unique values 0.1, 0.2, 0.3
Tree.Builder builder = Tree.builder().setTargetType(TargetType.CLASSIFICATION);
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
builder.addLeaf(rootNode.getRightChild(), 1.0);
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
builder.addLeaf(leftChildNode.getLeftChild(), 1.0);
builder.addLeaf(leftChildNode.getRightChild(), 0.0);
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree = builder.setFeatureNames(featureNames).setClassificationLabels(Arrays.asList("cat", "dog")).build();
double eps = 0.000001;
// This feature vector should hit the right child of the root node
List<Double> featureVector = Arrays.asList(0.6, 0.0);
List<Double> expectedProbs = Arrays.asList(1.0, 0.0);
List<String> expectedFields = Arrays.asList("dog", "cat");
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
List<ClassificationInferenceResults.TopClassEntry> probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
}
// This should hit the left child of the left child of the root node
// i.e. it takes the path left, left
featureVector = Arrays.asList(0.3, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
}
// This should handle missing values and take the default_left path
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
.getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
}
}
public void testTreeWithNullRoot() {
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
() -> Tree.builder()
@ -366,55 +180,6 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
assertThat(tree.estimatedNumOperations(), equalTo(7L));
}
public void testFeatureImportance() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree = Tree.builder()
.setFeatureNames(featureNames)
.setNodes(
TreeNode.builder(0)
.setSplitFeature(0)
.setOperator(Operator.LT)
.setLeftChild(1)
.setRightChild(2)
.setThreshold(0.5)
.setNumberSamples(4L),
TreeNode.builder(1)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4)
.setOperator(Operator.LT)
.setThreshold(0.5)
.setNumberSamples(2L),
TreeNode.builder(2)
.setSplitFeature(1)
.setLeftChild(5)
.setRightChild(6)
.setOperator(Operator.LT)
.setThreshold(0.5)
.setNumberSamples(2L),
TreeNode.builder(3).setLeafValue(3.0).setNumberSamples(1L),
TreeNode.builder(4).setLeafValue(8.0).setNumberSamples(1L),
TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L),
TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build();
Map<String, double[]> featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.25)),
Collections.emptyMap());
assertThat(featureImportance.get("foo")[0], closeTo(-5.0, eps));
assertThat(featureImportance.get("bar")[0], closeTo(-2.5, eps));
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.75)), Collections.emptyMap());
assertThat(featureImportance.get("foo")[0], closeTo(-5.0, eps));
assertThat(featureImportance.get("bar")[0], closeTo(2.5, eps));
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.25)), Collections.emptyMap());
assertThat(featureImportance.get("foo")[0], closeTo(5.0, eps));
assertThat(featureImportance.get("bar")[0], closeTo(-2.5, eps));
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.75)), Collections.emptyMap());
assertThat(featureImportance.get("foo")[0], closeTo(5.0, eps));
assertThat(featureImportance.get("bar")[0], closeTo(2.5, eps));
}
public void testMaxFeatureIndex() {
int numFeatures = randomIntBetween(1, 15);
@ -487,7 +252,5 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
.validate();
}
private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}
}

View File

@ -7,17 +7,39 @@ package org.elasticsearch.xpack.core.ml.utils;
import org.elasticsearch.test.ESTestCase;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
public class MapHelperTests extends ESTestCase {
public void testCollapseFields() {
Map<String, Object> map = new HashMap<>();
map.put("a", Collections.singletonMap("b.c", 1));
map.put("d", Collections.singletonMap("e", Collections.singletonMap("f", 2)));
map.put("g.h.i", 3);
{
assertThat(MapHelper.dotCollapse(map, Collections.emptyList()), is(anEmptyMap()));
}
{
Map<String, Object> collapsed = MapHelper.dotCollapse(map, Arrays.asList("a.b.c", "d.e.f", "g.h.i", "m.i.s.s.i.n.g"));
assertThat(collapsed, hasEntry("a.b.c", 1));
assertThat(collapsed, hasEntry("d.e.f", 2));
assertThat(collapsed, hasEntry("g.h.i", 3));
assertThat(collapsed, not(hasKey("m.i.s.s.i.n.g")));
}
}
public void testAbsolutePathStringAsKey() {
String path = "a.b.c.d";
Map<String, Object> map = Collections.singletonMap(path, 2);

View File

@ -34,6 +34,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinitionTests.CLASSIFICATION_DEFINITION;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.CoreMatchers.containsString;
@ -512,133 +513,6 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"definition\": " + REGRESSION_DEFINITION +
"}";
private static final String CLASSIFICATION_DEFINITION = "{" +
" \"preprocessors\": [\n" +
" {\n" +
" \"one_hot_encoding\": {\n" +
" \"field\": \"col1\",\n" +
" \"hot_map\": {\n" +
" \"male\": \"col1_male\",\n" +
" \"female\": \"col1_female\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"target_mean_encoding\": {\n" +
" \"field\": \"col2\",\n" +
" \"feature_name\": \"col2_encoded\",\n" +
" \"target_map\": {\n" +
" \"S\": 5.0,\n" +
" \"M\": 10.0,\n" +
" \"L\": 20\n" +
" },\n" +
" \"default_value\": 5.0\n" +
" }\n" +
" },\n" +
" {\n" +
" \"frequency_encoding\": {\n" +
" \"field\": \"col3\",\n" +
" \"feature_name\": \"col3_encoded\",\n" +
" \"frequency_map\": {\n" +
" \"none\": 0.75,\n" +
" \"true\": 0.10,\n" +
" \"false\": 0.15\n" +
" }\n" +
" }\n" +
" }\n" +
" ],\n" +
" \"trained_model\": {\n" +
" \"ensemble\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"aggregate_output\": {\n" +
" \"weighted_mode\": {\n" +
" \"num_classes\": \"2\",\n" +
" \"weights\": [\n" +
" 0.5,\n" +
" 0.5\n" +
" ]\n" +
" }\n" +
" },\n" +
" \"target_type\": \"classification\",\n" +
" \"classification_labels\": [\"first\", \"second\"],\n" +
" \"trained_models\": [\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"number_samples\": 100,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"number_samples\": 80,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"number_samples\": 20,\n" +
" \"leaf_value\": 0\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" },\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"number_samples\": 180,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"number_samples\": 10,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"number_samples\": 170,\n" +
" \"leaf_value\": 0\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" }\n" +
" ]\n" +
" }\n" +
" }\n" +
"}";
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());

View File

@ -6,12 +6,12 @@
package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
@ -30,7 +30,7 @@ import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WA
public class LocalModel implements Model {
private final TrainedModelDefinition trainedModelDefinition;
private final InferenceDefinition trainedModelDefinition;
private final String modelId;
private final String nodeId;
private final Set<String> fieldNames;
@ -43,7 +43,7 @@ public class LocalModel implements Model {
public LocalModel(String modelId,
String nodeId,
TrainedModelDefinition trainedModelDefinition,
InferenceDefinition trainedModelDefinition,
TrainedModelInput input,
Map<String, String> defaultFieldMap,
InferenceConfig modelInferenceConfig,
@ -75,7 +75,7 @@ public class LocalModel implements Model {
@Override
public String getResultsType() {
switch (trainedModelDefinition.getTrainedModel().targetType()) {
switch (trainedModelDefinition.getTargetType()) {
case CLASSIFICATION:
return ClassificationInferenceResults.NAME;
case REGRESSION:
@ -83,7 +83,7 @@ public class LocalModel implements Model {
default:
throw ExceptionsHelper.badRequestException("Model [{}] has unsupported target type [{}]",
modelId,
trainedModelDefinition.getTrainedModel().targetType());
trainedModelDefinition.getTargetType());
}
}
@ -111,10 +111,12 @@ public class LocalModel implements Model {
statsAccumulator.incInference();
currentInferenceCount.increment();
// Needs to happen before collapse as defaultFieldMap might resolve fields to their appropriate name
Model.mapFieldsIfNecessary(fields, defaultFieldMap);
Map<String, Object> flattenedFields = MapHelper.dotCollapse(fields, fieldNames);
boolean shouldPersistStats = ((currentInferenceCount.sum() + 1) % persistenceQuotient == 0);
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
if (flattenedFields.isEmpty()) {
statsAccumulator.incMissingFields();
if (shouldPersistStats) {
persistStats(false);
@ -122,7 +124,7 @@ public class LocalModel implements Model {
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
return;
}
InferenceResults inferenceResults = trainedModelDefinition.infer(fields, update.apply(inferenceConfig));
InferenceResults inferenceResults = trainedModelDefinition.infer(flattenedFields, update.apply(inferenceConfig));
if (shouldPersistStats) {
persistStats(false);
}

View File

@ -32,6 +32,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
@ -39,7 +40,6 @@ import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
@ -94,7 +94,6 @@ public class ModelLoadingService implements ClusterStateListener {
private final ThreadPool threadPool;
private final InferenceAuditor auditor;
private final ByteSizeValue maxCacheSize;
private final NamedXContentRegistry namedXContentRegistry;
private final String localNode;
public ModelLoadingService(TrainedModelProvider trainedModelProvider,
@ -111,7 +110,6 @@ public class ModelLoadingService implements ClusterStateListener {
this.auditor = auditor;
this.modelStatsService = modelStatsService;
this.shouldNotAudit = new HashSet<>();
this.namedXContentRegistry = namedXContentRegistry;
this.localModelCache = CacheBuilder.<String, LocalModel>builder()
.setMaximumWeight(this.maxCacheSize.getBytes())
.weigher((id, localModel) -> localModel.ramBytesUsed())
@ -151,16 +149,17 @@ public class ModelLoadingService implements ClusterStateListener {
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
// by a simulated pipeline
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
provider.getTrainedModel(modelId, true, ActionListener.wrap(
trainedModelConfig -> {
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry);
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
configAndInferenceDef -> {
TrainedModelConfig trainedModelConfig = configAndInferenceDef.v1();
InferenceDefinition inferenceDefinition = configAndInferenceDef.v2();
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) :
inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) :
trainedModelConfig.getInferenceConfig();
modelActionListener.onResponse(new LocalModel(
trainedModelConfig.getModelId(),
localNode,
trainedModelConfig.getModelDefinition(),
inferenceDefinition,
trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap(),
inferenceConfig,
@ -206,10 +205,10 @@ public class ModelLoadingService implements ClusterStateListener {
}
private void loadModel(String modelId) {
provider.getTrainedModel(modelId, true, ActionListener.wrap(
trainedModelConfig -> {
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
configAndInferenceDef -> {
logger.debug(() -> new ParameterizedMessage("[{}] successfully loaded model", modelId));
handleLoadSuccess(modelId, trainedModelConfig);
handleLoadSuccess(modelId, configAndInferenceDef);
},
failure -> {
logger.warn(new ParameterizedMessage("[{}] failed to load model", modelId), failure);
@ -218,16 +217,17 @@ public class ModelLoadingService implements ClusterStateListener {
));
}
private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelConfig) throws IOException {
private void handleLoadSuccess(String modelId,
Tuple<TrainedModelConfig, InferenceDefinition> configAndInferenceDef) {
Queue<ActionListener<Model>> listeners;
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry);
TrainedModelConfig trainedModelConfig = configAndInferenceDef.v1();
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) :
trainedModelConfig.getInferenceConfig();
LocalModel loadedModel = new LocalModel(
trainedModelConfig.getModelId(),
localNode,
trainedModelConfig.getModelDefinition(),
configAndInferenceDef.v2(),
trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap(),
inferenceConfig,

View File

@ -34,7 +34,6 @@ import org.elasticsearch.common.Numbers;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
@ -43,9 +42,9 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.mapper.MapperService;
@ -65,10 +64,13 @@ import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
@ -230,6 +232,69 @@ public class TrainedModelProvider {
executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest.request(), bulkResponseActionListener);
}
public void getTrainedModelForInference(final String modelId,
final ActionListener<Tuple<TrainedModelConfig, InferenceDefinition>> listener) {
// TODO Change this when we get more than just langIdent stored
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
try {
TrainedModelConfig config = loadModelFromResource(modelId, false).ensureParsedDefinition(xContentRegistry);
assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork;
listener.onResponse(Tuple.tuple(
config,
InferenceDefinition.builder()
.setPreProcessors(config.getModelDefinition().getPreProcessors())
.setTrainedModel((LangIdentNeuralNetwork)config.getModelDefinition().getTrainedModel())
.build()));
return;
} catch (ElasticsearchException|IOException ex) {
listener.onFailure(ex);
return;
}
}
getTrainedModel(modelId, false, ActionListener.wrap(
config -> {
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
.boolQuery()
.filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(),
TrainedModelDefinitionDoc.NAME))))
.setSize(MAX_NUM_DEFINITION_DOCS)
// First find the latest index
.addSort("_index", SortOrder.DESC)
// Then, sort by doc_num
.addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName())
.order(SortOrder.ASC)
.unmappedType("long"))
.request();
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
searchResponse -> {
List<TrainedModelDefinitionDoc> docs = handleHits(searchResponse.getHits().getHits(),
modelId,
this::parseModelDefinitionDocLenientlyFromSource);
String compressedString = docs.stream()
.map(TrainedModelDefinitionDoc::getCompressedString)
.collect(Collectors.joining());
if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) {
listener.onFailure(ExceptionsHelper.serverError(
Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
return;
}
InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate(
compressedString,
InferenceDefinition::fromXContent,
xContentRegistry);
listener.onResponse(Tuple.tuple(config, inferenceDefinition));
},
listener::onFailure
));
},
listener::onFailure
));
}
public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
@ -633,27 +698,18 @@ public class TrainedModelProvider {
throw new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId));
}
try {
BytesReference bytes = Streams.readFully(getClass()
.getResourceAsStream(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT));
try (XContentParser parser =
XContentHelper.createParser(xContentRegistry,
LoggingDeprecationHandler.INSTANCE,
bytes,
XContentType.JSON)) {
TrainedModelConfig.Builder builder = TrainedModelConfig.fromXContent(parser, true);
if (nullOutDefinition) {
builder.clearDefinition();
}
return builder.build();
} catch (IOException ioEx) {
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx);
throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId);
try (XContentParser parser = JsonXContent.jsonXContent.createParser(
xContentRegistry,
LoggingDeprecationHandler.INSTANCE,
getClass().getResourceAsStream(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT))) {
TrainedModelConfig.Builder builder = TrainedModelConfig.fromXContent(parser, true);
if (nullOutDefinition) {
builder.clearDefinition();
}
} catch (IOException ex) {
String msg = new ParameterizedMessage("[{}] failed to read model as resource", modelId).getFormattedMessage();
logger.error(msg, ex);
throw ExceptionsHelper.serverError(msg, ex);
return builder.build();
} catch (IOException ioEx) {
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx);
throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId);
}
}
@ -724,9 +780,16 @@ public class TrainedModelProvider {
if (item.getResponse().getHits().getHits().length == 0) {
throw new ResourceNotFoundException(resourceId);
}
List<T> results = new ArrayList<>(item.getResponse().getHits().getHits().length);
String initialIndex = item.getResponse().getHits().getHits()[0].getIndex();
for (SearchHit hit : item.getResponse().getHits().getHits()) {
return handleHits(item.getResponse().getHits().getHits(), resourceId, parseLeniently);
}
private static <T> List<T> handleHits(SearchHit[] hits,
String resourceId,
CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {
List<T> results = new ArrayList<>(hits.length);
String initialIndex = hits[0].getIndex();
for (SearchHit hit : hits) {
// We don't want to spread across multiple backing indices
if (hit.getIndex().equals(initialIndex)) {
results.add(parseLeniently.apply(hit.getSourceRef(), resourceId));

View File

@ -8,7 +8,6 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
@ -17,7 +16,6 @@ import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceRes
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
@ -28,18 +26,22 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.mockito.ArgumentMatcher;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModelTests.serializeFromTrainedModel;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
@ -58,13 +60,13 @@ public class LocalModelTests extends ESTestCase {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical");
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
InferenceDefinition definition = InferenceDefinition.builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(false))
.setTrainedModel(buildClassificationInference(false))
.build();
Model model = new LocalModel(modelId,
LocalModel model = new LocalModel(modelId,
"test-node",
definition,
new TrainedModelInput(inputFields),
@ -73,7 +75,7 @@ public class LocalModelTests extends ESTestCase {
modelStatsService);
Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0);
put("field.bar", 0.5);
put("field", Collections.singletonMap("bar", 0.5));
put("categorical", "dog");
}};
@ -89,9 +91,9 @@ public class LocalModelTests extends ESTestCase {
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(1L));
// Test with labels
definition = new TrainedModelDefinition.Builder()
definition = InferenceDefinition.builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(true))
.setTrainedModel(buildClassificationInference(true))
.build();
model = new LocalModel(modelId,
"test-node",
@ -130,9 +132,9 @@ public class LocalModelTests extends ESTestCase {
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical");
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
InferenceDefinition definition = InferenceDefinition.builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(true))
.setTrainedModel(buildClassificationInference(true))
.build();
Model model = new LocalModel(modelId,
@ -186,11 +188,11 @@ public class LocalModelTests extends ESTestCase {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
InferenceDefinition trainedModelDefinition = InferenceDefinition.builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildRegression())
.setTrainedModel(buildRegressionInference())
.build();
Model model = new LocalModel("regression_model",
LocalModel model = new LocalModel("regression_model",
"test-node",
trainedModelDefinition,
new TrainedModelInput(inputFields),
@ -212,9 +214,9 @@ public class LocalModelTests extends ESTestCase {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
InferenceDefinition trainedModelDefinition = InferenceDefinition.builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildRegression())
.setTrainedModel(buildRegressionInference())
.build();
Model model = new LocalModel(
"regression_model",
@ -242,9 +244,9 @@ public class LocalModelTests extends ESTestCase {
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
InferenceDefinition definition = InferenceDefinition.builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(false))
.setTrainedModel(buildClassificationInference(false))
.build();
Model model = new LocalModel(modelId,
@ -277,16 +279,16 @@ public class LocalModelTests extends ESTestCase {
}), anyBoolean());
}
private static <T extends InferenceConfig> SingleValueInferenceResults getSingleValue(Model model,
Map<String, Object> fields,
InferenceConfigUpdate config)
private static SingleValueInferenceResults getSingleValue(Model model,
Map<String, Object> fields,
InferenceConfigUpdate config)
throws Exception {
return (SingleValueInferenceResults)getInferenceResult(model, fields, config);
}
private static <T extends InferenceConfig> InferenceResults getInferenceResult(Model model,
Map<String, Object> fields,
InferenceConfigUpdate config) throws Exception {
private static InferenceResults getInferenceResult(Model model,
Map<String, Object> fields,
InferenceConfigUpdate config) throws Exception {
PlainActionFuture<InferenceResults> future = new PlainActionFuture<>();
model.infer(fields, config, future);
return future.get();
@ -299,6 +301,10 @@ public class LocalModelTests extends ESTestCase {
return oneHotEncoding;
}
public static InferenceModel buildClassificationInference(boolean includeLables) throws IOException {
return serializeFromTrainedModel((Ensemble)buildClassification(includeLables));
}
public static TrainedModel buildClassification(boolean includeLabels) {
List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
Tree tree1 = Tree.builder()
@ -345,6 +351,10 @@ public class LocalModelTests extends ESTestCase {
.build();
}
public static InferenceModel buildRegressionInference() throws IOException {
return serializeFromTrainedModel((Ensemble)buildRegression());
}
public static TrainedModel buildRegression() {
List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
Tree tree1 = Tree.builder()

View File

@ -19,6 +19,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.ByteSizeValue;
@ -34,10 +35,10 @@ import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
@ -130,9 +131,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
assertThat(future.get(), is(not(nullValue())));
}
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), any());
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any());
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model3), any());
// Test invalidate cache for model3
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
@ -143,10 +144,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
assertThat(future.get(), is(not(nullValue())));
}
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), any());
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any());
// It is not referenced, so called eagerly
verify(trainedModelProvider, times(4)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelProvider, times(4)).getTrainedModelForInference(eq(model3), any());
}
public void testMaxCachedLimitReached() throws Exception {
@ -179,9 +180,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
// the loading occurred or which models are currently in the cache due to evictions.
// Verify that we have at least loaded all three
assertBusy(() -> {
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), any());
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any());
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model3), any());
});
// all models loaded put in the cache
@ -198,10 +199,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
// Depending on the order the models were first loaded in the first step
// models 1 & 2 may have been evicted by model 3 in which case they have
// been loaded at most twice
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, atMost(2)).getTrainedModelForInference(eq(model1), any());
verify(trainedModelProvider, atMost(2)).getTrainedModelForInference(eq(model2), any());
// Only loaded requested once on the initial load from the change event
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model3), any());
// model 3 has been loaded and evicted exactly once
verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@ -217,7 +218,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
modelLoadingService.getModel(model3, future3);
assertThat(future3.get(), is(not(nullValue())));
}
verify(trainedModelProvider, times(2)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelProvider, times(2)).getTrainedModelForInference(eq(model3), any());
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
@ -238,7 +239,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
modelLoadingService.getModel(model1, future1);
assertThat(future1.get(), is(not(nullValue())));
}
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model1), any());
verify(trainedModelStatsService, times(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
@ -252,7 +253,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
modelLoadingService.getModel(model2, future2);
assertThat(future2.get(), is(not(nullValue())));
}
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model2), any());
// Test invalidate cache for model3
// Now both model 1 and 2 should fit in cache without issues
@ -264,9 +265,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
assertThat(future.get(), is(not(nullValue())));
}
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, times(5)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model1), any());
verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model2), any());
verify(trainedModelProvider, times(5)).getTrainedModelForInference(eq(model3), any());
}
@ -291,7 +292,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
assertThat(future.get(), is(not(nullValue())));
}
verify(trainedModelProvider, times(10)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, times(10)).getTrainedModelForInference(eq(model1), any());
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
}
@ -319,7 +320,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model)));
}
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(true), any());
verify(trainedModelProvider, atMost(2)).getTrainedModelForInference(eq(model), any());
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
}
@ -365,36 +366,34 @@ public class ModelLoadingServiceTests extends ESTestCase {
assertThat(future.get(), is(not(nullValue())));
}
verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(true), any());
verify(trainedModelProvider, times(3)).getTrainedModelForInference(eq(model), any());
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
}
@SuppressWarnings("unchecked")
private void withTrainedModel(String modelId, long size) throws IOException {
TrainedModelDefinition definition = mock(TrainedModelDefinition.class);
private void withTrainedModel(String modelId, long size) {
InferenceDefinition definition = mock(InferenceDefinition.class);
when(definition.ramBytesUsed()).thenReturn(size);
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
when(trainedModelConfig.getModelDefinition()).thenReturn(definition);
when(trainedModelConfig.getModelId()).thenReturn(modelId);
when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS);
when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz")));
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
listener.onResponse(trainedModelConfig);
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
listener.onResponse(Tuple.tuple(trainedModelConfig, definition));
return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(true), any());
doAnswer(invocationOnMock -> trainedModelConfig).when(trainedModelConfig).ensureParsedDefinition(any());
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
}
private void withMissingModel(String modelId) {
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(true), any());
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
}
private static ClusterChangedEvent ingestChangedEvent(String... modelId) throws IOException {

View File

@ -14,6 +14,8 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@ -36,6 +38,10 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
config.ensureParsedDefinition(xContentRegistry());
TrainedModelDefinition trainedModelDefinition = config.getModelDefinition();
InferenceDefinition inferenceDefinition = new InferenceDefinition(
(LangIdentNeuralNetwork)trainedModelDefinition.getTrainedModel(),
trainedModelDefinition.getPreProcessors()
);
List<LanguageExamples.LanguageExampleEntry> examples = new LanguageExamples().getLanguageExamples();
ClassificationConfig classificationConfig = new ClassificationConfig(1);
@ -47,7 +53,7 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
Map<String, Object> inferenceFields = new HashMap<>();
inferenceFields.put("text", text);
ClassificationInferenceResults singleValueInferenceResults =
(ClassificationInferenceResults) trainedModelDefinition.infer(inferenceFields, classificationConfig);
(ClassificationInferenceResults) inferenceDefinition.infer(inferenceFields, classificationConfig);
assertThat(singleValueInferenceResults.valueAsString(), equalTo(cld3Actual));
double eps = entry.getLanguage().equals("hr") ? 0.001 : 0.00001;