[ML][Inference] Adding classification_weights to ensemble models (#50874) (#50994)

* [ML][Inference] Adding classification_weights to ensemble models

classification_weights are a way to allow models to
prefer specific classification results over others
this might be advantageous if classification value
probabilities are a known quantity and can improve
model error rates.
This commit is contained in:
Benjamin Trent 2020-01-14 12:40:25 -05:00 committed by GitHub
parent de5713fa4b
commit 72c270946f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 190 additions and 156 deletions

View File

@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
@ -41,6 +42,7 @@ public class Ensemble implements TrainedModel {
public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output");
public static final ParseField TARGET_TYPE = new ParseField("target_type");
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
public static final ParseField CLASSIFICATION_WEIGHTS = new ParseField("classification_weights");
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
NAME,
@ -60,6 +62,7 @@ public class Ensemble implements TrainedModel {
AGGREGATE_OUTPUT);
PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
PARSER.declareDoubleArray(Ensemble.Builder::setClassificationWeights, CLASSIFICATION_WEIGHTS);
}
public static Ensemble fromXContent(XContentParser parser) {
@ -71,17 +74,20 @@ public class Ensemble implements TrainedModel {
private final OutputAggregator outputAggregator;
private final TargetType targetType;
private final List<String> classificationLabels;
private final double[] classificationWeights;
Ensemble(List<String> featureNames,
List<TrainedModel> models,
@Nullable OutputAggregator outputAggregator,
TargetType targetType,
@Nullable List<String> classificationLabels) {
@Nullable List<String> classificationLabels,
@Nullable double[] classificationWeights) {
this.featureNames = featureNames;
this.models = models;
this.outputAggregator = outputAggregator;
this.targetType = targetType;
this.classificationLabels = classificationLabels;
this.classificationWeights = classificationWeights;
}
@Override
@ -116,6 +122,9 @@ public class Ensemble implements TrainedModel {
if (classificationLabels != null) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
}
if (classificationWeights != null) {
builder.field(CLASSIFICATION_WEIGHTS.getPreferredName(), classificationWeights);
}
builder.endObject();
return builder;
}
@ -129,12 +138,18 @@ public class Ensemble implements TrainedModel {
&& Objects.equals(models, that.models)
&& Objects.equals(targetType, that.targetType)
&& Objects.equals(classificationLabels, that.classificationLabels)
&& Arrays.equals(classificationWeights, that.classificationWeights)
&& Objects.equals(outputAggregator, that.outputAggregator);
}
@Override
public int hashCode() {
return Objects.hash(featureNames, models, outputAggregator, classificationLabels, targetType);
return Objects.hash(featureNames,
models,
outputAggregator,
classificationLabels,
targetType,
Arrays.hashCode(classificationWeights));
}
public static Builder builder() {
@ -147,6 +162,7 @@ public class Ensemble implements TrainedModel {
private OutputAggregator outputAggregator;
private TargetType targetType;
private List<String> classificationLabels;
private double[] classificationWeights;
public Builder setFeatureNames(List<String> featureNames) {
this.featureNames = featureNames;
@ -173,6 +189,11 @@ public class Ensemble implements TrainedModel {
return this;
}
public Builder setClassificationWeights(List<Double> classificationWeights) {
this.classificationWeights = classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
return this;
}
private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
this.setOutputAggregator(outputAggregators.get(0));
}
@ -182,7 +203,7 @@ public class Ensemble implements TrainedModel {
}
public Ensemble build() {
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels, classificationWeights);
}
}
}

View File

@ -80,11 +80,19 @@ public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ?
Stream.generate(ESTestCase::randomDouble)
.limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size())
.mapToDouble(Double::valueOf)
.toArray() :
null;
return new Ensemble(featureNames,
models,
outputAggregator,
targetType,
categoryLabels);
categoryLabels,
thresholds);
}
@Override

View File

