[7.x] [ML][Inference] Add support for multi-value leaves to the tree model (#52531) (#52901)

* [ML][Inference] Add support for multi-value leaves to the tree model (#52531)

This adds support for multi-value leaves. This is a prerequisite for multi-class boosted tree classification.
This commit is contained in:
Benjamin Trent 2020-02-27 14:05:28 -05:00 committed by GitHub
parent 710a9ead69
commit 19a6c5d980
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 576 additions and 198 deletions

View File

@ -30,6 +30,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@ -225,7 +226,7 @@ public class Tree implements TrainedModel {
for (int i = nodes.size(); i < nodeIndex + 1; i++) {
nodes.add(null);
}
nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(value));
nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(Collections.singletonList(value)));
return this;
}

View File

@ -27,6 +27,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public class TreeNode implements ToXContentObject {
@ -61,7 +62,7 @@ public class TreeNode implements ToXContentObject {
PARSER.declareInt(Builder::setSplitFeature, SPLIT_FEATURE);
PARSER.declareInt(Builder::setNodeIndex, NODE_INDEX);
PARSER.declareDouble(Builder::setSplitGain, SPLIT_GAIN);
PARSER.declareDouble(Builder::setLeafValue, LEAF_VALUE);
PARSER.declareDoubleArray(Builder::setLeafValue, LEAF_VALUE);
PARSER.declareLong(Builder::setNumberSamples, NUMBER_SAMPLES);
}
@ -74,7 +75,7 @@ public class TreeNode implements ToXContentObject {
private final Integer splitFeature;
private final int nodeIndex;
private final Double splitGain;
private final Double leafValue;
private final List<Double> leafValue;
private final Boolean defaultLeft;
private final Integer leftChild;
private final Integer rightChild;
@ -86,7 +87,7 @@ public class TreeNode implements ToXContentObject {
Integer splitFeature,
int nodeIndex,
Double splitGain,
Double leafValue,
List<Double> leafValue,
Boolean defaultLeft,
Integer leftChild,
Integer rightChild,
@ -123,7 +124,7 @@ public class TreeNode implements ToXContentObject {
return splitGain;
}
public Double getLeafValue() {
public List<Double> getLeafValue() {
return leafValue;
}
@ -212,7 +213,7 @@ public class TreeNode implements ToXContentObject {
private Integer splitFeature;
private int nodeIndex;
private Double splitGain;
private Double leafValue;
private List<Double> leafValue;
private Boolean defaultLeft;
private Integer leftChild;
private Integer rightChild;
@ -250,7 +251,7 @@ public class TreeNode implements ToXContentObject {
return this;
}
public Builder setLeafValue(Double leafValue) {
public Builder setLeafValue(List<Double> leafValue) {
this.leafValue = leafValue;
return this;
}

View File

@ -23,6 +23,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.Collections;
public class TreeNodeTests extends AbstractXContentTestCase<TreeNode> {
@ -48,7 +49,7 @@ public class TreeNodeTests extends AbstractXContentTestCase<TreeNode> {
public static TreeNode createRandomLeafNode(double internalValue) {
return TreeNode.builder(randomInt(100))
.setDefaultLeft(randomBoolean() ? null : randomBoolean())
.setLeafValue(internalValue)
.setLeafValue(Collections.singletonList(internalValue))
.setNumberSamples(randomNonNegativeLong())
.build();
}
@ -60,7 +61,7 @@ public class TreeNodeTests extends AbstractXContentTestCase<TreeNode> {
Integer featureIndex,
Operator operator) {
return TreeNode.builder(nodeIndex)
.setLeafValue(left == null ? randomDouble() : null)
.setLeafValue(left == null ? Collections.singletonList(randomDouble()) : null)
.setDefaultLeft(randomBoolean() ? null : randomBoolean())
.setLeftChild(left)
.setRightChild(right)

View File

@ -5,29 +5,37 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.ingest.IngestDocument;
import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
public class RawInferenceResults extends SingleValueInferenceResults {
public class RawInferenceResults implements InferenceResults {
public static final String NAME = "raw";
public RawInferenceResults(double value, Map<String, Double> featureImportance) {
super(value, featureImportance);
private final double[] value;
private final Map<String, Double> featureImportance;
public RawInferenceResults(double[] value, Map<String, Double> featureImportance) {
this.value = value;
this.featureImportance = featureImportance;
}
public RawInferenceResults(StreamInput in) throws IOException {
super(in);
public double[] getValue() {
return value;
}
public Map<String, Double> getFeatureImportance() {
return featureImportance;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
throw new UnsupportedOperationException("[raw] does not support wire serialization");
}
@Override
@ -35,13 +43,13 @@ public class RawInferenceResults extends SingleValueInferenceResults {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RawInferenceResults that = (RawInferenceResults) object;
return Objects.equals(value(), that.value())
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
return Arrays.equals(value, that.value)
&& Objects.equals(featureImportance, that.featureImportance);
}
@Override
public int hashCode() {
return Objects.hash(value(), getFeatureImportance());
return Objects.hash(Arrays.hashCode(value), featureImportance);
}
@Override

View File

@ -26,30 +26,29 @@ public final class InferenceHelpers {
/**
* @return Tuple of the highest scored index and the top classes
*/
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(List<Double> probabilities,
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(double[] probabilities,
List<String> classificationLabels,
@Nullable double[] classificationWeights,
int numToInclude) {
if (classificationLabels != null && probabilities.size() != classificationLabels.size()) {
if (classificationLabels != null && probabilities.length != classificationLabels.size()) {
throw ExceptionsHelper
.serverError(
"model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]",
null,
probabilities.size(),
probabilities.length,
classificationLabels.size());
}
List<Double> scores = classificationWeights == null ?
double[] scores = classificationWeights == null ?
probabilities :
IntStream.range(0, probabilities.size())
.mapToDouble(i -> probabilities.get(i) * classificationWeights[i])
.boxed()
.collect(Collectors.toList());
IntStream.range(0, probabilities.length)
.mapToDouble(i -> probabilities[i] * classificationWeights[i])
.toArray();
int[] sortedIndices = IntStream.range(0, probabilities.size())
int[] sortedIndices = IntStream.range(0, scores.length)
.boxed()
.sorted(Comparator.comparing(scores::get).reversed())
.sorted(Comparator.comparing(i -> scores[(Integer)i]).reversed())
.mapToInt(i -> i)
.toArray();
@ -59,14 +58,14 @@ public final class InferenceHelpers {
List<String> labels = classificationLabels == null ?
// If we don't have the labels we should return the top classification values anyways, they will just be numeric
IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) :
IntStream.range(0, probabilities.length).boxed().map(String::valueOf).collect(Collectors.toList()) :
classificationLabels;
int count = numToInclude < 0 ? probabilities.size() : Math.min(numToInclude, probabilities.size());
int count = numToInclude < 0 ? probabilities.length : Math.min(numToInclude, probabilities.length);
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
for(int i = 0; i < count; i++) {
int idx = sortedIndices[i];
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx), scores.get(idx)));
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities[idx], scores[idx]));
}
return Tuple.tuple(sortedIndices[0], topClassEntries);

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
@ -62,4 +63,8 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
* @return A {@code Map<String, Double>} mapping each featureName to its importance
*/
Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
default Version getMinimalCompatibilityVersion() {
return Version.V_7_6_0;
}
}

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
@ -20,7 +21,6 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInference
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
@ -139,19 +139,20 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
List<Double> inferenceResults = new ArrayList<>(this.models.size());
double[][] inferenceResults = new double[this.models.size()][];
List<Map<String, Double>> featureInfluence = new ArrayList<>();
int i = 0;
NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
this.models.forEach(model -> {
for (TrainedModel model : models) {
InferenceResults result = model.infer(fields, subModelInferenceConfig, Collections.emptyMap());
assert result instanceof SingleValueInferenceResults;
SingleValueInferenceResults inferenceResult = (SingleValueInferenceResults) result;
inferenceResults.add(inferenceResult.value());
assert result instanceof RawInferenceResults;
RawInferenceResults inferenceResult = (RawInferenceResults) result;
inferenceResults[i++] = inferenceResult.getValue();
if (config.requestingImportance()) {
featureInfluence.add(inferenceResult.getFeatureImportance());
}
});
List<Double> processed = outputAggregator.processValues(inferenceResults);
}
double[] processed = outputAggregator.processValues(inferenceResults);
return buildResults(processed, featureInfluence, config, featureDecoderMap);
}
@ -160,13 +161,13 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return targetType;
}
private InferenceResults buildResults(List<Double> processedInferences,
private InferenceResults buildResults(double[] processedInferences,
List<Map<String, Double>> featureInfluence,
InferenceConfig config,
Map<String, String> featureDecoderMap) {
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(outputAggregator.aggregate(processedInferences),
return new RawInferenceResults(new double[] {outputAggregator.aggregate(processedInferences)},
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
}
switch(targetType) {
@ -176,7 +177,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
assert classificationWeights == null || processedInferences.size() == classificationWeights.length;
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
// Adjust the probabilities according to the thresholds
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
processedInferences,
@ -356,6 +357,11 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return Collections.unmodifiableCollection(accountables);
}
@Override
public Version getMinimalCompatibilityVersion() {
return models.stream().map(TrainedModel::getMinimalCompatibilityVersion).max(Version::compareTo).orElse(Version.V_7_6_0);
}
public static class Builder {
private List<String> featureNames;
private List<TrainedModel> trainedModels;

View File

@ -19,9 +19,9 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.sigmoid;
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
@ -78,31 +78,39 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie
}
@Override
public List<Double> processValues(List<Double> values) {
public double[] processValues(double[][] values) {
Objects.requireNonNull(values, "values must not be null");
if (weights != null && values.size() != weights.length) {
if (weights != null && values.length != weights.length) {
throw new IllegalArgumentException("values must be the same length as weights.");
}
double summation = weights == null ?
values.stream().mapToDouble(Double::valueOf).sum() :
IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).sum();
double probOfClassOne = sigmoid(summation);
double[] sumOnAxis1 = new double[values[0].length];
for (int j = 0; j < values.length; j++) {
double[] value = values[j];
double weight = weights == null ? 1.0 : weights[j];
for(int i = 0; i < value.length; i++) {
if (i >= sumOnAxis1.length) {
throw new IllegalArgumentException("value entries must have the same dimensions");
}
sumOnAxis1[i] += (value[i] * weight);
}
}
if (sumOnAxis1.length > 1) {
return softMax(sumOnAxis1);
}
double probOfClassOne = sigmoid(sumOnAxis1[0]);
assert 0.0 <= probOfClassOne && probOfClassOne <= 1.0;
return Arrays.asList(1.0 - probOfClassOne, probOfClassOne);
return new double[] {1.0 - probOfClassOne, probOfClassOne};
}
@Override
public double aggregate(List<Double> values) {
public double aggregate(double[] values) {
Objects.requireNonNull(values, "values must not be null");
assert values.size() == 2;
int bestValue = 0;
double bestProb = Double.NEGATIVE_INFINITY;
for (int i = 0; i < values.size(); i++) {
if (values.get(i) == null) {
throw new IllegalArgumentException("values must not contain null values");
}
if (values.get(i) > bestProb) {
bestProb = values.get(i);
for (int i = 0; i < values.length; i++) {
if (values[i] > bestProb) {
bestProb = values[i];
bestValue = i;
}
}

View File

@ -10,8 +10,6 @@ import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
import java.util.List;
public interface OutputAggregator extends NamedXContentObject, NamedWriteable, Accountable {
/**
@ -20,15 +18,15 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable, A
Integer expectedValueSize();
/**
* This pre-processes the values so that they may be passed directly to the {@link OutputAggregator#aggregate(List)} method.
* This pre-processes the values so that they may be passed directly to the {@link OutputAggregator#aggregate(double[])} method.
*
* Two major types of pre-processed values could be returned:
* - The confidence/probability scaled values given the input values (See: {@link WeightedMode#processValues(List)}
* - A simple transformation of the passed values in preparation for aggregation (See: {@link WeightedSum#processValues(List)}
* - The confidence/probability scaled values given the input values (See: {@link WeightedMode#processValues(double[][])}
* - A simple transformation of the passed values in preparation for aggregation (See: {@link WeightedSum#processValues(double[][])}
* @param values the values to process
* @return A new list containing the processed values or the same list if no processing is required
*/
List<Double> processValues(List<Double> values);
double[] processValues(double[][] values);
/**
* Function to aggregate the processed values into a single double
@ -40,7 +38,7 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable, A
* @param processedValues The values to aggregate
* @return the aggregated value.
*/
double aggregate(List<Double> processedValues);
double aggregate(double[] processedValues);
/**
* @return The name of the output aggregator

View File

@ -89,21 +89,37 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
}
@Override
public List<Double> processValues(List<Double> values) {
public double[] processValues(double[][] values) {
Objects.requireNonNull(values, "values must not be null");
if (weights != null && values.size() != weights.length) {
if (weights != null && values.length != weights.length) {
throw new IllegalArgumentException("values must be the same length as weights.");
}
List<Integer> freqArray = new ArrayList<>();
Integer maxVal = 0;
for (Double value : values) {
if (value == null) {
throw new IllegalArgumentException("values must not contain null values");
// Multiple leaf values
if (values[0].length > 1) {
double[] sumOnAxis1 = new double[values[0].length];
for (int j = 0; j < values.length; j++) {
double[] value = values[j];
double weight = weights == null ? 1.0 : weights[j];
for(int i = 0; i < value.length; i++) {
if (i >= sumOnAxis1.length) {
throw new IllegalArgumentException("value entries must have the same dimensions");
}
sumOnAxis1[i] += (value[i] * weight);
}
}
if (Double.isNaN(value) || Double.isInfinite(value) || value < 0.0 || value != Math.rint(value)) {
return softMax(sumOnAxis1);
}
// Singular leaf values
List<Integer> freqArray = new ArrayList<>();
int maxVal = 0;
for (double[] value : values) {
if (value.length != 1) {
throw new IllegalArgumentException("value entries must have the same dimensions");
}
if (Double.isNaN(value[0]) || Double.isInfinite(value[0]) || value[0] < 0.0 || value[0] != Math.rint(value[0])) {
throw new IllegalArgumentException("values must be whole, non-infinite, and positive");
}
Integer integerValue = value.intValue();
int integerValue = Double.valueOf(value[0]).intValue();
freqArray.add(integerValue);
if (integerValue > maxVal) {
maxVal = integerValue;
@ -112,27 +128,27 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
if (maxVal >= numClasses) {
throw new IllegalArgumentException("values contain entries larger than expected max of [" + (numClasses - 1) + "]");
}
List<Double> frequencies = new ArrayList<>(Collections.nCopies(numClasses, Double.NEGATIVE_INFINITY));
double[] frequencies = Collections.nCopies(numClasses, Double.NEGATIVE_INFINITY)
.stream()
.mapToDouble(Double::doubleValue)
.toArray();
for (int i = 0; i < freqArray.size(); i++) {
Double weight = weights == null ? 1.0 : weights[i];
Integer value = freqArray.get(i);
Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight;
frequencies.set(value, frequency);
double weight = weights == null ? 1.0 : weights[i];
int value = freqArray.get(i);
double frequency = frequencies[value] == Double.NEGATIVE_INFINITY ? weight : frequencies[value] + weight;
frequencies[value] = frequency;
}
return softMax(frequencies);
}
@Override
public double aggregate(List<Double> values) {
public double aggregate(double[] values) {
Objects.requireNonNull(values, "values must not be null");
int bestValue = 0;
double bestFreq = Double.NEGATIVE_INFINITY;
for (int i = 0; i < values.size(); i++) {
if (values.get(i) == null) {
throw new IllegalArgumentException("values must not contain null values");
}
if (values.get(i) > bestFreq) {
bestFreq = values.get(i);
for (int i = 0; i < values.length; i++) {
if (values[i] > bestFreq) {
bestFreq = values[i];
bestValue = i;
}
}

View File

@ -19,8 +19,6 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
@ -73,28 +71,25 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar
}
@Override
public List<Double> processValues(List<Double> values) {
public double[] processValues(double[][] values) {
Objects.requireNonNull(values, "values must not be null");
assert values[0].length == 1;
if (weights == null) {
return values;
return Arrays.stream(values).mapToDouble(v -> v[0]).toArray();
}
if (values.size() != weights.length) {
if (values.length != weights.length) {
throw new IllegalArgumentException("values must be the same length as weights.");
}
return IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).boxed().collect(Collectors.toList());
return IntStream.range(0, weights.length).mapToDouble(i -> values[i][0] * weights[i]).toArray();
}
@Override
public double aggregate(List<Double> values) {
public double aggregate(double[] values) {
Objects.requireNonNull(values, "values must not be null");
if (values.isEmpty()) {
if (values.length == 0) {
throw new IllegalArgumentException("values must not be empty");
}
Optional<Double> summation = values.stream().reduce(Double::sum);
if (summation.isPresent()) {
return summation.get();
}
throw new IllegalArgumentException("values must not contain null values");
return Arrays.stream(values).sum();
}
@Override

View File

@ -30,7 +30,6 @@ import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
@ -130,7 +129,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
double[] h0 = hiddenLayer.productPlusBias(false, embeddedVector);
double[] scores = softmaxLayer.productPlusBias(true, h0);
List<Double> probabilities = softMax(Arrays.stream(scores).boxed().collect(Collectors.toList()));
double[] probabilities = softMax(scores);
ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
@ -29,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfi
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
@ -100,7 +102,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
this.nodes = Collections.unmodifiableList(nodes);
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
this.highestOrderCategory = new CachedSupplier<>(this::maxLeafValue);
}
public Tree(StreamInput in) throws IOException {
@ -112,7 +114,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
} else {
this.classificationLabels = null;
}
this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
this.highestOrderCategory = new CachedSupplier<>(this::maxLeafValue);
}
@Override
@ -147,7 +149,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return buildResult(node.getLeafValue(), featureImportance, config);
}
private InferenceResults buildResult(Double value, Map<String, Double> featureImportance, InferenceConfig config) {
private InferenceResults buildResult(double[] value, Map<String, Double> featureImportance, InferenceConfig config) {
assert value != null && value.length > 0;
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(value, featureImportance);
@ -160,13 +163,13 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
classificationLabels,
null,
classificationConfig.getNumTopClasses());
return new ClassificationInferenceResults(value,
return new ClassificationInferenceResults(topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(),
featureImportance,
config);
case REGRESSION:
return new RegressionInferenceResults(value, config, featureImportance);
return new RegressionInferenceResults(value[0], config, featureImportance);
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
}
@ -193,14 +196,22 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return targetType;
}
private List<Double> classificationProbability(double inferenceValue) {
private double[] classificationProbability(double[] inferenceValue) {
// Multi-value leaves, indicates that the leaves contain an array of values.
// The index of which corresponds to classification values
if (inferenceValue.length > 1) {
return Statistics.softMax(inferenceValue);
}
// If we are classification, we should assume that the inference return value is whole.
assert inferenceValue == Math.rint(inferenceValue);
assert inferenceValue[0] == Math.rint(inferenceValue[0]);
double maxCategory = this.highestOrderCategory.get();
// If we are classification, we should assume that the largest leaf value is whole.
assert maxCategory == Math.rint(maxCategory);
List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
list.set(Double.valueOf(inferenceValue).intValue(), 1.0);
double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)
.stream()
.mapToDouble(Double::doubleValue)
.toArray();
list[Double.valueOf(inferenceValue[0]).intValue()] = 1.0;
return list;
}
@ -268,6 +279,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
checkTargetType();
detectMissingNodes();
detectCycle();
verifyLeafNodeUniformity();
}
@Override
@ -331,7 +343,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
TreeNode currNode = nodes.get(nodeIndex);
nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
if (currNode.isLeaf()) {
// TODO multi-value????
double leafValue = nodeValues[nodeIndex];
for (int i = 1; i < nextIndex; ++i) {
double scale = splitPath.sumUnwoundPath(i, nextIndex);
@ -375,7 +386,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) {
TreeNode node = nodes.get(nodeIndex);
if (node.isLeaf()) {
nodeEstimates[nodeIndex] = node.getLeafValue();
// TODO multi-value????
nodeEstimates[nodeIndex] = node.getLeafValue()[0];
return 0;
}
@ -424,6 +436,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
throw ExceptionsHelper.badRequestException(
"[target_type] should be [classification] if [classification_labels] are provided");
}
if (this.targetType != TargetType.CLASSIFICATION && this.nodes.stream().anyMatch(n -> n.getLeafValue().length > 1)) {
throw ExceptionsHelper.badRequestException(
"[target_type] should be [classification] if leaf nodes have multiple values");
}
}
private void detectCycle() {
@ -465,14 +481,39 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
}
}
private void verifyLeafNodeUniformity() {
Integer leafValueLengths = null;
for (TreeNode node : nodes) {
if (node.isLeaf()) {
if (leafValueLengths == null) {
leafValueLengths = node.getLeafValue().length;
} else if (leafValueLengths != node.getLeafValue().length) {
throw ExceptionsHelper.badRequestException(
"[tree.tree_structure] all leaf nodes must have the same number of values");
}
}
}
}
private static boolean nodeMissing(int nodeIdx, List<TreeNode> nodes) {
return nodeIdx >= nodes.size();
}
private Double maxLeafValue() {
return targetType == TargetType.CLASSIFICATION ?
this.nodes.stream().filter(TreeNode::isLeaf).mapToDouble(TreeNode::getLeafValue).max().getAsDouble() :
null;
if (targetType != TargetType.CLASSIFICATION) {
return null;
}
double max = 0.0;
for (TreeNode node : this.nodes) {
if (node.isLeaf()) {
if (node.getLeafValue().length > 1) {
return (double)node.getLeafValue().length;
} else {
max = Math.max(node.getLeafValue()[0], max);
}
}
}
return max;
}
@Override
@ -493,6 +534,14 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return Collections.unmodifiableCollection(accountables);
}
@Override
public Version getMinimalCompatibilityVersion() {
if (nodes.stream().filter(TreeNode::isLeaf).anyMatch(t -> t.getLeafValue().length > 1)) {
return Version.V_7_7_0;
}
return Version.V_7_6_0;
}
public static class Builder {
private List<String> featureNames;
private ArrayList<TreeNode.Builder> nodes;
@ -586,6 +635,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
* @return this
*/
Tree.Builder addLeaf(int nodeIndex, double value) {
return addLeaf(nodeIndex, Arrays.asList(value));
}
Tree.Builder addLeaf(int nodeIndex, List<Double> value) {
for (int i = nodes.size(); i < nodeIndex + 1; i++) {
nodes.add(null);
}

View File

@ -21,6 +21,8 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
@ -60,7 +62,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
parser.declareInt(TreeNode.Builder::setSplitFeature, SPLIT_FEATURE);
parser.declareInt(TreeNode.Builder::setNodeIndex, NODE_INDEX);
parser.declareDouble(TreeNode.Builder::setSplitGain, SPLIT_GAIN);
parser.declareDouble(TreeNode.Builder::setLeafValue, LEAF_VALUE);
parser.declareDoubleArray(TreeNode.Builder::setLeafValue, LEAF_VALUE);
parser.declareLong(TreeNode.Builder::setNumberSamples, NUMBER_SAMPLES);
return parser;
}
@ -74,7 +76,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
private final int splitFeature;
private final int nodeIndex;
private final double splitGain;
private final double leafValue;
private final double[] leafValue;
private final boolean defaultLeft;
private final int leftChild;
private final int rightChild;
@ -86,7 +88,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
Integer splitFeature,
int nodeIndex,
Double splitGain,
Double leafValue,
List<Double> leafValue,
Boolean defaultLeft,
Integer leftChild,
Integer rightChild,
@ -96,7 +98,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
this.splitFeature = splitFeature == null ? -1 : splitFeature;
this.nodeIndex = nodeIndex;
this.splitGain = splitGain == null ? Double.NaN : splitGain;
this.leafValue = leafValue == null ? Double.NaN : leafValue;
this.leafValue = leafValue == null ? new double[0] : leafValue.stream().mapToDouble(Double::doubleValue).toArray();
this.defaultLeft = defaultLeft == null ? false : defaultLeft;
this.leftChild = leftChild == null ? -1 : leftChild;
this.rightChild = rightChild == null ? -1 : rightChild;
@ -112,7 +114,11 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
splitFeature = in.readInt();
splitGain = in.readDouble();
nodeIndex = in.readVInt();
leafValue = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
leafValue = in.readDoubleArray();
} else {
leafValue = new double[]{in.readDouble()};
}
defaultLeft = in.readBoolean();
leftChild = in.readInt();
rightChild = in.readInt();
@ -144,7 +150,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
return splitGain;
}
public double getLeafValue() {
public double[] getLeafValue() {
return leafValue;
}
@ -190,7 +196,18 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
out.writeInt(splitFeature);
out.writeDouble(splitGain);
out.writeVInt(nodeIndex);
out.writeDouble(leafValue);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeDoubleArray(leafValue);
} else {
if (leafValue.length > 1) {
throw new IOException("Multi-class classification models require that all nodes are at least version 7.7.0.");
}
if (leafValue.length == 0) {
out.writeDouble(Double.NaN);
} else {
out.writeDouble(leafValue[0]);
}
}
out.writeBoolean(defaultLeft);
out.writeInt(leftChild);
out.writeInt(rightChild);
@ -209,7 +226,9 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
}
addOptionalDouble(builder, SPLIT_GAIN, splitGain);
builder.field(NODE_INDEX.getPreferredName(), nodeIndex);
addOptionalDouble(builder, LEAF_VALUE, leafValue);
if (leafValue.length > 0) {
builder.field(LEAF_VALUE.getPreferredName(), leafValue);
}
builder.field(DEFAULT_LEFT.getPreferredName(), defaultLeft);
if (leftChild >= 0) {
builder.field(LEFT_CHILD.getPreferredName(), leftChild);
@ -238,7 +257,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
&& Objects.equals(splitFeature, that.splitFeature)
&& Objects.equals(nodeIndex, that.nodeIndex)
&& Objects.equals(splitGain, that.splitGain)
&& Objects.equals(leafValue, that.leafValue)
&& Arrays.equals(leafValue, that.leafValue)
&& Objects.equals(defaultLeft, that.defaultLeft)
&& Objects.equals(leftChild, that.leftChild)
&& Objects.equals(rightChild, that.rightChild)
@ -252,7 +271,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
splitFeature,
splitGain,
nodeIndex,
leafValue,
Arrays.hashCode(leafValue),
defaultLeft,
leftChild,
rightChild,
@ -270,7 +289,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
@Override
public long ramBytesUsed() {
return SHALLOW_SIZE;
return SHALLOW_SIZE + this.leafValue.length * Double.BYTES;
}
public static class Builder {
@ -279,7 +298,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
private Integer splitFeature;
private int nodeIndex;
private Double splitGain;
private Double leafValue;
private List<Double> leafValue;
private Boolean defaultLeft;
private Integer leftChild;
private Integer rightChild;
@ -317,11 +336,19 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
return this;
}
public Builder setLeafValue(Double leafValue) {
public Builder setLeafValue(double leafValue) {
return this.setLeafValue(Collections.singletonList(leafValue));
}
public Builder setLeafValue(List<Double> leafValue) {
this.leafValue = leafValue;
return this;
}
List<Double> getLeafValue() {
return this.leafValue;
}
public Builder setDefaultLeft(Boolean defaultLeft) {
this.defaultLeft = defaultLeft;
return this;
@ -358,6 +385,9 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
if (leafValue == null) {
throw new IllegalArgumentException("[leaf_value] is required for a leaf node.");
}
if (leafValue.stream().anyMatch(Objects::isNull)) {
throw new IllegalArgumentException("[leaf_value] cannot have null values.");
}
} else {
if (leftChild < 0) {
throw new IllegalArgumentException("[left_child] must be a non-negative integer.");

View File

@ -7,8 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.utils;
import org.elasticsearch.common.Numbers;
import java.util.List;
import java.util.stream.Collectors;
import java.util.Arrays;
public final class Statistics {
@ -20,28 +19,29 @@ public final class Statistics {
* Any {@link Double#isInfinite()}, {@link Double#NaN}, or `null` values are ignored in calculation and returned as 0.0 in the
* softMax.
* @param values Values on which to run SoftMax.
* @return A new list containing the softmax of the passed values
* @return A new array containing the softmax of the passed values
*/
public static List<Double> softMax(List<Double> values) {
Double expSum = 0.0;
Double max = values.stream().filter(Statistics::isValid).max(Double::compareTo).orElse(null);
if (max == null) {
public static double[] softMax(double[] values) {
double expSum = 0.0;
double max = Arrays.stream(values).filter(Statistics::isValid).max().orElse(Double.NaN);
if (isValid(max) == false) {
throw new IllegalArgumentException("no valid values present");
}
List<Double> exps = values.stream().map(v -> isValid(v) ? v - max : Double.NEGATIVE_INFINITY)
.collect(Collectors.toList());
for (int i = 0; i < exps.size(); i++) {
if (isValid(exps.get(i))) {
Double exp = Math.exp(exps.get(i));
double[] exps = new double[values.length];
for (int i = 0; i < exps.length; i++) {
if (isValid(values[i])) {
double exp = Math.exp(values[i] - max);
expSum += exp;
exps.set(i, exp);
exps[i] = exp;
} else {
exps[i] = Double.NaN;
}
}
for (int i = 0; i < exps.size(); i++) {
if (isValid(exps.get(i))) {
exps.set(i, exps.get(i)/expSum);
for (int i = 0; i < exps.length; i++) {
if (isValid(exps[i])) {
exps[i] /= expSum;
} else {
exps.set(i, 0.0);
exps[i] = 0.0;
}
}
return exps;
@ -51,8 +51,8 @@ public final class Statistics {
return 1/(1 + Math.exp(-value));
}
private static boolean isValid(Double v) {
return v != null && Numbers.isValidDouble(v);
private static boolean isValid(double v) {
return Numbers.isValidDouble(v);
}
}

View File

@ -5,24 +5,37 @@
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.test.ESTestCase;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
public class RawInferenceResultsTests extends AbstractWireSerializingTestCase<RawInferenceResults> {
import static org.hamcrest.CoreMatchers.equalTo;
public class RawInferenceResultsTests extends ESTestCase {
public static RawInferenceResults createRandomResults() {
return new RawInferenceResults(randomDouble(), randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08));
int n = randomIntBetween(1, 10);
double[] results = new double[n];
for (int i = 0; i < n; i++) {
results[i] = randomDouble();
}
return new RawInferenceResults(results, randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08));
}
@Override
protected RawInferenceResults createTestInstance() {
return createRandomResults();
public void testEqualityAndHashcode() {
int n = randomIntBetween(1, 10);
double[] results = new double[n];
for (int i = 0; i < n; i++) {
results[i] = randomDouble();
}
Map<String, Double> importance = randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08);
RawInferenceResults lft = new RawInferenceResults(results, new HashMap<>(importance));
RawInferenceResults rgt = new RawInferenceResults(Arrays.copyOf(results, n), new HashMap<>(importance));
assertThat(lft, equalTo(rgt));
assertThat(lft.hashCode(), equalTo(rgt.hashCode()));
}
@Override
protected Writeable.Reader<RawInferenceResults> instanceReader() {
return RawInferenceResults::new;
}
}

View File

@ -11,11 +11,10 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticRegression> {
@ -43,7 +42,13 @@ public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticReg
public void testAggregate() {
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
double[][] values = new double[][]{
new double[] {1.0},
new double[] {2.0},
new double[] {2.0},
new double[] {3.0},
new double[] {5.0}
};
LogisticRegression logisticRegression = new LogisticRegression(ones);
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0));
@ -57,6 +62,36 @@ public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticReg
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0));
}
public void testAggregateMultiValueArrays() {
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
double[][] values = new double[][]{
new double[] {1.0, 0.0, 1.0},
new double[] {2.0, 0.0, 0.0},
new double[] {2.0, 3.0, 1.0},
new double[] {3.0, 3.0, 1.0},
new double[] {1.0, 1.0, 5.0}
};
LogisticRegression logisticRegression = new LogisticRegression(ones);
double[] processedValues = logisticRegression.processValues(values);
assertThat(processedValues.length, equalTo(3));
assertThat(processedValues[0], closeTo(0.665240955, 0.00001));
assertThat(processedValues[1], closeTo(0.090030573, 0.00001));
assertThat(processedValues[2], closeTo(0.244728471, 0.00001));
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(0.0));
double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0};
logisticRegression = new LogisticRegression(variedWeights);
processedValues = logisticRegression.processValues(values);
assertThat(processedValues.length, equalTo(3));
assertThat(processedValues[0], closeTo(0.0, 0.00001));
assertThat(processedValues[1], closeTo(0.0, 0.00001));
assertThat(processedValues[2], closeTo(0.9999999, 0.00001));
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(2.0));
}
public void testCompatibleWith() {
LogisticRegression logisticRegression = createTestInstance();
assertThat(logisticRegression.compatibleWith(TargetType.CLASSIFICATION), is(true));

View File

@ -8,9 +8,6 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.junit.Before;
import java.util.ArrayList;
import java.util.List;
import static org.hamcrest.Matchers.equalTo;
public abstract class WeightedAggregatorTests<T extends OutputAggregator> extends AbstractSerializingTestCase<T> {
@ -35,9 +32,9 @@ public abstract class WeightedAggregatorTests<T extends OutputAggregator> extend
public void testWithValuesOfWrongLength() {
int numberOfValues = randomIntBetween(5, 10);
List<Double> values = new ArrayList<>(numberOfValues);
double[][] values = new double[numberOfValues][];
for (int i = 0; i < numberOfValues; i++) {
values.add(randomDouble());
values[i] = new double[] {randomDouble()};
}
OutputAggregator outputAggregatorWithTooFewWeights = createTestInstance(randomIntBetween(1, numberOfValues - 1));

View File

@ -11,8 +11,6 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;
import static org.hamcrest.CoreMatchers.is;
@ -44,7 +42,13 @@ public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
public void testAggregate() {
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
double[][] values = new double[][]{
new double[] {1.0},
new double[] {2.0},
new double[] {2.0},
new double[] {3.0},
new double[] {5.0}
};
WeightedMode weightedMode = new WeightedMode(ones, 6);
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
@ -57,19 +61,55 @@ public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
weightedMode = new WeightedMode(6);
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
values = Arrays.asList(1.0, 1.0, 1.0, 1.0, 2.0);
values = new double[][]{
new double[] {1.0},
new double[] {1.0},
new double[] {1.0},
new double[] {1.0},
new double[] {2.0}
};
weightedMode = new WeightedMode(6);
List<Double> processedValues = weightedMode.processValues(values);
assertThat(processedValues.size(), equalTo(6));
assertThat(processedValues.get(0), equalTo(0.0));
assertThat(processedValues.get(1), closeTo(0.95257412, 0.00001));
assertThat(processedValues.get(2), closeTo((1.0 - 0.95257412), 0.00001));
assertThat(processedValues.get(3), equalTo(0.0));
assertThat(processedValues.get(4), equalTo(0.0));
assertThat(processedValues.get(5), equalTo(0.0));
double[] processedValues = weightedMode.processValues(values);
assertThat(processedValues.length, equalTo(6));
assertThat(processedValues[0], equalTo(0.0));
assertThat(processedValues[1], closeTo(0.95257412, 0.00001));
assertThat(processedValues[2], closeTo((1.0 - 0.95257412), 0.00001));
assertThat(processedValues[3], equalTo(0.0));
assertThat(processedValues[4], equalTo(0.0));
assertThat(processedValues[5], equalTo(0.0));
assertThat(weightedMode.aggregate(processedValues), equalTo(1.0));
}
public void testAggregateMultiValueArrays() {
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
double[][] values = new double[][]{
new double[] {1.0, 0.0, 1.0},
new double[] {2.0, 0.0, 0.0},
new double[] {2.0, 3.0, 1.0},
new double[] {3.0, 3.0, 1.0},
new double[] {1.0, 1.0, 5.0}
};
WeightedMode weightedMode = new WeightedMode(ones, 3);
double[] processedValues = weightedMode.processValues(values);
assertThat(processedValues.length, equalTo(3));
assertThat(processedValues[0], closeTo(0.665240955, 0.00001));
assertThat(processedValues[1], closeTo(0.090030573, 0.00001));
assertThat(processedValues[2], closeTo(0.244728471, 0.00001));
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(0.0));
double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0};
weightedMode = new WeightedMode(variedWeights, 3);
processedValues = weightedMode.processValues(values);
assertThat(processedValues.length, equalTo(3));
assertThat(processedValues[0], closeTo(0.0, 0.00001));
assertThat(processedValues[1], closeTo(0.0, 0.00001));
assertThat(processedValues[2], closeTo(0.9999999, 0.00001));
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
}
public void testCompatibleWith() {
WeightedMode weightedMode = createTestInstance();
assertThat(weightedMode.compatibleWith(TargetType.CLASSIFICATION), is(true));

View File

@ -11,8 +11,6 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;
import static org.hamcrest.CoreMatchers.is;
@ -43,7 +41,13 @@ public class WeightedSumTests extends WeightedAggregatorTests<WeightedSum> {
public void testAggregate() {
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
double[][] values = new double[][]{
new double[] {1.0},
new double[] {2.0},
new double[] {2.0},
new double[] {3.0},
new double[] {5.0}
};
WeightedSum weightedSum = new WeightedSum(ones);
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0));

View File

@ -55,7 +55,7 @@ public class TreeNodeTests extends AbstractSerializingTestCase<TreeNode> {
return TreeNode.builder(randomInt(100))
.setDefaultLeft(randomBoolean() ? null : randomBoolean())
.setNumberSamples(randomNonNegativeLong())
.setLeafValue(internalValue)
.setLeafValue(Collections.singletonList(internalValue))
.build();
}
@ -66,7 +66,7 @@ public class TreeNodeTests extends AbstractSerializingTestCase<TreeNode> {
Integer featureIndex,
Operator operator) {
return TreeNode.builder(nodeId)
.setLeafValue(left == null ? randomDouble() : null)
.setLeafValue(left == null ? Collections.singletonList(randomDouble()) : null)
.setDefaultLeft(randomBoolean() ? null : randomBoolean())
.setLeftChild(left)
.setRightChild(right)

View File

@ -112,7 +112,7 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
public void testInferWithStump() {
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
builder.setRoot(TreeNode.builder(0).setLeafValue(42.0));
builder.setRoot(TreeNode.builder(0).setLeafValue(Collections.singletonList(42.0)));
builder.setFeatureNames(Collections.emptyList());
Tree tree = builder.build();

View File

@ -16,18 +16,18 @@ import static org.hamcrest.Matchers.closeTo;
public class StatisticsTests extends ESTestCase {
public void testSoftMax() {
List<Double> values = Arrays.asList(Double.NEGATIVE_INFINITY, 1.0, -0.5, null, Double.NaN, Double.POSITIVE_INFINITY, 1.0, 5.0);
List<Double> softMax = Statistics.softMax(values);
double[] values = new double[] {Double.NEGATIVE_INFINITY, 1.0, -0.5, Double.NaN, Double.NaN, Double.POSITIVE_INFINITY, 1.0, 5.0};
double[] softMax = Statistics.softMax(values);
List<Double> expected = Arrays.asList(0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042);
double[] expected = new double[] {0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042};
for(int i = 0; i < expected.size(); i++) {
assertThat(softMax.get(i), closeTo(expected.get(i), 0.000001));
for(int i = 0; i < expected.length; i++) {
assertThat(softMax[i], closeTo(expected[i], 0.000001));
}
}
public void testSoftMaxWithNoValidValues() {
List<Double> values = Arrays.asList(Double.NEGATIVE_INFINITY, null, Double.NaN, Double.POSITIVE_INFINITY);
double[] values = new double[] {Double.NEGATIVE_INFINITY, Double.NaN, Double.POSITIVE_INFINITY};
expectThrows(IllegalArgumentException.class, () -> Statistics.softMax(values));
}

View File

@ -211,14 +211,14 @@ public class TrainedModelIT extends ESRestTestCase {
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5),
TreeNode.builder(1).setLeafValue(0.3),
TreeNode.builder(1).setLeafValue(Collections.singletonList(0.3)),
TreeNode.builder(2)
.setThreshold(0.0)
.setSplitFeature(3)
.setLeftChild(3)
.setRightChild(4),
TreeNode.builder(3).setLeafValue(0.1),
TreeNode.builder(4).setLeafValue(0.2))
TreeNode.builder(3).setLeafValue(Collections.singletonList(0.1)),
TreeNode.builder(4).setLeafValue(Collections.singletonList(0.2)))
.build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
@ -227,8 +227,8 @@ public class TrainedModelIT extends ESRestTestCase {
.setRightChild(2)
.setSplitFeature(2)
.setThreshold(1.0),
TreeNode.builder(1).setLeafValue(1.5),
TreeNode.builder(2).setLeafValue(0.9))
TreeNode.builder(1).setLeafValue(Collections.singletonList(1.5)),
TreeNode.builder(2).setLeafValue(Collections.singletonList(0.9)))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
@ -237,8 +237,8 @@ public class TrainedModelIT extends ESRestTestCase {
.setRightChild(2)
.setSplitFeature(1)
.setThreshold(0.2),
TreeNode.builder(1).setLeafValue(1.5),
TreeNode.builder(2).setLeafValue(0.9))
TreeNode.builder(1).setLeafValue(Collections.singletonList(1.5)),
TreeNode.builder(2).setLeafValue(Collections.singletonList(0.9)))
.build();
return Ensemble.builder()
.setTargetType(TargetType.REGRESSION)

View File

@ -94,6 +94,18 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Re
return;
}
Version minCompatibilityVersion = request.getTrainedModelConfig()
.getModelDefinition()
.getTrainedModel()
.getMinimalCompatibilityVersion();
if (state.nodes().getMinNodeVersion().before(minCompatibilityVersion)) {
listener.onFailure(ExceptionsHelper.badRequestException(
"Definition for [{}] requires that all nodes are at least version [{}]",
request.getTrainedModelConfig().getModelId(),
minCompatibilityVersion.toString()));
return;
}
TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig())
.setVersion(Version.CURRENT)
.setCreateTime(Instant.now())

View File

@ -22,6 +22,12 @@ import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceRes
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
@ -189,6 +195,109 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be"));
}
public void testInferModelMultiClassModel() throws Exception {
String modelId = "test-load-models-classification-multi";
Map<String, String> oneHotEncoding = new HashMap<>();
oneHotEncoding.put("cat", "animal_cat");
oneHotEncoding.put("dog", "animal_dog");
TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId)
.setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
.setParsedDefinition(new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
.setTrainedModel(buildMultiClassClassification()))
.setVersion(Version.CURRENT)
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setCreateTime(Instant.now())
.setEstimatedOperations(0)
.setEstimatedHeapMemory(0)
.build();
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
assertThat(putConfigHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
List<Map<String, Object>> toInfer = new ArrayList<>();
toInfer.add(new HashMap<String, Object>() {{
put("field", new HashMap<String, Object>(){{
put("foo", 1.0);
put("bar", 0.5);
}});
put("other", new HashMap<String, Object>(){{
put("categorical", "dog");
}});
}});
toInfer.add(new HashMap<String, Object>() {{
put("field", new HashMap<String, Object>(){{
put("foo", 0.9);
put("bar", 1.5);
}});
put("other", new HashMap<String, Object>(){{
put("categorical", "cat");
}});
}});
List<Map<String, Object>> toInfer2 = new ArrayList<>();
toInfer2.add(new HashMap<String, Object>() {{
put("field", new HashMap<String, Object>(){{
put("foo", 0.0);
put("bar", 0.01);
}});
put("other", new HashMap<String, Object>(){{
put("categorical", "dog");
}});
}});
toInfer2.add(new HashMap<String, Object>() {{
put("field", new HashMap<String, Object>(){{
put("foo", 1.0);
put("bar", 0.0);
}});
put("other", new HashMap<String, Object>(){{
put("categorical", "cat");
}});
}});
// Test regression
InternalInferModelAction.Request request = new InternalInferModelAction.Request(modelId,
toInfer,
ClassificationConfig.EMPTY_PARAMS,
true);
InternalInferModelAction.Response response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults()
.stream()
.map(i -> ((SingleValueInferenceResults)i).valueAsString())
.collect(Collectors.toList()),
contains("option_0", "option_2"));
request = new InternalInferModelAction.Request(modelId, toInfer2, ClassificationConfig.EMPTY_PARAMS, true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults()
.stream()
.map(i -> ((SingleValueInferenceResults)i).valueAsString())
.collect(Collectors.toList()),
contains("option_2", "option_0"));
// Get top classes
request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfig(3, null, null), true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
ClassificationInferenceResults classificationInferenceResults =
(ClassificationInferenceResults)response.getInferenceResults().get(0);
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("option_0"));
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("option_2"));
assertThat(classificationInferenceResults.getTopClasses().get(2).getClassification(), equalTo("option_1"));
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1);
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("option_2"));
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("option_0"));
assertThat(classificationInferenceResults.getTopClasses().get(2).getClassification(), equalTo("option_1"));
}
public void testInferMissingModel() {
String model = "test-infer-missing-model";
InternalInferModelAction.Request request = new InternalInferModelAction.Request(
@ -256,6 +365,54 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
.setModelId(modelId);
}
public static TrainedModel buildMultiClassClassification() {
List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(Arrays.asList(1.0, 0.0, 2.0)))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(Arrays.asList(0.0, 1.0, 0.0)))
.addNode(TreeNode.builder(4).setLeafValue(Arrays.asList(0.0, 0.0, 1.0))).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(3)
.setThreshold(1.0))
.addNode(TreeNode.builder(1).setLeafValue(Arrays.asList(2.0, 0.0, 0.0)))
.addNode(TreeNode.builder(2).setLeafValue(Arrays.asList(0.0, 2.0, 0.0)))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(1.0))
.addNode(TreeNode.builder(1).setLeafValue(Arrays.asList(0.0, 0.0, 1.0)))
.addNode(TreeNode.builder(2).setLeafValue(Arrays.asList(0.0, 1.0, 0.0)))
.build();
return Ensemble.builder()
.setClassificationLabels(Arrays.asList("option_0", "option_1", "option_2"))
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 3))
.build();
}
@Override
public NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();