* [ML][Inference] Adding classification_weights to ensemble models classification_weights are a way to allow models to prefer specific classification results over others this might be advantageous if classification value probabilities are a known quantity and can improve model error rates.
This commit is contained in:
parent
de5713fa4b
commit
72c270946f
|
@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
|
|||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
@ -41,6 +42,7 @@ public class Ensemble implements TrainedModel {
|
|||
public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output");
|
||||
public static final ParseField TARGET_TYPE = new ParseField("target_type");
|
||||
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
|
||||
public static final ParseField CLASSIFICATION_WEIGHTS = new ParseField("classification_weights");
|
||||
|
||||
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
|
||||
NAME,
|
||||
|
@ -60,6 +62,7 @@ public class Ensemble implements TrainedModel {
|
|||
AGGREGATE_OUTPUT);
|
||||
PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
|
||||
PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
|
||||
PARSER.declareDoubleArray(Ensemble.Builder::setClassificationWeights, CLASSIFICATION_WEIGHTS);
|
||||
}
|
||||
|
||||
public static Ensemble fromXContent(XContentParser parser) {
|
||||
|
@ -71,17 +74,20 @@ public class Ensemble implements TrainedModel {
|
|||
private final OutputAggregator outputAggregator;
|
||||
private final TargetType targetType;
|
||||
private final List<String> classificationLabels;
|
||||
private final double[] classificationWeights;
|
||||
|
||||
Ensemble(List<String> featureNames,
|
||||
List<TrainedModel> models,
|
||||
@Nullable OutputAggregator outputAggregator,
|
||||
TargetType targetType,
|
||||
@Nullable List<String> classificationLabels) {
|
||||
@Nullable List<String> classificationLabels,
|
||||
@Nullable double[] classificationWeights) {
|
||||
this.featureNames = featureNames;
|
||||
this.models = models;
|
||||
this.outputAggregator = outputAggregator;
|
||||
this.targetType = targetType;
|
||||
this.classificationLabels = classificationLabels;
|
||||
this.classificationWeights = classificationWeights;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -116,6 +122,9 @@ public class Ensemble implements TrainedModel {
|
|||
if (classificationLabels != null) {
|
||||
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
|
||||
}
|
||||
if (classificationWeights != null) {
|
||||
builder.field(CLASSIFICATION_WEIGHTS.getPreferredName(), classificationWeights);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -129,12 +138,18 @@ public class Ensemble implements TrainedModel {
|
|||
&& Objects.equals(models, that.models)
|
||||
&& Objects.equals(targetType, that.targetType)
|
||||
&& Objects.equals(classificationLabels, that.classificationLabels)
|
||||
&& Arrays.equals(classificationWeights, that.classificationWeights)
|
||||
&& Objects.equals(outputAggregator, that.outputAggregator);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureNames, models, outputAggregator, classificationLabels, targetType);
|
||||
return Objects.hash(featureNames,
|
||||
models,
|
||||
outputAggregator,
|
||||
classificationLabels,
|
||||
targetType,
|
||||
Arrays.hashCode(classificationWeights));
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
|
@ -147,6 +162,7 @@ public class Ensemble implements TrainedModel {
|
|||
private OutputAggregator outputAggregator;
|
||||
private TargetType targetType;
|
||||
private List<String> classificationLabels;
|
||||
private double[] classificationWeights;
|
||||
|
||||
public Builder setFeatureNames(List<String> featureNames) {
|
||||
this.featureNames = featureNames;
|
||||
|
@ -173,6 +189,11 @@ public class Ensemble implements TrainedModel {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setClassificationWeights(List<Double> classificationWeights) {
|
||||
this.classificationWeights = classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
|
||||
return this;
|
||||
}
|
||||
|
||||
private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
|
||||
this.setOutputAggregator(outputAggregators.get(0));
|
||||
}
|
||||
|
@ -182,7 +203,7 @@ public class Ensemble implements TrainedModel {
|
|||
}
|
||||
|
||||
public Ensemble build() {
|
||||
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
|
||||
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels, classificationWeights);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -80,11 +80,19 @@ public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
|
|||
if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
|
||||
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
|
||||
}
|
||||
double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ?
|
||||
Stream.generate(ESTestCase::randomDouble)
|
||||
.limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size())
|
||||
.mapToDouble(Double::valueOf)
|
||||
.toArray() :
|
||||
null;
|
||||
|
||||
return new Ensemble(featureNames,
|
||||
models,
|
||||
outputAggregator,
|
||||
targetType,
|
||||
categoryLabels);
|
||||
categoryLabels,
|
||||
thresholds);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -112,18 +112,26 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
|
||||
public final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
public final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
|
||||
public final ParseField CLASS_SCORE = new ParseField("class_score");
|
||||
|
||||
private final String classification;
|
||||
private final double probability;
|
||||
private final double score;
|
||||
|
||||
public TopClassEntry(String classification, Double probability) {
|
||||
public TopClassEntry(String classification, double probability) {
|
||||
this(classification, probability, probability);
|
||||
}
|
||||
|
||||
public TopClassEntry(String classification, double probability, double score) {
|
||||
this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
|
||||
this.probability = ExceptionsHelper.requireNonNull(probability, CLASS_PROBABILITY);
|
||||
this.probability = probability;
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
public TopClassEntry(StreamInput in) throws IOException {
|
||||
this.classification = in.readString();
|
||||
this.probability = in.readDouble();
|
||||
this.score = in.readDouble();
|
||||
}
|
||||
|
||||
public String getClassification() {
|
||||
|
@ -134,10 +142,15 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
return probability;
|
||||
}
|
||||
|
||||
public double getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
public Map<String, Object> asValueMap() {
|
||||
Map<String, Object> map = new HashMap<>(2);
|
||||
Map<String, Object> map = new HashMap<>(3, 1.0f);
|
||||
map.put(CLASS_NAME.getPreferredName(), classification);
|
||||
map.put(CLASS_PROBABILITY.getPreferredName(), probability);
|
||||
map.put(CLASS_SCORE.getPreferredName(), score);
|
||||
return map;
|
||||
}
|
||||
|
||||
|
@ -145,6 +158,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(classification);
|
||||
out.writeDouble(probability);
|
||||
out.writeDouble(score);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -152,13 +166,12 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
TopClassEntry that = (TopClassEntry) object;
|
||||
return Objects.equals(classification, that.classification) &&
|
||||
Objects.equals(probability, that.probability);
|
||||
return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(classification, probability);
|
||||
return Objects.hash(classification, probability, score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
|
@ -20,17 +21,13 @@ public final class InferenceHelpers {
|
|||
|
||||
private InferenceHelpers() { }
|
||||
|
||||
public static List<ClassificationInferenceResults.TopClassEntry> topClasses(List<Double> probabilities,
|
||||
List<String> classificationLabels,
|
||||
int numToInclude) {
|
||||
if (numToInclude == 0) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
int[] sortedIndices = IntStream.range(0, probabilities.size())
|
||||
.boxed()
|
||||
.sorted(Comparator.comparing(probabilities::get).reversed())
|
||||
.mapToInt(i -> i)
|
||||
.toArray();
|
||||
/**
|
||||
* @return Tuple of the highest scored index and the top classes
|
||||
*/
|
||||
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(List<Double> probabilities,
|
||||
List<String> classificationLabels,
|
||||
@Nullable double[] classificationWeights,
|
||||
int numToInclude) {
|
||||
|
||||
if (classificationLabels != null && probabilities.size() != classificationLabels.size()) {
|
||||
throw ExceptionsHelper
|
||||
|
@ -38,7 +35,24 @@ public final class InferenceHelpers {
|
|||
"model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]",
|
||||
null,
|
||||
probabilities.size(),
|
||||
classificationLabels);
|
||||
classificationLabels.size());
|
||||
}
|
||||
|
||||
List<Double> scores = classificationWeights == null ?
|
||||
probabilities :
|
||||
IntStream.range(0, probabilities.size())
|
||||
.mapToDouble(i -> probabilities.get(i) * classificationWeights[i])
|
||||
.boxed()
|
||||
.collect(Collectors.toList());
|
||||
|
||||
int[] sortedIndices = IntStream.range(0, probabilities.size())
|
||||
.boxed()
|
||||
.sorted(Comparator.comparing(scores::get).reversed())
|
||||
.mapToInt(i -> i)
|
||||
.toArray();
|
||||
|
||||
if (numToInclude == 0) {
|
||||
return Tuple.tuple(sortedIndices[0], Collections.emptyList());
|
||||
}
|
||||
|
||||
List<String> labels = classificationLabels == null ?
|
||||
|
@ -50,26 +64,24 @@ public final class InferenceHelpers {
|
|||
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
|
||||
for(int i = 0; i < count; i++) {
|
||||
int idx = sortedIndices[i];
|
||||
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx)));
|
||||
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx), scores.get(idx)));
|
||||
}
|
||||
|
||||
return topClassEntries;
|
||||
return Tuple.tuple(sortedIndices[0], topClassEntries);
|
||||
}
|
||||
|
||||
public static String classificationLabel(double inferenceValue, @Nullable List<String> classificationLabels) {
|
||||
assert inferenceValue == Math.rint(inferenceValue);
|
||||
public static String classificationLabel(Integer inferenceValue, @Nullable List<String> classificationLabels) {
|
||||
if (classificationLabels == null) {
|
||||
return String.valueOf(inferenceValue);
|
||||
}
|
||||
int label = Double.valueOf(inferenceValue).intValue();
|
||||
if (label < 0 || label >= classificationLabels.size()) {
|
||||
if (inferenceValue < 0 || inferenceValue >= classificationLabels.size()) {
|
||||
throw ExceptionsHelper.serverError(
|
||||
"model returned classification value of [{}] which is not a valid index in classification labels [{}]",
|
||||
null,
|
||||
label,
|
||||
inferenceValue,
|
||||
classificationLabels);
|
||||
}
|
||||
return classificationLabels.get(label);
|
||||
return classificationLabels.get(inferenceValue);
|
||||
}
|
||||
|
||||
public static Double toDouble(Object value) {
|
||||
|
|
|
@ -6,21 +6,14 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accountable {
|
||||
|
||||
/**
|
||||
* @return List of featureNames expected by the model. In the order that they are expected
|
||||
*/
|
||||
List<String> getFeatureNames();
|
||||
|
||||
/**
|
||||
* Infer against the provided fields
|
||||
*
|
||||
|
@ -36,12 +29,6 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
|
|||
*/
|
||||
TargetType targetType();
|
||||
|
||||
/**
|
||||
* @return Ordinal encoded list of classification labels.
|
||||
*/
|
||||
@Nullable
|
||||
List<String> classificationLabels();
|
||||
|
||||
/**
|
||||
* Runs validations against the model.
|
||||
*
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.apache.lucene.util.Accountables;
|
|||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
|
@ -33,6 +34,7 @@ import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
@ -53,6 +55,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output");
|
||||
public static final ParseField TARGET_TYPE = new ParseField("target_type");
|
||||
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
|
||||
public static final ParseField CLASSIFICATION_WEIGHTS = new ParseField("classification_weights");
|
||||
|
||||
private static final ObjectParser<Ensemble.Builder, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ObjectParser<Ensemble.Builder, Void> STRICT_PARSER = createParser(false);
|
||||
|
@ -77,6 +80,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
AGGREGATE_OUTPUT);
|
||||
parser.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
|
||||
parser.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
|
||||
parser.declareDoubleArray(Ensemble.Builder::setClassificationWeights, CLASSIFICATION_WEIGHTS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -93,17 +97,22 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
private final OutputAggregator outputAggregator;
|
||||
private final TargetType targetType;
|
||||
private final List<String> classificationLabels;
|
||||
private final double[] classificationWeights;
|
||||
|
||||
Ensemble(List<String> featureNames,
|
||||
List<TrainedModel> models,
|
||||
OutputAggregator outputAggregator,
|
||||
TargetType targetType,
|
||||
@Nullable List<String> classificationLabels) {
|
||||
@Nullable List<String> classificationLabels,
|
||||
@Nullable double[] classificationWeights) {
|
||||
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
|
||||
this.models = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(models, TRAINED_MODELS));
|
||||
this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
|
||||
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
|
||||
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
|
||||
this.classificationWeights = classificationWeights == null ?
|
||||
null :
|
||||
Arrays.copyOf(classificationWeights, classificationWeights.length);
|
||||
}
|
||||
|
||||
public Ensemble(StreamInput in) throws IOException {
|
||||
|
@ -116,11 +125,11 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
} else {
|
||||
this.classificationLabels = null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getFeatureNames() {
|
||||
return featureNames;
|
||||
if (in.readBoolean()) {
|
||||
this.classificationWeights = in.readDoubleArray();
|
||||
} else {
|
||||
this.classificationWeights = null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -153,25 +162,22 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences), config);
|
||||
case CLASSIFICATION:
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
List<ClassificationInferenceResults.TopClassEntry> topClasses = InferenceHelpers.topClasses(
|
||||
assert classificationWeights == null || processedInferences.size() == classificationWeights.length;
|
||||
// Adjust the probabilities according to the thresholds
|
||||
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
processedInferences,
|
||||
classificationLabels,
|
||||
classificationWeights,
|
||||
classificationConfig.getNumTopClasses());
|
||||
double value = outputAggregator.aggregate(processedInferences);
|
||||
return new ClassificationInferenceResults(outputAggregator.aggregate(processedInferences),
|
||||
classificationLabel(value, classificationLabels),
|
||||
topClasses,
|
||||
return new ClassificationInferenceResults((double)topClasses.v1(),
|
||||
classificationLabel(topClasses.v1(), classificationLabels),
|
||||
topClasses.v2(),
|
||||
config);
|
||||
default:
|
||||
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> classificationLabels() {
|
||||
return classificationLabels;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -187,6 +193,10 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
if (classificationLabels != null) {
|
||||
out.writeStringCollection(classificationLabels);
|
||||
}
|
||||
out.writeBoolean(classificationWeights != null);
|
||||
if (classificationWeights != null) {
|
||||
out.writeDoubleArray(classificationWeights);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -208,6 +218,9 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
if (classificationLabels != null) {
|
||||
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
|
||||
}
|
||||
if (classificationWeights != null) {
|
||||
builder.field(CLASSIFICATION_WEIGHTS.getPreferredName(), classificationWeights);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -221,12 +234,18 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
&& Objects.equals(models, that.models)
|
||||
&& Objects.equals(targetType, that.targetType)
|
||||
&& Objects.equals(classificationLabels, that.classificationLabels)
|
||||
&& Objects.equals(outputAggregator, that.outputAggregator);
|
||||
&& Objects.equals(outputAggregator, that.outputAggregator)
|
||||
&& Arrays.equals(classificationWeights, that.classificationWeights);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureNames, models, outputAggregator, targetType, classificationLabels);
|
||||
return Objects.hash(featureNames,
|
||||
models,
|
||||
outputAggregator,
|
||||
targetType,
|
||||
classificationLabels,
|
||||
Arrays.hashCode(classificationWeights));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -246,9 +265,16 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
outputAggregator.expectedValueSize(),
|
||||
models.size());
|
||||
}
|
||||
if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) {
|
||||
if ((this.classificationLabels != null || this.classificationWeights != null) && (this.targetType != TargetType.CLASSIFICATION)) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[target_type] should be [classification] if [classification_labels] is provided, and vice versa");
|
||||
"[target_type] should be [classification] if [classification_labels] or [classification_weights] are provided");
|
||||
}
|
||||
if (classificationWeights != null &&
|
||||
classificationLabels != null &&
|
||||
classificationWeights.length != classificationLabels.size()) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[classification_weights] and [classification_labels] should be the same length if both are provided"
|
||||
);
|
||||
}
|
||||
this.models.forEach(TrainedModel::validate);
|
||||
}
|
||||
|
@ -271,6 +297,9 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
size += RamUsageEstimator.sizeOfCollection(featureNames);
|
||||
size += RamUsageEstimator.sizeOfCollection(classificationLabels);
|
||||
size += RamUsageEstimator.sizeOfCollection(models);
|
||||
if (classificationWeights != null) {
|
||||
size += RamUsageEstimator.sizeOf(classificationWeights);
|
||||
}
|
||||
size += outputAggregator.ramBytesUsed();
|
||||
return size;
|
||||
}
|
||||
|
@ -291,6 +320,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
private OutputAggregator outputAggregator = new WeightedSum();
|
||||
private TargetType targetType = TargetType.REGRESSION;
|
||||
private List<String> classificationLabels;
|
||||
private double[] classificationWeights;
|
||||
private boolean modelsAreOrdered;
|
||||
|
||||
private Builder (boolean modelsAreOrdered) {
|
||||
|
@ -330,6 +360,11 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setClassificationWeights(List<Double> classificationWeights) {
|
||||
this.classificationWeights = classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
|
||||
return this;
|
||||
}
|
||||
|
||||
private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
|
||||
if (outputAggregators.size() != 1) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must have exactly one aggregator defined.",
|
||||
|
@ -352,7 +387,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
|||
if (modelsAreOrdered == false && trainedModels != null && trainedModels.size() > 1) {
|
||||
throw ExceptionsHelper.badRequestException("[trained_models] needs to be an array of objects");
|
||||
}
|
||||
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
|
||||
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels, classificationWeights);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident;
|
|||
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
|
@ -25,13 +26,10 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
|
||||
|
@ -105,11 +103,6 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
|
|||
this.softmaxLayer = new LangNetLayer(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getFeatureNames() {
|
||||
return Collections.singletonList(embeddedVectorFeatureName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
|
||||
if (config instanceof ClassificationConfig == false) {
|
||||
|
@ -134,20 +127,17 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
|
|||
|
||||
List<Double> probabilities = softMax(Arrays.stream(scores).boxed().collect(Collectors.toList()));
|
||||
|
||||
int maxIndex = IntStream.range(0, probabilities.size())
|
||||
.boxed()
|
||||
.max(Comparator.comparing(probabilities::get))
|
||||
.orElseThrow(() -> ExceptionsHelper.serverError("Unexpected null value while searching for max probability"));
|
||||
|
||||
assert maxIndex >= 0 && maxIndex < LANGUAGE_NAMES.size() : "Invalid language predicted. Predicted language index " + maxIndex;
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
List<ClassificationInferenceResults.TopClassEntry> topClasses = InferenceHelpers.topClasses(
|
||||
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
probabilities,
|
||||
LANGUAGE_NAMES,
|
||||
null,
|
||||
classificationConfig.getNumTopClasses());
|
||||
return new ClassificationInferenceResults(maxIndex,
|
||||
LANGUAGE_NAMES.get(maxIndex),
|
||||
topClasses,
|
||||
assert topClasses.v1() >= 0 && topClasses.v1() < LANGUAGE_NAMES.size() :
|
||||
"Invalid language predicted. Predicted language index " + topClasses.v1();
|
||||
return new ClassificationInferenceResults(topClasses.v1(),
|
||||
LANGUAGE_NAMES.get(topClasses.v1()),
|
||||
topClasses.v2(),
|
||||
classificationConfig);
|
||||
}
|
||||
|
||||
|
@ -156,11 +146,6 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
|
|||
return TargetType.CLASSIFICATION;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> classificationLabels() {
|
||||
return LANGUAGE_NAMES;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate() {
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.apache.lucene.util.Accountables;
|
|||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.util.CachedSupplier;
|
||||
|
@ -114,11 +115,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getFeatureNames() {
|
||||
return featureNames;
|
||||
}
|
||||
|
||||
public List<TreeNode> getNodes() {
|
||||
return nodes;
|
||||
}
|
||||
|
@ -152,11 +148,15 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
switch (targetType) {
|
||||
case CLASSIFICATION:
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
List<ClassificationInferenceResults.TopClassEntry> topClasses = InferenceHelpers.topClasses(
|
||||
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
|
||||
classificationProbability(value),
|
||||
classificationLabels,
|
||||
null,
|
||||
classificationConfig.getNumTopClasses());
|
||||
return new ClassificationInferenceResults(value, classificationLabel(value, classificationLabels), topClasses, config);
|
||||
return new ClassificationInferenceResults(value,
|
||||
classificationLabel(topClasses.v1(), classificationLabels),
|
||||
topClasses.v2(),
|
||||
config);
|
||||
case REGRESSION:
|
||||
return new RegressionInferenceResults(value, config);
|
||||
default:
|
||||
|
@ -197,11 +197,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
return list;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> classificationLabels() {
|
||||
return classificationLabels;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -270,9 +265,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
}
|
||||
|
||||
private void checkTargetType() {
|
||||
if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) {
|
||||
if (this.classificationLabels != null && this.targetType != TargetType.CLASSIFICATION) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[target_type] should be [classification] if [classification_labels] is provided, and vice versa");
|
||||
"[target_type] should be [classification] if [classification_labels] are provided");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -8,10 +8,8 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
|
|||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
|
@ -28,7 +26,6 @@ import org.junit.Before;
|
|||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
@ -77,16 +74,24 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights),
|
||||
new WeightedSum(weights),
|
||||
new LogisticRegression(weights));
|
||||
TargetType targetType = randomFrom(TargetType.values());
|
||||
List<String> categoryLabels = null;
|
||||
if (randomBoolean()) {
|
||||
if (randomBoolean() && targetType == TargetType.CLASSIFICATION) {
|
||||
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
|
||||
}
|
||||
double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ?
|
||||
Stream.generate(ESTestCase::randomDouble)
|
||||
.limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size())
|
||||
.mapToDouble(Double::valueOf)
|
||||
.toArray() :
|
||||
null;
|
||||
|
||||
return new Ensemble(featureNames,
|
||||
models,
|
||||
outputAggregator,
|
||||
randomFrom(TargetType.values()),
|
||||
categoryLabels);
|
||||
targetType,
|
||||
categoryLabels,
|
||||
thresholds);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -101,17 +106,12 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(entries);
|
||||
return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
}
|
||||
|
||||
public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() {
|
||||
|
@ -184,16 +184,15 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
|
||||
public void testEnsembleWithTargetTypeAndLabelsMismatch() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa";
|
||||
String msg = "[target_type] should be [classification] if " +
|
||||
"[classification_labels] or [classification_weights] are provided";
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Ensemble.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(
|
||||
Tree.builder()
|
||||
.setNodes(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()))
|
||||
.setLeafValue(randomDouble()))
|
||||
.setFeatureNames(featureNames)
|
||||
.build()))
|
||||
.setClassificationLabels(Arrays.asList("label1", "label2"))
|
||||
|
@ -201,23 +200,6 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo(msg));
|
||||
ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Ensemble.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(
|
||||
Tree.builder()
|
||||
.setNodes(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()))
|
||||
.setFeatureNames(featureNames)
|
||||
.build()))
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setOutputAggregator(new WeightedMode())
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo(msg));
|
||||
}
|
||||
|
||||
public void testClassificationProbability() {
|
||||
|
@ -262,34 +244,41 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}))
|
||||
.setClassificationWeights(Arrays.asList(0.7, 0.3))
|
||||
.build();
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
List<Double> expected = Arrays.asList(0.768524783, 0.231475216);
|
||||
List<Double> scores = Arrays.asList(0.230557435, 0.162032651);
|
||||
double eps = 0.000001;
|
||||
List<ClassificationInferenceResults.TopClassEntry> probabilities =
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
||||
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
|
||||
}
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
expected = Arrays.asList(0.689974481, 0.3100255188);
|
||||
expected = Arrays.asList(0.310025518, 0.6899744811);
|
||||
scores = Arrays.asList(0.217017863, 0.2069923443);
|
||||
probabilities =
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
||||
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
|
||||
}
|
||||
|
||||
featureVector = Arrays.asList(0.0, 1.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
expected = Arrays.asList(0.768524783, 0.231475216);
|
||||
scores = Arrays.asList(0.230557435, 0.162032651);
|
||||
probabilities =
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
||||
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
|
||||
}
|
||||
|
||||
// This should handle missing values and take the default_left path
|
||||
|
@ -298,10 +287,12 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
put("bar", null);
|
||||
}};
|
||||
expected = Arrays.asList(0.6899744811, 0.3100255188);
|
||||
scores = Arrays.asList(0.482982136, 0.0930076556);
|
||||
probabilities =
|
||||
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
||||
for(int i = 0; i < expected.size(); i++) {
|
||||
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
||||
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -93,14 +93,13 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
}
|
||||
childNodes = nextNodes;
|
||||
}
|
||||
TargetType targetType = randomFrom(TargetType.values());
|
||||
List<String> categoryLabels = null;
|
||||
if (randomBoolean()) {
|
||||
if (randomBoolean() && targetType == TargetType.CLASSIFICATION) {
|
||||
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
|
||||
}
|
||||
|
||||
return builder.setTargetType(randomFrom(TargetType.values()))
|
||||
.setClassificationLabels(categoryLabels)
|
||||
.build();
|
||||
return builder.setTargetType(targetType).setClassificationLabels(categoryLabels).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -325,7 +324,7 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
|
||||
public void testTreeWithTargetTypeAndLabelsMismatch() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa";
|
||||
String msg = "[target_type] should be [classification] if [classification_labels] are provided";
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Tree.builder()
|
||||
.setRoot(TreeNode.builder(0)
|
||||
|
@ -338,18 +337,6 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo(msg));
|
||||
ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Tree.builder()
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()))
|
||||
.setFeatureNames(featureNames)
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo(msg));
|
||||
}
|
||||
|
||||
public void testOperationsEstimations() {
|
||||
|
|
|
@ -471,7 +471,7 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"leaf_value\": 2\n" +
|
||||
" \"leaf_value\": 0\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
|
@ -501,7 +501,7 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
|||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"leaf_value\": 2\n" +
|
||||
" \"leaf_value\": 0\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
|
|
|
@ -56,7 +56,7 @@ public class LocalModelTests extends ESTestCase {
|
|||
|
||||
SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfig(0));
|
||||
assertThat(result.value(), equalTo(0.0));
|
||||
assertThat(result.valueAsString(), is("0.0"));
|
||||
assertThat(result.valueAsString(), is("0"));
|
||||
|
||||
ClassificationInferenceResults classificationResult =
|
||||
(ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1));
|
||||
|
|
Loading…
Reference in New Issue