@ -112,18 +112,26 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
public final ParseField CLASS_NAME = new ParseField("class_name");
public final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
public final ParseField CLASS_SCORE = new ParseField("class_score");
private final String classification;
private final double probability;
private final double score;
public TopClassEntry(String classification, Double probability) {
public TopClassEntry(String classification, double probability) {
this(classification, probability, probability);
}
public TopClassEntry(String classification, double probability, double score) {
this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
this.probability = ExceptionsHelper.requireNonNull(probability, CLASS_PROBABILITY);
this.probability = probability;
this.score = score;
}
public TopClassEntry(StreamInput in) throws IOException {
this.classification = in.readString();
this.probability = in.readDouble();
this.score = in.readDouble();
}
public String getClassification() {
@ -134,10 +142,15 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
return probability;
}
public double getScore() {
return score;
}
public Map<String, Object> asValueMap() {
Map<String, Object> map = new HashMap<>(2);
Map<String, Object> map = new HashMap<>(3, 1.0f);
map.put(CLASS_NAME.getPreferredName(), classification);
map.put(CLASS_PROBABILITY.getPreferredName(), probability);
map.put(CLASS_SCORE.getPreferredName(), score);
return map;
}
@ -145,6 +158,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
public void writeTo(StreamOutput out) throws IOException {
out.writeString(classification);
out.writeDouble(probability);
out.writeDouble(score);
}
@Override
@ -152,13 +166,12 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
TopClassEntry that = (TopClassEntry) object;
return Objects.equals(classification, that.classification) &&
Objects.equals(probability, that.probability);
return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
}
@Override
public int hashCode() {
return Objects.hash(classification, probability);
return Objects.hash(classification, probability, score);
}
}
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -20,17 +21,13 @@ public final class InferenceHelpers {
private InferenceHelpers() { }
public static List<ClassificationInferenceResults.TopClassEntry> topClasses(List<Double> probabilities,
List<String> classificationLabels,
int numToInclude) {
if (numToInclude == 0) {
return Collections.emptyList();
}
int[] sortedIndices = IntStream.range(0, probabilities.size())
.boxed()
.sorted(Comparator.comparing(probabilities::get).reversed())
.mapToInt(i -> i)
.toArray();
/**
* @return Tuple of the highest scored index and the top classes
*/
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(List<Double> probabilities,
List<String> classificationLabels,
@Nullable double[] classificationWeights,
int numToInclude) {
if (classificationLabels != null && probabilities.size() != classificationLabels.size()) {
throw ExceptionsHelper
@ -38,7 +35,24 @@ public final class InferenceHelpers {
"model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]",
null,
probabilities.size(),
classificationLabels);
classificationLabels.size());
}
List<Double> scores = classificationWeights == null ?
probabilities :
IntStream.range(0, probabilities.size())
.mapToDouble(i -> probabilities.get(i) * classificationWeights[i])
.boxed()
.collect(Collectors.toList());
int[] sortedIndices = IntStream.range(0, probabilities.size())
.boxed()
.sorted(Comparator.comparing(scores::get).reversed())
.mapToInt(i -> i)
.toArray();
if (numToInclude == 0) {
return Tuple.tuple(sortedIndices[0], Collections.emptyList());
}
List<String> labels = classificationLabels == null ?
@ -50,26 +64,24 @@ public final class InferenceHelpers {
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
for(int i = 0; i < count; i++) {
int idx = sortedIndices[i];
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx)));
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx), scores.get(idx)));
}
return topClassEntries;
return Tuple.tuple(sortedIndices[0], topClassEntries);
}
public static String classificationLabel(double inferenceValue, @Nullable List<String> classificationLabels) {
assert inferenceValue == Math.rint(inferenceValue);
public static String classificationLabel(Integer inferenceValue, @Nullable List<String> classificationLabels) {
if (classificationLabels == null) {
return String.valueOf(inferenceValue);
}
int label = Double.valueOf(inferenceValue).intValue();
if (label < 0 || label >= classificationLabels.size()) {
if (inferenceValue < 0 || inferenceValue >= classificationLabels.size()) {
throw ExceptionsHelper.serverError(
"model returned classification value of [{}] which is not a valid index in classification labels [{}]",
null,
label,
inferenceValue,
classificationLabels);
}
return classificationLabels.get(label);
return classificationLabels.get(inferenceValue);
}
public static Double toDouble(Object value) {

View File

@ -6,21 +6,14 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
import java.util.List;
import java.util.Map;
public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accountable {
/**
* @return List of featureNames expected by the model. In the order that they are expected
*/
List<String> getFeatureNames();
/**
* Infer against the provided fields
*
@ -36,12 +29,6 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
*/
TargetType targetType();
/**
* @return Ordinal encoded list of classification labels.
*/
@Nullable
List<String> classificationLabels();
/**
* Runs validations against the model.
*

View File

@ -10,6 +10,7 @@ import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
@ -33,6 +34,7 @@ import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
@ -53,6 +55,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output");
public static final ParseField TARGET_TYPE = new ParseField("target_type");
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
public static final ParseField CLASSIFICATION_WEIGHTS = new ParseField("classification_weights");
private static final ObjectParser<Ensemble.Builder, Void> LENIENT_PARSER = createParser(true);
private static final ObjectParser<Ensemble.Builder, Void> STRICT_PARSER = createParser(false);
@ -77,6 +80,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
AGGREGATE_OUTPUT);
parser.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
parser.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
parser.declareDoubleArray(Ensemble.Builder::setClassificationWeights, CLASSIFICATION_WEIGHTS);
return parser;
}
@ -93,17 +97,22 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
private final OutputAggregator outputAggregator;
private final TargetType targetType;
private final List<String> classificationLabels;
private final double[] classificationWeights;
Ensemble(List<String> featureNames,
List<TrainedModel> models,
OutputAggregator outputAggregator,
TargetType targetType,
@Nullable List<String> classificationLabels) {
@Nullable List<String> classificationLabels,
@Nullable double[] classificationWeights) {
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
this.models = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(models, TRAINED_MODELS));
this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
this.classificationWeights = classificationWeights == null ?
null :
Arrays.copyOf(classificationWeights, classificationWeights.length);
}
public Ensemble(StreamInput in) throws IOException {
@ -116,11 +125,11 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
} else {
this.classificationLabels = null;
}
}
@Override
public List<String> getFeatureNames() {
return featureNames;
if (in.readBoolean()) {
this.classificationWeights = in.readDoubleArray();
} else {
this.classificationWeights = null;
}
}
@Override
@ -153,25 +162,22 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences), config);
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
List<ClassificationInferenceResults.TopClassEntry> topClasses = InferenceHelpers.topClasses(
assert classificationWeights == null || processedInferences.size() == classificationWeights.length;
// Adjust the probabilities according to the thresholds
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
processedInferences,
classificationLabels,
classificationWeights,
classificationConfig.getNumTopClasses());
double value = outputAggregator.aggregate(processedInferences);
return new ClassificationInferenceResults(outputAggregator.aggregate(processedInferences),
classificationLabel(value, classificationLabels),
topClasses,
return new ClassificationInferenceResults((double)topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(),
config);
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
}
}
@Override
public List<String> classificationLabels() {
return classificationLabels;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -187,6 +193,10 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
if (classificationLabels != null) {
out.writeStringCollection(classificationLabels);
}
out.writeBoolean(classificationWeights != null);
if (classificationWeights != null) {
out.writeDoubleArray(classificationWeights);
}
}
@Override
@ -208,6 +218,9 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
if (classificationLabels != null) {
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
}
if (classificationWeights != null) {
builder.field(CLASSIFICATION_WEIGHTS.getPreferredName(), classificationWeights);
}
builder.endObject();
return builder;
}
@ -221,12 +234,18 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
&& Objects.equals(models, that.models)
&& Objects.equals(targetType, that.targetType)
&& Objects.equals(classificationLabels, that.classificationLabels)
&& Objects.equals(outputAggregator, that.outputAggregator);
&& Objects.equals(outputAggregator, that.outputAggregator)
&& Arrays.equals(classificationWeights, that.classificationWeights);
}
@Override
public int hashCode() {
return Objects.hash(featureNames, models, outputAggregator, targetType, classificationLabels);
return Objects.hash(featureNames,
models,
outputAggregator,
targetType,
classificationLabels,
Arrays.hashCode(classificationWeights));
}
@Override
@ -246,9 +265,16 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
outputAggregator.expectedValueSize(),
models.size());
}
if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) {
if ((this.classificationLabels != null || this.classificationWeights != null) && (this.targetType != TargetType.CLASSIFICATION)) {
throw ExceptionsHelper.badRequestException(
"[target_type] should be [classification] if [classification_labels] is provided, and vice versa");
"[target_type] should be [classification] if [classification_labels] or [classification_weights] are provided");
}
if (classificationWeights != null &&
classificationLabels != null &&
classificationWeights.length != classificationLabels.size()) {
throw ExceptionsHelper.badRequestException(
"[classification_weights] and [classification_labels] should be the same length if both are provided"
);
}
this.models.forEach(TrainedModel::validate);
}
@ -271,6 +297,9 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
size += RamUsageEstimator.sizeOfCollection(featureNames);
size += RamUsageEstimator.sizeOfCollection(classificationLabels);
size += RamUsageEstimator.sizeOfCollection(models);
if (classificationWeights != null) {
size += RamUsageEstimator.sizeOf(classificationWeights);
}
size += outputAggregator.ramBytesUsed();
return size;
}
@ -291,6 +320,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
private OutputAggregator outputAggregator = new WeightedSum();
private TargetType targetType = TargetType.REGRESSION;
private List<String> classificationLabels;
private double[] classificationWeights;
private boolean modelsAreOrdered;
private Builder (boolean modelsAreOrdered) {
@ -330,6 +360,11 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return this;
}
public Builder setClassificationWeights(List<Double> classificationWeights) {
this.classificationWeights = classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
return this;
}
private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
if (outputAggregators.size() != 1) {
throw ExceptionsHelper.badRequestException("[{}] must have exactly one aggregator defined.",
@ -352,7 +387,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
if (modelsAreOrdered == false && trainedModels != null && trainedModels.size() > 1) {
throw ExceptionsHelper.badRequestException("[trained_models] needs to be an array of objects");
}
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels, classificationWeights);
}
}
}

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@ -25,13 +26,10 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
@ -105,11 +103,6 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
this.softmaxLayer = new LangNetLayer(in);
}
@Override
public List<String> getFeatureNames() {
return Collections.singletonList(embeddedVectorFeatureName);
}
@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
if (config instanceof ClassificationConfig == false) {
@ -134,20 +127,17 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
List<Double> probabilities = softMax(Arrays.stream(scores).boxed().collect(Collectors.toList()));
int maxIndex = IntStream.range(0, probabilities.size())
.boxed()
.max(Comparator.comparing(probabilities::get))
.orElseThrow(() -> ExceptionsHelper.serverError("Unexpected null value while searching for max probability"));
assert maxIndex >= 0 && maxIndex < LANGUAGE_NAMES.size() : "Invalid language predicted. Predicted language index " + maxIndex;
ClassificationConfig classificationConfig = (ClassificationConfig) config;
List<ClassificationInferenceResults.TopClassEntry> topClasses = InferenceHelpers.topClasses(
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
probabilities,
LANGUAGE_NAMES,
null,
classificationConfig.getNumTopClasses());
return new ClassificationInferenceResults(maxIndex,
LANGUAGE_NAMES.get(maxIndex),
topClasses,
assert topClasses.v1() >= 0 && topClasses.v1() < LANGUAGE_NAMES.size() :
"Invalid language predicted. Predicted language index " + topClasses.v1();
return new ClassificationInferenceResults(topClasses.v1(),
LANGUAGE_NAMES.get(topClasses.v1()),
topClasses.v2(),
classificationConfig);
}
@ -156,11 +146,6 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
return TargetType.CLASSIFICATION;
}
@Override
public List<String> classificationLabels() {
return LANGUAGE_NAMES;
}
@Override
public void validate() {
}

View File

@ -10,6 +10,7 @@ import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.CachedSupplier;
@ -114,11 +115,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return NAME.getPreferredName();
}
@Override
public List<String> getFeatureNames() {
return featureNames;
}
public List<TreeNode> getNodes() {
return nodes;
}
@ -152,11 +148,15 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
switch (targetType) {
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
List<ClassificationInferenceResults.TopClassEntry> topClasses = InferenceHelpers.topClasses(
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
classificationProbability(value),
classificationLabels,
null,
classificationConfig.getNumTopClasses());
return new ClassificationInferenceResults(value, classificationLabel(value, classificationLabels), topClasses, config);
return new ClassificationInferenceResults(value,
classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(),
config);
case REGRESSION:
return new RegressionInferenceResults(value, config);
default:
@ -197,11 +197,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return list;
}
@Override
public List<String> classificationLabels() {
return classificationLabels;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -270,9 +265,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
}
private void checkTargetType() {
if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) {
if (this.classificationLabels != null && this.targetType != TargetType.CLASSIFICATION) {
throw ExceptionsHelper.badRequestException(
"[target_type] should be [classification] if [classification_labels] is provided, and vice versa");
"[target_type] should be [classification] if [classification_labels] are provided");
}
}

View File

@ -8,10 +8,8 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
@ -28,7 +26,6 @@ import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -77,16 +74,24 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights),
new WeightedSum(weights),
new LogisticRegression(weights));
TargetType targetType = randomFrom(TargetType.values());
List<String> categoryLabels = null;
if (randomBoolean()) {
if (randomBoolean() && targetType == TargetType.CLASSIFICATION) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ?
Stream.generate(ESTestCase::randomDouble)
.limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size())
.mapToDouble(Double::valueOf)
.toArray() :
null;
return new Ensemble(featureNames,
models,
outputAggregator,
randomFrom(TargetType.values()),
categoryLabels);
targetType,
categoryLabels,
thresholds);
}
@Override
@ -101,17 +106,12 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(entries);
return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
}
public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() {
@ -184,16 +184,15 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
public void testEnsembleWithTargetTypeAndLabelsMismatch() {
List<String> featureNames = Arrays.asList("foo", "bar");
String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa";
String msg = "[target_type] should be [classification] if " +
"[classification_labels] or [classification_weights] are provided";
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Ensemble.builder()
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(
Tree.builder()
.setNodes(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.setLeafValue(randomDouble()))
.setFeatureNames(featureNames)
.build()))
.setClassificationLabels(Arrays.asList("label1", "label2"))
@ -201,23 +200,6 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
.validate();
});
assertThat(ex.getMessage(), equalTo(msg));
ex = expectThrows(ElasticsearchException.class, () -> {
Ensemble.builder()
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(
Tree.builder()
.setNodes(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.setFeatureNames(featureNames)
.build()))
.setTargetType(TargetType.CLASSIFICATION)
.setOutputAggregator(new WeightedMode())
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo(msg));
}
public void testClassificationProbability() {
@ -262,34 +244,41 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}))
.setClassificationWeights(Arrays.asList(0.7, 0.3))
.build();
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
List<Double> expected = Arrays.asList(0.768524783, 0.231475216);
List<Double> scores = Arrays.asList(0.230557435, 0.162032651);
double eps = 0.000001;
List<ClassificationInferenceResults.TopClassEntry> probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
}
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
expected = Arrays.asList(0.689974481, 0.3100255188);
expected = Arrays.asList(0.310025518, 0.6899744811);
scores = Arrays.asList(0.217017863, 0.2069923443);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
}
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
expected = Arrays.asList(0.768524783, 0.231475216);
scores = Arrays.asList(0.230557435, 0.162032651);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
}
// This should handle missing values and take the default_left path
@ -298,10 +287,12 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
put("bar", null);
}};
expected = Arrays.asList(0.6899744811, 0.3100255188);
scores = Arrays.asList(0.482982136, 0.0930076556);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
}
}

View File

@ -93,14 +93,13 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
}
childNodes = nextNodes;
}
TargetType targetType = randomFrom(TargetType.values());
List<String> categoryLabels = null;
if (randomBoolean()) {
if (randomBoolean() && targetType == TargetType.CLASSIFICATION) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
return builder.setTargetType(randomFrom(TargetType.values()))
.setClassificationLabels(categoryLabels)
.build();
return builder.setTargetType(targetType).setClassificationLabels(categoryLabels).build();
}
@Override
@ -325,7 +324,7 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
public void testTreeWithTargetTypeAndLabelsMismatch() {
List<String> featureNames = Arrays.asList("foo", "bar");
String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa";
String msg = "[target_type] should be [classification] if [classification_labels] are provided";
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Tree.builder()
.setRoot(TreeNode.builder(0)
@ -338,18 +337,6 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
.validate();
});
assertThat(ex.getMessage(), equalTo(msg));
ex = expectThrows(ElasticsearchException.class, () -> {
Tree.builder()
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.setFeatureNames(featureNames)
.setTargetType(TargetType.CLASSIFICATION)
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo(msg));
}
public void testOperationsEstimations() {

View File

@ -471,7 +471,7 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"leaf_value\": 2\n" +
" \"leaf_value\": 0\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
@ -501,7 +501,7 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"leaf_value\": 2\n" +
" \"leaf_value\": 0\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +

View File

@ -56,7 +56,7 @@ public class LocalModelTests extends ESTestCase {
SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfig(0));
assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), is("0.0"));
assertThat(result.valueAsString(), is("0"));
ClassificationInferenceResults classificationResult =
(ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1));