* [ML][Inference] Add support for multi-value leaves to the tree model (#52531) This adds support for multi-value leaves. This is a prerequisite for multi-class boosted tree classification.
This commit is contained in:
parent
710a9ead69
commit
19a6c5d980
|
@ -30,6 +30,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -225,7 +226,7 @@ public class Tree implements TrainedModel {
|
|||
for (int i = nodes.size(); i < nodeIndex + 1; i++) {
|
||||
nodes.add(null);
|
||||
}
|
||||
nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(value));
|
||||
nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(Collections.singletonList(value)));
|
||||
return this;
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
|
|||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TreeNode implements ToXContentObject {
|
||||
|
@ -61,7 +62,7 @@ public class TreeNode implements ToXContentObject {
|
|||
PARSER.declareInt(Builder::setSplitFeature, SPLIT_FEATURE);
|
||||
PARSER.declareInt(Builder::setNodeIndex, NODE_INDEX);
|
||||
PARSER.declareDouble(Builder::setSplitGain, SPLIT_GAIN);
|
||||
PARSER.declareDouble(Builder::setLeafValue, LEAF_VALUE);
|
||||
PARSER.declareDoubleArray(Builder::setLeafValue, LEAF_VALUE);
|
||||
PARSER.declareLong(Builder::setNumberSamples, NUMBER_SAMPLES);
|
||||
}
|
||||
|
||||
|
@ -74,7 +75,7 @@ public class TreeNode implements ToXContentObject {
|
|||
private final Integer splitFeature;
|
||||
private final int nodeIndex;
|
||||
private final Double splitGain;
|
||||
private final Double leafValue;
|
||||
private final List<Double> leafValue;
|
||||
private final Boolean defaultLeft;
|
||||
private final Integer leftChild;
|
||||
private final Integer rightChild;
|
||||
|
@ -86,7 +87,7 @@ public class TreeNode implements ToXContentObject {
|
|||
Integer splitFeature,
|
||||
int nodeIndex,
|
||||
Double splitGain,
|
||||
Double leafValue,
|
||||
List<Double> leafValue,
|
||||
Boolean defaultLeft,
|
||||
Integer leftChild,
|
||||
Integer rightChild,
|
||||
|
@ -123,7 +124,7 @@ public class TreeNode implements ToXContentObject {
|
|||
return splitGain;
|
||||
}
|
||||
|
||||
public Double getLeafValue() {
|
||||
public List<Double> getLeafValue() {
|
||||
return leafValue;
|
||||
}
|
||||
|
||||
|
@ -212,7 +213,7 @@ public class TreeNode implements ToXContentObject {
|
|||
private Integer splitFeature;
|
||||
private int nodeIndex;
|
||||
private Double splitGain;
|
||||
private Double leafValue;
|
||||
private List<Double> leafValue;
|
||||
private Boolean defaultLeft;
|
||||
private Integer leftChild;
|
||||
private Integer rightChild;
|
||||
|
@ -250,7 +251,7 @@ public class TreeNode implements ToXContentObject {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setLeafValue(Double leafValue) {
|
||||
public Builder setLeafValue(List<Double> leafValue) {
|
||||
this.leafValue = leafValue;
|
||||
return this;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
||||
public class TreeNodeTests extends AbstractXContentTestCase<TreeNode> {
|
||||
|
||||
|
@ -48,7 +49,7 @@ public class TreeNodeTests extends AbstractXContentTestCase<TreeNode> {
|
|||
public static TreeNode createRandomLeafNode(double internalValue) {
|
||||
return TreeNode.builder(randomInt(100))
|
||||
.setDefaultLeft(randomBoolean() ? null : randomBoolean())
|
||||
.setLeafValue(internalValue)
|
||||
.setLeafValue(Collections.singletonList(internalValue))
|
||||
.setNumberSamples(randomNonNegativeLong())
|
||||
.build();
|
||||
}
|
||||
|
@ -60,7 +61,7 @@ public class TreeNodeTests extends AbstractXContentTestCase<TreeNode> {
|
|||
Integer featureIndex,
|
||||
Operator operator) {
|
||||
return TreeNode.builder(nodeIndex)
|
||||
.setLeafValue(left == null ? randomDouble() : null)
|
||||
.setLeafValue(left == null ? Collections.singletonList(randomDouble()) : null)
|
||||
.setDefaultLeft(randomBoolean() ? null : randomBoolean())
|
||||
.setLeftChild(left)
|
||||
.setRightChild(right)
|
||||
|
|
|
@ -5,29 +5,37 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class RawInferenceResults extends SingleValueInferenceResults {
|
||||
public class RawInferenceResults implements InferenceResults {
|
||||
|
||||
public static final String NAME = "raw";
|
||||
|
||||
public RawInferenceResults(double value, Map<String, Double> featureImportance) {
|
||||
super(value, featureImportance);
|
||||
private final double[] value;
|
||||
private final Map<String, Double> featureImportance;
|
||||
|
||||
public RawInferenceResults(double[] value, Map<String, Double> featureImportance) {
|
||||
this.value = value;
|
||||
this.featureImportance = featureImportance;
|
||||
}
|
||||
|
||||
public RawInferenceResults(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
public double[] getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
public Map<String, Double> getFeatureImportance() {
|
||||
return featureImportance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
throw new UnsupportedOperationException("[raw] does not support wire serialization");
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -35,13 +43,13 @@ public class RawInferenceResults extends SingleValueInferenceResults {
|
|||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
RawInferenceResults that = (RawInferenceResults) object;
|
||||
return Objects.equals(value(), that.value())
|
||||
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
|
||||
return Arrays.equals(value, that.value)
|
||||
&& Objects.equals(featureImportance, that.featureImportance);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(value(), getFeatureImportance());
|
||||
return Objects.hash(Arrays.hashCode(value), featureImportance);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -26,30 +26,29 @@ public final class InferenceHelpers {
|
|||
/**
|
||||
* @return Tuple of the highest scored index and the top classes
|
||||
*/
|
||||
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(List<Double> probabilities,
|
||||
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(double[] probabilities,
|
||||
List<String> classificationLabels,
|
||||
@Nullable double[] classificationWeights,
|
||||
int numToInclude) {
|
||||
|
||||
if (classificationLabels != null && probabilities.size() != classificationLabels.size()) {
|
||||
if (classificationLabels != null && probabilities.length != classificationLabels.size()) {
|
||||
throw ExceptionsHelper
|
||||
.serverError(
|
||||
"model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]",
|
||||
null,
|
||||
probabilities.size(),
|
||||
probabilities.length,
|
||||
classificationLabels.size());
|
||||
}
|
||||
|
||||
List<Double> scores = classificationWeights == null ?
|
||||
double[] scores = classificationWeights == null ?
|
||||
probabilities :
|
||||
IntStream.range(0, probabilities.size())
|
||||
.mapToDouble(i -> probabilities.get(i) * classificationWeights[i])
|
||||
.boxed()
|
||||
.collect(Collectors.toList());
|
||||
IntStream.range(0, probabilities.length)
|
||||
.mapToDouble(i -> probabilities[i] * classificationWeights[i])
|
||||
.toArray();
|
||||
|
||||
int[] sortedIndices = IntStream.range(0, probabilities.size())
|
||||
int[] sortedIndices = IntStream.range(0, scores.length)
|
||||
.boxed()
|
||||
.sorted(Comparator.comparing(scores::get).reversed())
|
||||
.sorted(Comparator.comparing(i -> scores[(Integer)i]).reversed())
|
||||
.mapToInt(i -> i)
|
||||
.toArray();
|
||||
|
||||
|
@ -59,14 +58,14 @@ public final class InferenceHelpers {
|
|||
|
||||
List<String> labels = classificationLabels == null ?
|
||||
// If we don't have the labels we should return the top classification values anyways, they will just be numeric
|
||||
IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) :
|
||||
IntStream.range(0, probabilities.length).boxed().map(String::valueOf).collect(Collectors.toList()) :
|
||||
classificationLabels;
|
||||
|
||||
int count = numToInclude < 0 ? probabilities.size() : Math.min(numToInclude, probabilities.size());
|
||||
int count = numToInclude < 0 ? probabilities.length : Math.min(numToInclude, probabilities.length);
|
||||
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
|
||||
for(int i = 0; i < count; i++) {
|
||||
int idx = sortedIndices[i];
|
||||
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx), scores.get(idx)));
|
||||
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities[idx], scores[idx]));
|
||||
}
|
||||
|
||||
return Tuple.tuple(sortedIndices[0], topClassEntries);
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
|
@ -62,4 +63,8 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
|
|||
* @return A {@code Map<String, Double>} mapping each featureName to its importance
|
||||
*/
|
||||
Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
|
||||
|
||||
default Version getMinimalCompatibilityVersion() {
|
||||
return Version.V_7_6_0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
|||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Accountables;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
|
@ -20,7 +21,6 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInference
|
|||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
|
||||
|
@ -139,19 +139,20 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
|
||||
}
|
||||
List<Double> inferenceResults = new ArrayList<>(this.models.size());
|
||||
double[][] inferenceResults = new double[this.models.size()][];
|
||||
List<Map<String, Double>> featureInfluence = new ArrayList<>();
|
||||
int i = 0;
|
||||
NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
|
||||
this.models.forEach(model -> {
|
||||
for (TrainedModel model : models) {
|
||||
InferenceResults result = model.infer(fields, subModelInferenceConfig, Collections.emptyMap());
|
||||
assert result instanceof SingleValueInferenceResults;
|
||||
SingleValueInferenceResults inferenceResult = (SingleValueInferenceResults) result;
|
||||
inferenceResults.add(inferenceResult.value());
|
||||
assert result instanceof RawInferenceResults;
|
||||
RawInferenceResults inferenceResult = (RawInferenceResults) result;
|
||||
inferenceResults[i++] = inferenceResult.getValue();
|
||||
if (config.requestingImportance()) {
|
||||
featureInfluence.add(inferenceResult.getFeatureImportance());
|
||||
}
|
||||
});
|
||||
List<Double> processed = outputAggregator.processValues(inferenceResults);
|
||||
}
|
||||
double[] processed = outputAggregator.processValues(inferenceResults);
|
||||
return buildResults(processed, featureInfluence, config, featureDecoderMap);
|
||||
}
|
||||
|
||||
|
@ -160,13 +161,13 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
return targetType;
|
||||
}
|
||||
|
||||
private InferenceResults buildResults(List<Double> processedInferences,
|
||||
private InferenceResults buildResults(double[] processedInferences,
|
||||
List<Map<String, Double>> featureInfluence,
|
||||
InferenceConfig config,
|
||||
Map<String, String> featureDecoderMap) {
|
||||
// Indicates that the config is useless and the caller just wants the raw value
|
||||
if (config instanceof NullInferenceConfig) {
|
||||
return new RawInferenceResults(outputAggregator.aggregate(processedInferences),
|
||||
return new RawInferenceResults(new double[] {outputAggregator.aggregate(processedInferences)},
|
||||
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
|
||||
}
|
||||
switch(targetType) {
|
||||
|
@ -176,7 +177,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
|
||||
case CLASSIFICATION:
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
assert classificationWeights == null || processedInferences.size() == classificationWeights.length;
|
||||
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
|
||||
// Adjust the probabilities according to the thresholds
|
||||
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
processedInferences,
|
||||
|
@ -356,6 +357,11 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
return Collections.unmodifiableCollection(accountables);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Version getMinimalCompatibilityVersion() {
|
||||
return models.stream().map(TrainedModel::getMinimalCompatibilityVersion).max(Version::compareTo).orElse(Version.V_7_6_0);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private List<String> featureNames;
|
||||
private List<TrainedModel> trainedModels;
|
||||
|
|
|
@ -19,9 +19,9 @@ import java.io.IOException;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.sigmoid;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
|
||||
|
||||
public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
|
||||
|
||||
|
@ -78,31 +78,39 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie
|
|||
}
|
||||
|
||||
@Override
|
||||
public List<Double> processValues(List<Double> values) {
|
||||
public double[] processValues(double[][] values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
if (weights != null && values.size() != weights.length) {
|
||||
if (weights != null && values.length != weights.length) {
|
||||
throw new IllegalArgumentException("values must be the same length as weights.");
|
||||
}
|
||||
double summation = weights == null ?
|
||||
values.stream().mapToDouble(Double::valueOf).sum() :
|
||||
IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).sum();
|
||||
double probOfClassOne = sigmoid(summation);
|
||||
double[] sumOnAxis1 = new double[values[0].length];
|
||||
for (int j = 0; j < values.length; j++) {
|
||||
double[] value = values[j];
|
||||
double weight = weights == null ? 1.0 : weights[j];
|
||||
for(int i = 0; i < value.length; i++) {
|
||||
if (i >= sumOnAxis1.length) {
|
||||
throw new IllegalArgumentException("value entries must have the same dimensions");
|
||||
}
|
||||
sumOnAxis1[i] += (value[i] * weight);
|
||||
}
|
||||
}
|
||||
if (sumOnAxis1.length > 1) {
|
||||
return softMax(sumOnAxis1);
|
||||
}
|
||||
|
||||
double probOfClassOne = sigmoid(sumOnAxis1[0]);
|
||||
assert 0.0 <= probOfClassOne && probOfClassOne <= 1.0;
|
||||
return Arrays.asList(1.0 - probOfClassOne, probOfClassOne);
|
||||
return new double[] {1.0 - probOfClassOne, probOfClassOne};
|
||||
}
|
||||
|
||||
@Override
|
||||
public double aggregate(List<Double> values) {
|
||||
public double aggregate(double[] values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
assert values.size() == 2;
|
||||
int bestValue = 0;
|
||||
double bestProb = Double.NEGATIVE_INFINITY;
|
||||
for (int i = 0; i < values.size(); i++) {
|
||||
if (values.get(i) == null) {
|
||||
throw new IllegalArgumentException("values must not contain null values");
|
||||
}
|
||||
if (values.get(i) > bestProb) {
|
||||
bestProb = values.get(i);
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
if (values[i] > bestProb) {
|
||||
bestProb = values[i];
|
||||
bestValue = i;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,8 +10,6 @@ import org.elasticsearch.common.io.stream.NamedWriteable;
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface OutputAggregator extends NamedXContentObject, NamedWriteable, Accountable {
|
||||
|
||||
/**
|
||||
|
@ -20,15 +18,15 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable, A
|
|||
Integer expectedValueSize();
|
||||
|
||||
/**
|
||||
* This pre-processes the values so that they may be passed directly to the {@link OutputAggregator#aggregate(List)} method.
|
||||
* This pre-processes the values so that they may be passed directly to the {@link OutputAggregator#aggregate(double[])} method.
|
||||
*
|
||||
* Two major types of pre-processed values could be returned:
|
||||
* - The confidence/probability scaled values given the input values (See: {@link WeightedMode#processValues(List)}
|
||||
* - A simple transformation of the passed values in preparation for aggregation (See: {@link WeightedSum#processValues(List)}
|
||||
* - The confidence/probability scaled values given the input values (See: {@link WeightedMode#processValues(double[][])}
|
||||
* - A simple transformation of the passed values in preparation for aggregation (See: {@link WeightedSum#processValues(double[][])}
|
||||
* @param values the values to process
|
||||
* @return A new list containing the processed values or the same list if no processing is required
|
||||
*/
|
||||
List<Double> processValues(List<Double> values);
|
||||
double[] processValues(double[][] values);
|
||||
|
||||
/**
|
||||
* Function to aggregate the processed values into a single double
|
||||
|
@ -40,7 +38,7 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable, A
|
|||
* @param processedValues The values to aggregate
|
||||
* @return the aggregated value.
|
||||
*/
|
||||
double aggregate(List<Double> processedValues);
|
||||
double aggregate(double[] processedValues);
|
||||
|
||||
/**
|
||||
* @return The name of the output aggregator
|
||||
|
|
|
@ -89,21 +89,37 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
}
|
||||
|
||||
@Override
|
||||
public List<Double> processValues(List<Double> values) {
|
||||
public double[] processValues(double[][] values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
if (weights != null && values.size() != weights.length) {
|
||||
if (weights != null && values.length != weights.length) {
|
||||
throw new IllegalArgumentException("values must be the same length as weights.");
|
||||
}
|
||||
List<Integer> freqArray = new ArrayList<>();
|
||||
Integer maxVal = 0;
|
||||
for (Double value : values) {
|
||||
if (value == null) {
|
||||
throw new IllegalArgumentException("values must not contain null values");
|
||||
// Multiple leaf values
|
||||
if (values[0].length > 1) {
|
||||
double[] sumOnAxis1 = new double[values[0].length];
|
||||
for (int j = 0; j < values.length; j++) {
|
||||
double[] value = values[j];
|
||||
double weight = weights == null ? 1.0 : weights[j];
|
||||
for(int i = 0; i < value.length; i++) {
|
||||
if (i >= sumOnAxis1.length) {
|
||||
throw new IllegalArgumentException("value entries must have the same dimensions");
|
||||
}
|
||||
sumOnAxis1[i] += (value[i] * weight);
|
||||
}
|
||||
}
|
||||
if (Double.isNaN(value) || Double.isInfinite(value) || value < 0.0 || value != Math.rint(value)) {
|
||||
return softMax(sumOnAxis1);
|
||||
}
|
||||
// Singular leaf values
|
||||
List<Integer> freqArray = new ArrayList<>();
|
||||
int maxVal = 0;
|
||||
for (double[] value : values) {
|
||||
if (value.length != 1) {
|
||||
throw new IllegalArgumentException("value entries must have the same dimensions");
|
||||
}
|
||||
if (Double.isNaN(value[0]) || Double.isInfinite(value[0]) || value[0] < 0.0 || value[0] != Math.rint(value[0])) {
|
||||
throw new IllegalArgumentException("values must be whole, non-infinite, and positive");
|
||||
}
|
||||
Integer integerValue = value.intValue();
|
||||
int integerValue = Double.valueOf(value[0]).intValue();
|
||||
freqArray.add(integerValue);
|
||||
if (integerValue > maxVal) {
|
||||
maxVal = integerValue;
|
||||
|
@ -112,27 +128,27 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
if (maxVal >= numClasses) {
|
||||
throw new IllegalArgumentException("values contain entries larger than expected max of [" + (numClasses - 1) + "]");
|
||||
}
|
||||
List<Double> frequencies = new ArrayList<>(Collections.nCopies(numClasses, Double.NEGATIVE_INFINITY));
|
||||
double[] frequencies = Collections.nCopies(numClasses, Double.NEGATIVE_INFINITY)
|
||||
.stream()
|
||||
.mapToDouble(Double::doubleValue)
|
||||
.toArray();
|
||||
for (int i = 0; i < freqArray.size(); i++) {
|
||||
Double weight = weights == null ? 1.0 : weights[i];
|
||||
Integer value = freqArray.get(i);
|
||||
Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight;
|
||||
frequencies.set(value, frequency);
|
||||
double weight = weights == null ? 1.0 : weights[i];
|
||||
int value = freqArray.get(i);
|
||||
double frequency = frequencies[value] == Double.NEGATIVE_INFINITY ? weight : frequencies[value] + weight;
|
||||
frequencies[value] = frequency;
|
||||
}
|
||||
return softMax(frequencies);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double aggregate(List<Double> values) {
|
||||
public double aggregate(double[] values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
int bestValue = 0;
|
||||
double bestFreq = Double.NEGATIVE_INFINITY;
|
||||
for (int i = 0; i < values.size(); i++) {
|
||||
if (values.get(i) == null) {
|
||||
throw new IllegalArgumentException("values must not contain null values");
|
||||
}
|
||||
if (values.get(i) > bestFreq) {
|
||||
bestFreq = values.get(i);
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
if (values[i] > bestFreq) {
|
||||
bestFreq = values[i];
|
||||
bestValue = i;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,8 +19,6 @@ import java.io.IOException;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
|
||||
|
@ -73,28 +71,25 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar
|
|||
}
|
||||
|
||||
@Override
|
||||
public List<Double> processValues(List<Double> values) {
|
||||
public double[] processValues(double[][] values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
assert values[0].length == 1;
|
||||
if (weights == null) {
|
||||
return values;
|
||||
return Arrays.stream(values).mapToDouble(v -> v[0]).toArray();
|
||||
}
|
||||
if (values.size() != weights.length) {
|
||||
if (values.length != weights.length) {
|
||||
throw new IllegalArgumentException("values must be the same length as weights.");
|
||||
}
|
||||
return IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).boxed().collect(Collectors.toList());
|
||||
return IntStream.range(0, weights.length).mapToDouble(i -> values[i][0] * weights[i]).toArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double aggregate(List<Double> values) {
|
||||
public double aggregate(double[] values) {
|
||||
Objects.requireNonNull(values, "values must not be null");
|
||||
if (values.isEmpty()) {
|
||||
if (values.length == 0) {
|
||||
throw new IllegalArgumentException("values must not be empty");
|
||||
}
|
||||
Optional<Double> summation = values.stream().reduce(Double::sum);
|
||||
if (summation.isPresent()) {
|
||||
return summation.get();
|
||||
}
|
||||
throw new IllegalArgumentException("values must not contain null values");
|
||||
return Arrays.stream(values).sum();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -30,7 +30,6 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
|
||||
|
@ -130,7 +129,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
|
|||
double[] h0 = hiddenLayer.productPlusBias(false, embeddedVector);
|
||||
double[] scores = softmaxLayer.productPlusBias(true, h0);
|
||||
|
||||
List<Double> probabilities = softMax(Arrays.stream(scores).boxed().collect(Collectors.toList()));
|
||||
double[] probabilities = softMax(scores);
|
||||
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
|
|||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Accountables;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
|
@ -29,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfi
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
|
||||
|
||||
|
@ -100,7 +102,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
this.nodes = Collections.unmodifiableList(nodes);
|
||||
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
|
||||
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
|
||||
this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
|
||||
this.highestOrderCategory = new CachedSupplier<>(this::maxLeafValue);
|
||||
}
|
||||
|
||||
public Tree(StreamInput in) throws IOException {
|
||||
|
@ -112,7 +114,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
} else {
|
||||
this.classificationLabels = null;
|
||||
}
|
||||
this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
|
||||
this.highestOrderCategory = new CachedSupplier<>(this::maxLeafValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -147,7 +149,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
return buildResult(node.getLeafValue(), featureImportance, config);
|
||||
}
|
||||
|
||||
private InferenceResults buildResult(Double value, Map<String, Double> featureImportance, InferenceConfig config) {
|
||||
private InferenceResults buildResult(double[] value, Map<String, Double> featureImportance, InferenceConfig config) {
|
||||
assert value != null && value.length > 0;
|
||||
// Indicates that the config is useless and the caller just wants the raw value
|
||||
if (config instanceof NullInferenceConfig) {
|
||||
return new RawInferenceResults(value, featureImportance);
|
||||
|
@ -160,13 +163,13 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
classificationLabels,
|
||||
null,
|
||||
classificationConfig.getNumTopClasses());
|
||||
return new ClassificationInferenceResults(value,
|
||||
return new ClassificationInferenceResults(topClasses.v1(),
|
||||
classificationLabel(topClasses.v1(), classificationLabels),
|
||||
topClasses.v2(),
|
||||
featureImportance,
|
||||
config);
|
||||
case REGRESSION:
|
||||
return new RegressionInferenceResults(value, config, featureImportance);
|
||||
return new RegressionInferenceResults(value[0], config, featureImportance);
|
||||
default:
|
||||
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
|
||||
}
|
||||
|
@ -193,14 +196,22 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
return targetType;
|
||||
}
|
||||
|
||||
private List<Double> classificationProbability(double inferenceValue) {
|
||||
private double[] classificationProbability(double[] inferenceValue) {
|
||||
// Multi-value leaves, indicates that the leaves contain an array of values.
|
||||
// The index of which corresponds to classification values
|
||||
if (inferenceValue.length > 1) {
|
||||
return Statistics.softMax(inferenceValue);
|
||||
}
|
||||
// If we are classification, we should assume that the inference return value is whole.
|
||||
assert inferenceValue == Math.rint(inferenceValue);
|
||||
assert inferenceValue[0] == Math.rint(inferenceValue[0]);
|
||||
double maxCategory = this.highestOrderCategory.get();
|
||||
// If we are classification, we should assume that the largest leaf value is whole.
|
||||
assert maxCategory == Math.rint(maxCategory);
|
||||
List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
|
||||
list.set(Double.valueOf(inferenceValue).intValue(), 1.0);
|
||||
double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)
|
||||
.stream()
|
||||
.mapToDouble(Double::doubleValue)
|
||||
.toArray();
|
||||
list[Double.valueOf(inferenceValue[0]).intValue()] = 1.0;
|
||||
return list;
|
||||
}
|
||||
|
||||
|
@ -268,6 +279,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
checkTargetType();
|
||||
detectMissingNodes();
|
||||
detectCycle();
|
||||
verifyLeafNodeUniformity();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -331,7 +343,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
TreeNode currNode = nodes.get(nodeIndex);
|
||||
nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
|
||||
if (currNode.isLeaf()) {
|
||||
// TODO multi-value????
|
||||
double leafValue = nodeValues[nodeIndex];
|
||||
for (int i = 1; i < nextIndex; ++i) {
|
||||
double scale = splitPath.sumUnwoundPath(i, nextIndex);
|
||||
|
@ -375,7 +386,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) {
|
||||
TreeNode node = nodes.get(nodeIndex);
|
||||
if (node.isLeaf()) {
|
||||
nodeEstimates[nodeIndex] = node.getLeafValue();
|
||||
// TODO multi-value????
|
||||
nodeEstimates[nodeIndex] = node.getLeafValue()[0];
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -424,6 +436,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
throw ExceptionsHelper.badRequestException(
|
||||
"[target_type] should be [classification] if [classification_labels] are provided");
|
||||
}
|
||||
if (this.targetType != TargetType.CLASSIFICATION && this.nodes.stream().anyMatch(n -> n.getLeafValue().length > 1)) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[target_type] should be [classification] if leaf nodes have multiple values");
|
||||
}
|
||||
}
|
||||
|
||||
private void detectCycle() {
|
||||
|
@ -465,14 +481,39 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
}
|
||||
}
|
||||
|
||||
private void verifyLeafNodeUniformity() {
|
||||
Integer leafValueLengths = null;
|
||||
for (TreeNode node : nodes) {
|
||||
if (node.isLeaf()) {
|
||||
if (leafValueLengths == null) {
|
||||
leafValueLengths = node.getLeafValue().length;
|
||||
} else if (leafValueLengths != node.getLeafValue().length) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[tree.tree_structure] all leaf nodes must have the same number of values");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean nodeMissing(int nodeIdx, List<TreeNode> nodes) {
|
||||
return nodeIdx >= nodes.size();
|
||||
}
|
||||
|
||||
private Double maxLeafValue() {
|
||||
return targetType == TargetType.CLASSIFICATION ?
|
||||
this.nodes.stream().filter(TreeNode::isLeaf).mapToDouble(TreeNode::getLeafValue).max().getAsDouble() :
|
||||
null;
|
||||
if (targetType != TargetType.CLASSIFICATION) {
|
||||
return null;
|
||||
}
|
||||
double max = 0.0;
|
||||
for (TreeNode node : this.nodes) {
|
||||
if (node.isLeaf()) {
|
||||
if (node.getLeafValue().length > 1) {
|
||||
return (double)node.getLeafValue().length;
|
||||
} else {
|
||||
max = Math.max(node.getLeafValue()[0], max);
|
||||
}
|
||||
}
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -493,6 +534,14 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
return Collections.unmodifiableCollection(accountables);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Version getMinimalCompatibilityVersion() {
|
||||
if (nodes.stream().filter(TreeNode::isLeaf).anyMatch(t -> t.getLeafValue().length > 1)) {
|
||||
return Version.V_7_7_0;
|
||||
}
|
||||
return Version.V_7_6_0;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private List<String> featureNames;
|
||||
private ArrayList<TreeNode.Builder> nodes;
|
||||
|
@ -586,6 +635,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
* @return this
|
||||
*/
|
||||
Tree.Builder addLeaf(int nodeIndex, double value) {
|
||||
return addLeaf(nodeIndex, Arrays.asList(value));
|
||||
}
|
||||
|
||||
Tree.Builder addLeaf(int nodeIndex, List<Double> value) {
|
||||
for (int i = nodes.size(); i < nodeIndex + 1; i++) {
|
||||
nodes.add(null);
|
||||
}
|
||||
|
|
|
@ -21,6 +21,8 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
import org.elasticsearch.xpack.core.ml.job.config.Operator;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
|
@ -60,7 +62,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
parser.declareInt(TreeNode.Builder::setSplitFeature, SPLIT_FEATURE);
|
||||
parser.declareInt(TreeNode.Builder::setNodeIndex, NODE_INDEX);
|
||||
parser.declareDouble(TreeNode.Builder::setSplitGain, SPLIT_GAIN);
|
||||
parser.declareDouble(TreeNode.Builder::setLeafValue, LEAF_VALUE);
|
||||
parser.declareDoubleArray(TreeNode.Builder::setLeafValue, LEAF_VALUE);
|
||||
parser.declareLong(TreeNode.Builder::setNumberSamples, NUMBER_SAMPLES);
|
||||
return parser;
|
||||
}
|
||||
|
@ -74,7 +76,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
private final int splitFeature;
|
||||
private final int nodeIndex;
|
||||
private final double splitGain;
|
||||
private final double leafValue;
|
||||
private final double[] leafValue;
|
||||
private final boolean defaultLeft;
|
||||
private final int leftChild;
|
||||
private final int rightChild;
|
||||
|
@ -86,7 +88,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
Integer splitFeature,
|
||||
int nodeIndex,
|
||||
Double splitGain,
|
||||
Double leafValue,
|
||||
List<Double> leafValue,
|
||||
Boolean defaultLeft,
|
||||
Integer leftChild,
|
||||
Integer rightChild,
|
||||
|
@ -96,7 +98,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
this.splitFeature = splitFeature == null ? -1 : splitFeature;
|
||||
this.nodeIndex = nodeIndex;
|
||||
this.splitGain = splitGain == null ? Double.NaN : splitGain;
|
||||
this.leafValue = leafValue == null ? Double.NaN : leafValue;
|
||||
this.leafValue = leafValue == null ? new double[0] : leafValue.stream().mapToDouble(Double::doubleValue).toArray();
|
||||
this.defaultLeft = defaultLeft == null ? false : defaultLeft;
|
||||
this.leftChild = leftChild == null ? -1 : leftChild;
|
||||
this.rightChild = rightChild == null ? -1 : rightChild;
|
||||
|
@ -112,7 +114,11 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
splitFeature = in.readInt();
|
||||
splitGain = in.readDouble();
|
||||
nodeIndex = in.readVInt();
|
||||
leafValue = in.readDouble();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
leafValue = in.readDoubleArray();
|
||||
} else {
|
||||
leafValue = new double[]{in.readDouble()};
|
||||
}
|
||||
defaultLeft = in.readBoolean();
|
||||
leftChild = in.readInt();
|
||||
rightChild = in.readInt();
|
||||
|
@ -144,7 +150,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
return splitGain;
|
||||
}
|
||||
|
||||
public double getLeafValue() {
|
||||
public double[] getLeafValue() {
|
||||
return leafValue;
|
||||
}
|
||||
|
||||
|
@ -190,7 +196,18 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
out.writeInt(splitFeature);
|
||||
out.writeDouble(splitGain);
|
||||
out.writeVInt(nodeIndex);
|
||||
out.writeDouble(leafValue);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
out.writeDoubleArray(leafValue);
|
||||
} else {
|
||||
if (leafValue.length > 1) {
|
||||
throw new IOException("Multi-class classification models require that all nodes are at least version 7.7.0.");
|
||||
}
|
||||
if (leafValue.length == 0) {
|
||||
out.writeDouble(Double.NaN);
|
||||
} else {
|
||||
out.writeDouble(leafValue[0]);
|
||||
}
|
||||
}
|
||||
out.writeBoolean(defaultLeft);
|
||||
out.writeInt(leftChild);
|
||||
out.writeInt(rightChild);
|
||||
|
@ -209,7 +226,9 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
}
|
||||
addOptionalDouble(builder, SPLIT_GAIN, splitGain);
|
||||
builder.field(NODE_INDEX.getPreferredName(), nodeIndex);
|
||||
addOptionalDouble(builder, LEAF_VALUE, leafValue);
|
||||
if (leafValue.length > 0) {
|
||||
builder.field(LEAF_VALUE.getPreferredName(), leafValue);
|
||||
}
|
||||
builder.field(DEFAULT_LEFT.getPreferredName(), defaultLeft);
|
||||
if (leftChild >= 0) {
|
||||
builder.field(LEFT_CHILD.getPreferredName(), leftChild);
|
||||
|
@ -238,7 +257,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
&& Objects.equals(splitFeature, that.splitFeature)
|
||||
&& Objects.equals(nodeIndex, that.nodeIndex)
|
||||
&& Objects.equals(splitGain, that.splitGain)
|
||||
&& Objects.equals(leafValue, that.leafValue)
|
||||
&& Arrays.equals(leafValue, that.leafValue)
|
||||
&& Objects.equals(defaultLeft, that.defaultLeft)
|
||||
&& Objects.equals(leftChild, that.leftChild)
|
||||
&& Objects.equals(rightChild, that.rightChild)
|
||||
|
@ -252,7 +271,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
splitFeature,
|
||||
splitGain,
|
||||
nodeIndex,
|
||||
leafValue,
|
||||
Arrays.hashCode(leafValue),
|
||||
defaultLeft,
|
||||
leftChild,
|
||||
rightChild,
|
||||
|
@ -270,7 +289,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return SHALLOW_SIZE;
|
||||
return SHALLOW_SIZE + this.leafValue.length * Double.BYTES;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
@ -279,7 +298,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
private Integer splitFeature;
|
||||
private int nodeIndex;
|
||||
private Double splitGain;
|
||||
private Double leafValue;
|
||||
private List<Double> leafValue;
|
||||
private Boolean defaultLeft;
|
||||
private Integer leftChild;
|
||||
private Integer rightChild;
|
||||
|
@ -317,11 +336,19 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setLeafValue(Double leafValue) {
|
||||
public Builder setLeafValue(double leafValue) {
|
||||
return this.setLeafValue(Collections.singletonList(leafValue));
|
||||
}
|
||||
|
||||
public Builder setLeafValue(List<Double> leafValue) {
|
||||
this.leafValue = leafValue;
|
||||
return this;
|
||||
}
|
||||
|
||||
List<Double> getLeafValue() {
|
||||
return this.leafValue;
|
||||
}
|
||||
|
||||
public Builder setDefaultLeft(Boolean defaultLeft) {
|
||||
this.defaultLeft = defaultLeft;
|
||||
return this;
|
||||
|
@ -358,6 +385,9 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|||
if (leafValue == null) {
|
||||
throw new IllegalArgumentException("[leaf_value] is required for a leaf node.");
|
||||
}
|
||||
if (leafValue.stream().anyMatch(Objects::isNull)) {
|
||||
throw new IllegalArgumentException("[leaf_value] cannot have null values.");
|
||||
}
|
||||
} else {
|
||||
if (leftChild < 0) {
|
||||
throw new IllegalArgumentException("[left_child] must be a non-negative integer.");
|
||||
|
|
|
@ -7,8 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.utils;
|
|||
|
||||
import org.elasticsearch.common.Numbers;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.Arrays;
|
||||
|
||||
public final class Statistics {
|
||||
|
||||
|
@ -20,28 +19,29 @@ public final class Statistics {
|
|||
* Any {@link Double#isInfinite()}, {@link Double#NaN}, or `null` values are ignored in calculation and returned as 0.0 in the
|
||||
* softMax.
|
||||
* @param values Values on which to run SoftMax.
|
||||
* @return A new list containing the softmax of the passed values
|
||||
* @return A new array containing the softmax of the passed values
|
||||
*/
|
||||
public static List<Double> softMax(List<Double> values) {
|
||||
Double expSum = 0.0;
|
||||
Double max = values.stream().filter(Statistics::isValid).max(Double::compareTo).orElse(null);
|
||||
if (max == null) {
|
||||
public static double[] softMax(double[] values) {
|
||||
double expSum = 0.0;
|
||||
double max = Arrays.stream(values).filter(Statistics::isValid).max().orElse(Double.NaN);
|
||||
if (isValid(max) == false) {
|
||||
throw new IllegalArgumentException("no valid values present");
|
||||
}
|
||||
List<Double> exps = values.stream().map(v -> isValid(v) ? v - max : Double.NEGATIVE_INFINITY)
|
||||
.collect(Collectors.toList());
|
||||
for (int i = 0; i < exps.size(); i++) {
|
||||
if (isValid(exps.get(i))) {
|
||||
Double exp = Math.exp(exps.get(i));
|
||||
double[] exps = new double[values.length];
|
||||
for (int i = 0; i < exps.length; i++) {
|
||||
if (isValid(values[i])) {
|
||||
double exp = Math.exp(values[i] - max);
|
||||
expSum += exp;
|
||||
exps.set(i, exp);
|
||||
exps[i] = exp;
|
||||
} else {
|
||||
exps[i] = Double.NaN;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < exps.size(); i++) {
|
||||
if (isValid(exps.get(i))) {
|
||||
exps.set(i, exps.get(i)/expSum);
|
||||
for (int i = 0; i < exps.length; i++) {
|
||||
if (isValid(exps[i])) {
|
||||
exps[i] /= expSum;
|
||||
} else {
|
||||
exps.set(i, 0.0);
|
||||
exps[i] = 0.0;
|
||||
}
|
||||
}
|
||||
return exps;
|
||||
|
@ -51,8 +51,8 @@ public final class Statistics {
|
|||
return 1/(1 + Math.exp(-value));
|
||||
}
|
||||
|
||||
private static boolean isValid(Double v) {
|
||||
return v != null && Numbers.isValidDouble(v);
|
||||
private static boolean isValid(double v) {
|
||||
return Numbers.isValidDouble(v);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -5,24 +5,37 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class RawInferenceResultsTests extends AbstractWireSerializingTestCase<RawInferenceResults> {
|
||||
import static org.hamcrest.CoreMatchers.equalTo;
|
||||
|
||||
public class RawInferenceResultsTests extends ESTestCase {
|
||||
|
||||
public static RawInferenceResults createRandomResults() {
|
||||
return new RawInferenceResults(randomDouble(), randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08));
|
||||
int n = randomIntBetween(1, 10);
|
||||
double[] results = new double[n];
|
||||
for (int i = 0; i < n; i++) {
|
||||
results[i] = randomDouble();
|
||||
}
|
||||
return new RawInferenceResults(results, randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RawInferenceResults createTestInstance() {
|
||||
return createRandomResults();
|
||||
public void testEqualityAndHashcode() {
|
||||
int n = randomIntBetween(1, 10);
|
||||
double[] results = new double[n];
|
||||
for (int i = 0; i < n; i++) {
|
||||
results[i] = randomDouble();
|
||||
}
|
||||
Map<String, Double> importance = randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08);
|
||||
RawInferenceResults lft = new RawInferenceResults(results, new HashMap<>(importance));
|
||||
RawInferenceResults rgt = new RawInferenceResults(Arrays.copyOf(results, n), new HashMap<>(importance));
|
||||
assertThat(lft, equalTo(rgt));
|
||||
assertThat(lft.hashCode(), equalTo(rgt.hashCode()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<RawInferenceResults> instanceReader() {
|
||||
return RawInferenceResults::new;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,11 +11,10 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticRegression> {
|
||||
|
@ -43,7 +42,13 @@ public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticReg
|
|||
|
||||
public void testAggregate() {
|
||||
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
|
||||
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
|
||||
double[][] values = new double[][]{
|
||||
new double[] {1.0},
|
||||
new double[] {2.0},
|
||||
new double[] {2.0},
|
||||
new double[] {3.0},
|
||||
new double[] {5.0}
|
||||
};
|
||||
|
||||
LogisticRegression logisticRegression = new LogisticRegression(ones);
|
||||
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0));
|
||||
|
@ -57,6 +62,36 @@ public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticReg
|
|||
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0));
|
||||
}
|
||||
|
||||
public void testAggregateMultiValueArrays() {
|
||||
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
|
||||
double[][] values = new double[][]{
|
||||
new double[] {1.0, 0.0, 1.0},
|
||||
new double[] {2.0, 0.0, 0.0},
|
||||
new double[] {2.0, 3.0, 1.0},
|
||||
new double[] {3.0, 3.0, 1.0},
|
||||
new double[] {1.0, 1.0, 5.0}
|
||||
};
|
||||
|
||||
LogisticRegression logisticRegression = new LogisticRegression(ones);
|
||||
double[] processedValues = logisticRegression.processValues(values);
|
||||
assertThat(processedValues.length, equalTo(3));
|
||||
assertThat(processedValues[0], closeTo(0.665240955, 0.00001));
|
||||
assertThat(processedValues[1], closeTo(0.090030573, 0.00001));
|
||||
assertThat(processedValues[2], closeTo(0.244728471, 0.00001));
|
||||
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(0.0));
|
||||
|
||||
double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0};
|
||||
|
||||
logisticRegression = new LogisticRegression(variedWeights);
|
||||
processedValues = logisticRegression.processValues(values);
|
||||
assertThat(processedValues.length, equalTo(3));
|
||||
assertThat(processedValues[0], closeTo(0.0, 0.00001));
|
||||
assertThat(processedValues[1], closeTo(0.0, 0.00001));
|
||||
assertThat(processedValues[2], closeTo(0.9999999, 0.00001));
|
||||
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(2.0));
|
||||
|
||||
}
|
||||
|
||||
public void testCompatibleWith() {
|
||||
LogisticRegression logisticRegression = createTestInstance();
|
||||
assertThat(logisticRegression.compatibleWith(TargetType.CLASSIFICATION), is(true));
|
||||
|
|
|
@ -8,9 +8,6 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
|||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public abstract class WeightedAggregatorTests<T extends OutputAggregator> extends AbstractSerializingTestCase<T> {
|
||||
|
@ -35,9 +32,9 @@ public abstract class WeightedAggregatorTests<T extends OutputAggregator> extend
|
|||
|
||||
public void testWithValuesOfWrongLength() {
|
||||
int numberOfValues = randomIntBetween(5, 10);
|
||||
List<Double> values = new ArrayList<>(numberOfValues);
|
||||
double[][] values = new double[numberOfValues][];
|
||||
for (int i = 0; i < numberOfValues; i++) {
|
||||
values.add(randomDouble());
|
||||
values[i] = new double[] {randomDouble()};
|
||||
}
|
||||
|
||||
OutputAggregator outputAggregatorWithTooFewWeights = createTestInstance(randomIntBetween(1, numberOfValues - 1));
|
||||
|
|
|
@ -11,8 +11,6 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
|
@ -44,7 +42,13 @@ public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
|
|||
|
||||
public void testAggregate() {
|
||||
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
|
||||
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
|
||||
double[][] values = new double[][]{
|
||||
new double[] {1.0},
|
||||
new double[] {2.0},
|
||||
new double[] {2.0},
|
||||
new double[] {3.0},
|
||||
new double[] {5.0}
|
||||
};
|
||||
|
||||
WeightedMode weightedMode = new WeightedMode(ones, 6);
|
||||
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
|
||||
|
@ -57,19 +61,55 @@ public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
|
|||
weightedMode = new WeightedMode(6);
|
||||
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
|
||||
|
||||
values = Arrays.asList(1.0, 1.0, 1.0, 1.0, 2.0);
|
||||
values = new double[][]{
|
||||
new double[] {1.0},
|
||||
new double[] {1.0},
|
||||
new double[] {1.0},
|
||||
new double[] {1.0},
|
||||
new double[] {2.0}
|
||||
};
|
||||
weightedMode = new WeightedMode(6);
|
||||
List<Double> processedValues = weightedMode.processValues(values);
|
||||
assertThat(processedValues.size(), equalTo(6));
|
||||
assertThat(processedValues.get(0), equalTo(0.0));
|
||||
assertThat(processedValues.get(1), closeTo(0.95257412, 0.00001));
|
||||
assertThat(processedValues.get(2), closeTo((1.0 - 0.95257412), 0.00001));
|
||||
assertThat(processedValues.get(3), equalTo(0.0));
|
||||
assertThat(processedValues.get(4), equalTo(0.0));
|
||||
assertThat(processedValues.get(5), equalTo(0.0));
|
||||
double[] processedValues = weightedMode.processValues(values);
|
||||
assertThat(processedValues.length, equalTo(6));
|
||||
assertThat(processedValues[0], equalTo(0.0));
|
||||
assertThat(processedValues[1], closeTo(0.95257412, 0.00001));
|
||||
assertThat(processedValues[2], closeTo((1.0 - 0.95257412), 0.00001));
|
||||
assertThat(processedValues[3], equalTo(0.0));
|
||||
assertThat(processedValues[4], equalTo(0.0));
|
||||
assertThat(processedValues[5], equalTo(0.0));
|
||||
assertThat(weightedMode.aggregate(processedValues), equalTo(1.0));
|
||||
}
|
||||
|
||||
public void testAggregateMultiValueArrays() {
|
||||
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
|
||||
double[][] values = new double[][]{
|
||||
new double[] {1.0, 0.0, 1.0},
|
||||
new double[] {2.0, 0.0, 0.0},
|
||||
new double[] {2.0, 3.0, 1.0},
|
||||
new double[] {3.0, 3.0, 1.0},
|
||||
new double[] {1.0, 1.0, 5.0}
|
||||
};
|
||||
|
||||
WeightedMode weightedMode = new WeightedMode(ones, 3);
|
||||
double[] processedValues = weightedMode.processValues(values);
|
||||
assertThat(processedValues.length, equalTo(3));
|
||||
assertThat(processedValues[0], closeTo(0.665240955, 0.00001));
|
||||
assertThat(processedValues[1], closeTo(0.090030573, 0.00001));
|
||||
assertThat(processedValues[2], closeTo(0.244728471, 0.00001));
|
||||
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(0.0));
|
||||
|
||||
double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0};
|
||||
|
||||
weightedMode = new WeightedMode(variedWeights, 3);
|
||||
processedValues = weightedMode.processValues(values);
|
||||
assertThat(processedValues.length, equalTo(3));
|
||||
assertThat(processedValues[0], closeTo(0.0, 0.00001));
|
||||
assertThat(processedValues[1], closeTo(0.0, 0.00001));
|
||||
assertThat(processedValues[2], closeTo(0.9999999, 0.00001));
|
||||
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
|
||||
|
||||
}
|
||||
|
||||
public void testCompatibleWith() {
|
||||
WeightedMode weightedMode = createTestInstance();
|
||||
assertThat(weightedMode.compatibleWith(TargetType.CLASSIFICATION), is(true));
|
||||
|
|
|
@ -11,8 +11,6 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
|
@ -43,7 +41,13 @@ public class WeightedSumTests extends WeightedAggregatorTests<WeightedSum> {
|
|||
|
||||
public void testAggregate() {
|
||||
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
|
||||
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
|
||||
double[][] values = new double[][]{
|
||||
new double[] {1.0},
|
||||
new double[] {2.0},
|
||||
new double[] {2.0},
|
||||
new double[] {3.0},
|
||||
new double[] {5.0}
|
||||
};
|
||||
|
||||
WeightedSum weightedSum = new WeightedSum(ones);
|
||||
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0));
|
||||
|
|
|
@ -55,7 +55,7 @@ public class TreeNodeTests extends AbstractSerializingTestCase<TreeNode> {
|
|||
return TreeNode.builder(randomInt(100))
|
||||
.setDefaultLeft(randomBoolean() ? null : randomBoolean())
|
||||
.setNumberSamples(randomNonNegativeLong())
|
||||
.setLeafValue(internalValue)
|
||||
.setLeafValue(Collections.singletonList(internalValue))
|
||||
.build();
|
||||
}
|
||||
|
||||
|
@ -66,7 +66,7 @@ public class TreeNodeTests extends AbstractSerializingTestCase<TreeNode> {
|
|||
Integer featureIndex,
|
||||
Operator operator) {
|
||||
return TreeNode.builder(nodeId)
|
||||
.setLeafValue(left == null ? randomDouble() : null)
|
||||
.setLeafValue(left == null ? Collections.singletonList(randomDouble()) : null)
|
||||
.setDefaultLeft(randomBoolean() ? null : randomBoolean())
|
||||
.setLeftChild(left)
|
||||
.setRightChild(right)
|
||||
|
|
|
@ -112,7 +112,7 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
|
||||
public void testInferWithStump() {
|
||||
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
|
||||
builder.setRoot(TreeNode.builder(0).setLeafValue(42.0));
|
||||
builder.setRoot(TreeNode.builder(0).setLeafValue(Collections.singletonList(42.0)));
|
||||
builder.setFeatureNames(Collections.emptyList());
|
||||
|
||||
Tree tree = builder.build();
|
||||
|
|
|
@ -16,18 +16,18 @@ import static org.hamcrest.Matchers.closeTo;
|
|||
public class StatisticsTests extends ESTestCase {
|
||||
|
||||
public void testSoftMax() {
|
||||
List<Double> values = Arrays.asList(Double.NEGATIVE_INFINITY, 1.0, -0.5, null, Double.NaN, Double.POSITIVE_INFINITY, 1.0, 5.0);
|
||||
List<Double> softMax = Statistics.softMax(values);
|
||||
double[] values = new double[] {Double.NEGATIVE_INFINITY, 1.0, -0.5, Double.NaN, Double.NaN, Double.POSITIVE_INFINITY, 1.0, 5.0};
|
||||
double[] softMax = Statistics.softMax(values);
|
||||
|
||||
List<Double> expected = Arrays.asList(0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042);
|
||||
double[] expected = new double[] {0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042};
|
||||
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(softMax.get(i), closeTo(expected.get(i), 0.000001));
|
||||
for(int i = 0; i < expected.length; i++) {
|
||||
assertThat(softMax[i], closeTo(expected[i], 0.000001));
|
||||
}
|
||||
}
|
||||
|
||||
public void testSoftMaxWithNoValidValues() {
|
||||
List<Double> values = Arrays.asList(Double.NEGATIVE_INFINITY, null, Double.NaN, Double.POSITIVE_INFINITY);
|
||||
double[] values = new double[] {Double.NEGATIVE_INFINITY, Double.NaN, Double.POSITIVE_INFINITY};
|
||||
expectThrows(IllegalArgumentException.class, () -> Statistics.softMax(values));
|
||||
}
|
||||
|
||||
|
|
|
@ -211,14 +211,14 @@ public class TrainedModelIT extends ESRestTestCase {
|
|||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5),
|
||||
TreeNode.builder(1).setLeafValue(0.3),
|
||||
TreeNode.builder(1).setLeafValue(Collections.singletonList(0.3)),
|
||||
TreeNode.builder(2)
|
||||
.setThreshold(0.0)
|
||||
.setSplitFeature(3)
|
||||
.setLeftChild(3)
|
||||
.setRightChild(4),
|
||||
TreeNode.builder(3).setLeafValue(0.1),
|
||||
TreeNode.builder(4).setLeafValue(0.2))
|
||||
TreeNode.builder(3).setLeafValue(Collections.singletonList(0.1)),
|
||||
TreeNode.builder(4).setLeafValue(Collections.singletonList(0.2)))
|
||||
.build();
|
||||
Tree tree2 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
|
@ -227,8 +227,8 @@ public class TrainedModelIT extends ESRestTestCase {
|
|||
.setRightChild(2)
|
||||
.setSplitFeature(2)
|
||||
.setThreshold(1.0),
|
||||
TreeNode.builder(1).setLeafValue(1.5),
|
||||
TreeNode.builder(2).setLeafValue(0.9))
|
||||
TreeNode.builder(1).setLeafValue(Collections.singletonList(1.5)),
|
||||
TreeNode.builder(2).setLeafValue(Collections.singletonList(0.9)))
|
||||
.build();
|
||||
Tree tree3 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
|
@ -237,8 +237,8 @@ public class TrainedModelIT extends ESRestTestCase {
|
|||
.setRightChild(2)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(0.2),
|
||||
TreeNode.builder(1).setLeafValue(1.5),
|
||||
TreeNode.builder(2).setLeafValue(0.9))
|
||||
TreeNode.builder(1).setLeafValue(Collections.singletonList(1.5)),
|
||||
TreeNode.builder(2).setLeafValue(Collections.singletonList(0.9)))
|
||||
.build();
|
||||
return Ensemble.builder()
|
||||
.setTargetType(TargetType.REGRESSION)
|
||||
|
|
|
@ -94,6 +94,18 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Re
|
|||
return;
|
||||
}
|
||||
|
||||
Version minCompatibilityVersion = request.getTrainedModelConfig()
|
||||
.getModelDefinition()
|
||||
.getTrainedModel()
|
||||
.getMinimalCompatibilityVersion();
|
||||
if (state.nodes().getMinNodeVersion().before(minCompatibilityVersion)) {
|
||||
listener.onFailure(ExceptionsHelper.badRequestException(
|
||||
"Definition for [{}] requires that all nodes are at least version [{}]",
|
||||
request.getTrainedModelConfig().getModelId(),
|
||||
minCompatibilityVersion.toString()));
|
||||
return;
|
||||
}
|
||||
|
||||
TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig())
|
||||
.setVersion(Version.CURRENT)
|
||||
.setCreateTime(Instant.now())
|
||||
|
|
|
@ -22,6 +22,12 @@ import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceRes
|
|||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
|
@ -189,6 +195,109 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be"));
|
||||
}
|
||||
|
||||
public void testInferModelMultiClassModel() throws Exception {
|
||||
String modelId = "test-load-models-classification-multi";
|
||||
Map<String, String> oneHotEncoding = new HashMap<>();
|
||||
oneHotEncoding.put("cat", "animal_cat");
|
||||
oneHotEncoding.put("dog", "animal_dog");
|
||||
TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
|
||||
.setParsedDefinition(new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
|
||||
.setTrainedModel(buildMultiClassClassification()))
|
||||
.setVersion(Version.CURRENT)
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
.setCreateTime(Instant.now())
|
||||
.setEstimatedOperations(0)
|
||||
.setEstimatedHeapMemory(0)
|
||||
.build();
|
||||
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
|
||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||
|
||||
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
|
||||
assertThat(putConfigHolder.get(), is(true));
|
||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
|
||||
|
||||
List<Map<String, Object>> toInfer = new ArrayList<>();
|
||||
toInfer.add(new HashMap<String, Object>() {{
|
||||
put("field", new HashMap<String, Object>(){{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.5);
|
||||
}});
|
||||
put("other", new HashMap<String, Object>(){{
|
||||
put("categorical", "dog");
|
||||
}});
|
||||
}});
|
||||
toInfer.add(new HashMap<String, Object>() {{
|
||||
put("field", new HashMap<String, Object>(){{
|
||||
put("foo", 0.9);
|
||||
put("bar", 1.5);
|
||||
}});
|
||||
put("other", new HashMap<String, Object>(){{
|
||||
put("categorical", "cat");
|
||||
}});
|
||||
}});
|
||||
|
||||
List<Map<String, Object>> toInfer2 = new ArrayList<>();
|
||||
toInfer2.add(new HashMap<String, Object>() {{
|
||||
put("field", new HashMap<String, Object>(){{
|
||||
put("foo", 0.0);
|
||||
put("bar", 0.01);
|
||||
}});
|
||||
put("other", new HashMap<String, Object>(){{
|
||||
put("categorical", "dog");
|
||||
}});
|
||||
}});
|
||||
toInfer2.add(new HashMap<String, Object>() {{
|
||||
put("field", new HashMap<String, Object>(){{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.0);
|
||||
}});
|
||||
put("other", new HashMap<String, Object>(){{
|
||||
put("categorical", "cat");
|
||||
}});
|
||||
}});
|
||||
|
||||
// Test regression
|
||||
InternalInferModelAction.Request request = new InternalInferModelAction.Request(modelId,
|
||||
toInfer,
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
true);
|
||||
InternalInferModelAction.Response response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
|
||||
assertThat(response.getInferenceResults()
|
||||
.stream()
|
||||
.map(i -> ((SingleValueInferenceResults)i).valueAsString())
|
||||
.collect(Collectors.toList()),
|
||||
contains("option_0", "option_2"));
|
||||
|
||||
request = new InternalInferModelAction.Request(modelId, toInfer2, ClassificationConfig.EMPTY_PARAMS, true);
|
||||
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
|
||||
assertThat(response.getInferenceResults()
|
||||
.stream()
|
||||
.map(i -> ((SingleValueInferenceResults)i).valueAsString())
|
||||
.collect(Collectors.toList()),
|
||||
contains("option_2", "option_0"));
|
||||
|
||||
|
||||
// Get top classes
|
||||
request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfig(3, null, null), true);
|
||||
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
|
||||
|
||||
ClassificationInferenceResults classificationInferenceResults =
|
||||
(ClassificationInferenceResults)response.getInferenceResults().get(0);
|
||||
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("option_0"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("option_2"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(2).getClassification(), equalTo("option_1"));
|
||||
|
||||
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1);
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("option_2"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("option_0"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(2).getClassification(), equalTo("option_1"));
|
||||
}
|
||||
|
||||
|
||||
public void testInferMissingModel() {
|
||||
String model = "test-infer-missing-model";
|
||||
InternalInferModelAction.Request request = new InternalInferModelAction.Request(
|
||||
|
@ -256,6 +365,54 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
.setModelId(modelId);
|
||||
}
|
||||
|
||||
public static TrainedModel buildMultiClassClassification() {
|
||||
List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
|
||||
|
||||
Tree tree1 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(Arrays.asList(1.0, 0.0, 2.0)))
|
||||
.addNode(TreeNode.builder(2)
|
||||
.setThreshold(0.8)
|
||||
.setSplitFeature(1)
|
||||
.setLeftChild(3)
|
||||
.setRightChild(4))
|
||||
.addNode(TreeNode.builder(3).setLeafValue(Arrays.asList(0.0, 1.0, 0.0)))
|
||||
.addNode(TreeNode.builder(4).setLeafValue(Arrays.asList(0.0, 0.0, 1.0))).build();
|
||||
Tree tree2 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(3)
|
||||
.setThreshold(1.0))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(Arrays.asList(2.0, 0.0, 0.0)))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(Arrays.asList(0.0, 2.0, 0.0)))
|
||||
.build();
|
||||
Tree tree3 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(1.0))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(Arrays.asList(0.0, 0.0, 1.0)))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(Arrays.asList(0.0, 1.0, 0.0)))
|
||||
.build();
|
||||
return Ensemble.builder()
|
||||
.setClassificationLabels(Arrays.asList("option_0", "option_1", "option_2"))
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 3))
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
|
|
Loading…
Reference in New Issue