* [ML][Inference] Fix weighted mode definition (#51648) Weighted mode inaccurately assumed that the "max value" of the input values would be the maximum class value. This does not make sense. Weighted Mode should know how many classes there are. Hence the new parameter `num_classes`. This indicates what the maximum class value to be expected.
This commit is contained in:
parent
69ef9b05cd
commit
1380dd439a
|
@ -34,13 +34,15 @@ public class WeightedMode implements OutputAggregator {
|
|||
|
||||
public static final String NAME = "weighted_mode";
|
||||
public static final ParseField WEIGHTS = new ParseField("weights");
|
||||
public static final ParseField NUM_CLASSES = new ParseField("num_classes");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<WeightedMode, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new WeightedMode((List<Double>)a[0]));
|
||||
a -> new WeightedMode((Integer)a[0], (List<Double>)a[1]));
|
||||
static {
|
||||
PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_CLASSES);
|
||||
PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
|
||||
}
|
||||
|
||||
|
@ -49,9 +51,11 @@ public class WeightedMode implements OutputAggregator {
|
|||
}
|
||||
|
||||
private final List<Double> weights;
|
||||
private final int numClasses;
|
||||
|
||||
public WeightedMode(List<Double> weights) {
|
||||
public WeightedMode(int numClasses, List<Double> weights) {
|
||||
this.weights = weights;
|
||||
this.numClasses = numClasses;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -65,6 +69,7 @@ public class WeightedMode implements OutputAggregator {
|
|||
if (weights != null) {
|
||||
builder.field(WEIGHTS.getPreferredName(), weights);
|
||||
}
|
||||
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -74,11 +79,11 @@ public class WeightedMode implements OutputAggregator {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
WeightedMode that = (WeightedMode) o;
|
||||
return Objects.equals(weights, that.weights);
|
||||
return Objects.equals(weights, that.weights) && numClasses == that.numClasses;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(weights);
|
||||
return Objects.hash(weights, numClasses);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,6 @@ import org.elasticsearch.test.ESTestCase;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Predicate;
|
||||
|
@ -69,17 +68,17 @@ public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
|
|||
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6, targetType))
|
||||
.limit(numberOfModels)
|
||||
.collect(Collectors.toList());
|
||||
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
|
||||
List<OutputAggregator> possibleAggregators = new ArrayList<>(Arrays.asList(new WeightedMode(weights),
|
||||
new LogisticRegression(weights)));
|
||||
if (targetType.equals(TargetType.REGRESSION)) {
|
||||
possibleAggregators.add(new WeightedSum(weights));
|
||||
}
|
||||
OutputAggregator outputAggregator = randomFrom(possibleAggregators.toArray(new OutputAggregator[0]));
|
||||
List<String> categoryLabels = null;
|
||||
if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
|
||||
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
|
||||
categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10));
|
||||
}
|
||||
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
|
||||
OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) :
|
||||
randomFrom(
|
||||
new WeightedMode(
|
||||
categoryLabels != null ? categoryLabels.size() : randomIntBetween(2, 10),
|
||||
weights),
|
||||
new LogisticRegression(weights));
|
||||
double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ?
|
||||
Stream.generate(ESTestCase::randomDouble)
|
||||
.limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size())
|
||||
|
|
|
@ -30,7 +30,9 @@ import java.util.stream.Stream;
|
|||
public class WeightedModeTests extends AbstractXContentTestCase<WeightedMode> {
|
||||
|
||||
WeightedMode createTestInstance(int numberOfWeights) {
|
||||
return new WeightedMode(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
|
||||
return new WeightedMode(
|
||||
randomIntBetween(2, 10),
|
||||
Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -45,7 +47,7 @@ public class WeightedModeTests extends AbstractXContentTestCase<WeightedMode> {
|
|||
|
||||
@Override
|
||||
protected WeightedMode createTestInstance() {
|
||||
return randomBoolean() ? new WeightedMode(null) : createTestInstance(randomIntBetween(1, 100));
|
||||
return randomBoolean() ? new WeightedMode(randomIntBetween(2, 10), null) : createTestInstance(randomIntBetween(1, 100));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
|
@ -29,6 +30,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedMode.class);
|
||||
public static final ParseField NAME = new ParseField("weighted_mode");
|
||||
public static final ParseField WEIGHTS = new ParseField("weights");
|
||||
public static final ParseField NUM_CLASSES = new ParseField("num_classes");
|
||||
|
||||
private static final ConstructingObjectParser<WeightedMode, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<WeightedMode, Void> STRICT_PARSER = createParser(false);
|
||||
|
@ -38,7 +40,8 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
ConstructingObjectParser<WeightedMode, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new WeightedMode((List<Double>)a[0]));
|
||||
a -> new WeightedMode((Integer) a[0], (List<Double>)a[1]));
|
||||
parser.declareInt(ConstructingObjectParser.constructorArg(), NUM_CLASSES);
|
||||
parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
|
||||
return parser;
|
||||
}
|
||||
|
@ -52,17 +55,23 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
}
|
||||
|
||||
private final double[] weights;
|
||||
private final int numClasses;
|
||||
|
||||
WeightedMode() {
|
||||
this((List<Double>) null);
|
||||
WeightedMode(int numClasses) {
|
||||
this(numClasses, null);
|
||||
}
|
||||
|
||||
private WeightedMode(List<Double> weights) {
|
||||
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray());
|
||||
private WeightedMode(Integer numClasses, List<Double> weights) {
|
||||
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray(), numClasses);
|
||||
}
|
||||
|
||||
public WeightedMode(double[] weights) {
|
||||
public WeightedMode(double[] weights, Integer numClasses) {
|
||||
this.weights = weights;
|
||||
this.numClasses = ExceptionsHelper.requireNonNull(numClasses, NUM_CLASSES);
|
||||
if (this.numClasses <= 1) {
|
||||
throw new IllegalArgumentException("[" + NUM_CLASSES.getPreferredName() + "] must be greater than 1.");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public WeightedMode(StreamInput in) throws IOException {
|
||||
|
@ -71,6 +80,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
} else {
|
||||
this.weights = null;
|
||||
}
|
||||
this.numClasses = in.readVInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -99,7 +109,10 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
maxVal = integerValue;
|
||||
}
|
||||
}
|
||||
List<Double> frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY));
|
||||
if (maxVal >= numClasses) {
|
||||
throw new IllegalArgumentException("values contain entries larger than expected max of [" + (numClasses - 1) + "]");
|
||||
}
|
||||
List<Double> frequencies = new ArrayList<>(Collections.nCopies(numClasses, Double.NEGATIVE_INFINITY));
|
||||
for (int i = 0; i < freqArray.size(); i++) {
|
||||
Double weight = weights == null ? 1.0 : weights[i];
|
||||
Integer value = freqArray.get(i);
|
||||
|
@ -133,7 +146,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
|
||||
@Override
|
||||
public boolean compatibleWith(TargetType targetType) {
|
||||
return true;
|
||||
return targetType.equals(TargetType.CLASSIFICATION);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -147,6 +160,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
if (weights != null) {
|
||||
out.writeDoubleArray(weights);
|
||||
}
|
||||
out.writeVInt(numClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -155,6 +169,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
if (weights != null) {
|
||||
builder.field(WEIGHTS.getPreferredName(), weights);
|
||||
}
|
||||
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -164,12 +179,12 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
WeightedMode that = (WeightedMode) o;
|
||||
return Arrays.equals(weights, that.weights);
|
||||
return Arrays.equals(weights, that.weights) && numClasses == that.numClasses;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Arrays.hashCode(weights);
|
||||
return Objects.hash(Arrays.hashCode(weights), numClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,6 +18,8 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
|
@ -26,7 +28,9 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
|||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
@ -302,4 +306,70 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
|
|||
assertThat(test.ramBytesUsed(), greaterThan(0L));
|
||||
}
|
||||
|
||||
public void testMultiClassIrisInference() throws IOException {
|
||||
// Fairly simple, random forest classification model built to fit in our format
|
||||
// Trained on the well known Iris dataset
|
||||
String compressedDef = "H4sIAPbiMl4C/+1b246bMBD9lVWet8jjG3b/oN9QVYgmToLEkghIL6r23wukl90" +
|
||||
"YxRMGlt2WPKwEC/gYe2bOnBl+rOoyzQq3SR4OG5ev3t/9WLmicg+fc9cd1Gm5c3VSfz+2x6t1nlZVts3Wa" +
|
||||
"Z0ditX93Wrr0vpUuqRIH1zVXPJxVbljmie5K3b1vr3ifPw125wPj65+9u/z8fnfn+4vh0jy9LPLzw/+UGb" +
|
||||
"Vu8rVhyptb+wOv7iyytaH/FD+PZWVu6xo7u8e92x+3XOaSZVurtm1QydVXZ7W7XPPcIoGWpIVG/etOWbNR" +
|
||||
"Ru3zqp28r+B5bVrH5a7bZ2s91m+aU5Cc6LMdvu/Z3gL55hndfILdnNOtGPuS1ftD901LDKs+wFYziy3j/d" +
|
||||
"3FwjgKoJ0m3xJ81N7kvn3cix64aEH1gOfX8CXkVEtemFAahvz2IcgsBCkB0GhEMTKH1Ri3xn49yosYO0Bj" +
|
||||
"hErDpGy3Y9JLbjSRvoQNAF+jIVvPPi2Bz67gK8iK1v0ptmsWoHoWXFDQG+x9/IeQ8Hbqm+swBGT15dr1wM" +
|
||||
"CKDNA2yv0GKxE7b4+cwFBWDKQ+BlfDSgsat43tH94xD49diMtoeEVhgaN2mi6iwzMKqFjKUDPEBqCrmq6O" +
|
||||
"HHd0PViMreajEEFJxlaccAi4B4CgdhzHBHdOcFqCSYTI14g2WS2z0007DfAe4Hy7DdkrI2I+9yGIhitJhh" +
|
||||
"tTBjXYN+axcX1Ab7Oom2P+RgAtffDLj/A0a5vfkAbL/jWCwJHj9jT3afMzSQtQJYEhR6ibQ984+McsYQqg" +
|
||||
"m4baTBKMB6LHhDo/Aj8BInDcI6q0ePG/rgMx+57hkXnU+AnVGBxCWH3zq3ijclwI/tW3lC2jSVsWM4oN1O" +
|
||||
"SIc4XkjRGXjGEosylOUkUQ7AhhkBgSXYc1YvAksw4PG1kGWsAT5tOxbruOKbTnwIkSYxD1MbXsWAIUwMKz" +
|
||||
"eGUeDUbRwI9Fkek5CiwqAM3Bz6NUgdUt+vBslhIo8UM6kDQac4kDiicpHfe+FwY2SQI5q3oadvnoQ3hMHE" +
|
||||
"pCaHUgkqoVcRCG5aiKzCUCN03cUtJ4ikJxZTVlcWvDvarL626DiiVLH71pf0qG1y9H7mEPSQBNoTtQpFba" +
|
||||
"NzfDFfXSNJqPFJBkFb/1iiNLxhSAW3u4Ns7qHHi+i1F9fmyj1vV0sDIZonP0wh+waxjLr1vOPcmxORe7n3" +
|
||||
"pKOKIhVp9Rtb4+Owa3xCX/TpFPnrig6nKTNisNl8aNEKQRfQITh9kG/NhTzcvpwRZoARZvkh8S6h7Oz1zI" +
|
||||
"atZeuYWk5nvC4TJ2aFFJXBCTkcO9UuQQ0qb3FXdx4xTPH6dBeApP0CQ43QejN8kd7l64jI1krMVgJfPEf7" +
|
||||
"h3uq3o/K/ztZqP1QKFagz/G+t1XxwjeIFuqkRbXoTdlOTGnwCIoKZ6ku1AbrBoN6oCdX56w3UEOO0y2B9g" +
|
||||
"aLbAYWcAdpeweKa2IfIT2jz5QzXxD6AoP+DrdXtxeluV7pdWrvkcKqPp7rjS19d+wp/fff/5Ez3FPjzFNy" +
|
||||
"fdpTi9JB0sDp2JR7b309mn5HuPkEAAA==";
|
||||
|
||||
TrainedModelDefinition definition = InferenceToXContentCompressor.inflate(compressedDef,
|
||||
parser -> TrainedModelDefinition.fromXContent(parser, true).build(),
|
||||
xContentRegistry());
|
||||
|
||||
Map<String, Object> fields = new HashMap<String, Object>(){{
|
||||
put("sepal_length", 5.1);
|
||||
put("sepal_width", 3.5);
|
||||
put("petal_length", 1.4);
|
||||
put("petal_width", 0.2);
|
||||
}};
|
||||
|
||||
assertThat(
|
||||
((ClassificationInferenceResults)definition.getTrainedModel()
|
||||
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
|
||||
.getClassificationLabel(),
|
||||
equalTo("Iris-setosa"));
|
||||
|
||||
fields = new HashMap<String, Object>(){{
|
||||
put("sepal_length", 7.0);
|
||||
put("sepal_width", 3.2);
|
||||
put("petal_length", 4.7);
|
||||
put("petal_width", 1.4);
|
||||
}};
|
||||
assertThat(
|
||||
((ClassificationInferenceResults)definition.getTrainedModel()
|
||||
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
|
||||
.getClassificationLabel(),
|
||||
equalTo("Iris-versicolor"));
|
||||
|
||||
fields = new HashMap<String, Object>(){{
|
||||
put("sepal_length", 6.5);
|
||||
put("sepal_width", 3.0);
|
||||
put("petal_length", 5.2);
|
||||
put("petal_width", 2.0);
|
||||
}};
|
||||
assertThat(
|
||||
((ClassificationInferenceResults)definition.getTrainedModel()
|
||||
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
|
||||
.getClassificationLabel(),
|
||||
equalTo("Iris-virginica"));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -72,14 +72,19 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
double[] weights = randomBoolean() ?
|
||||
null :
|
||||
Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).mapToDouble(Double::valueOf).toArray();
|
||||
OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights),
|
||||
new WeightedSum(weights),
|
||||
new LogisticRegression(weights));
|
||||
TargetType targetType = randomFrom(TargetType.values());
|
||||
List<String> categoryLabels = null;
|
||||
if (randomBoolean() && targetType == TargetType.CLASSIFICATION) {
|
||||
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
|
||||
categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10));
|
||||
}
|
||||
|
||||
OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) :
|
||||
randomFrom(
|
||||
new WeightedMode(
|
||||
weights,
|
||||
categoryLabels != null ? categoryLabels.size() : randomIntBetween(2, 10)),
|
||||
new LogisticRegression(weights));
|
||||
|
||||
double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ?
|
||||
Stream.generate(ESTestCase::randomDouble)
|
||||
.limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size())
|
||||
|
@ -122,7 +127,7 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
for (int i = 0; i < numberOfModels + 2; i++) {
|
||||
weights[i] = randomDouble();
|
||||
}
|
||||
OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
|
||||
OutputAggregator outputAggregator = new WeightedSum(weights);
|
||||
|
||||
List<TrainedModel> models = new ArrayList<>(numberOfModels);
|
||||
for (int i = 0; i < numberOfModels; i++) {
|
||||
|
@ -252,7 +257,7 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}))
|
||||
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 2))
|
||||
.setClassificationWeights(Arrays.asList(0.7, 0.3))
|
||||
.build();
|
||||
|
||||
|
@ -350,7 +355,7 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}))
|
||||
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 2))
|
||||
.build();
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
|
@ -376,6 +381,77 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
}
|
||||
|
||||
public void testMultiClassClassificationInference() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree tree1 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(2.0))
|
||||
.addNode(TreeNode.builder(2)
|
||||
.setThreshold(0.8)
|
||||
.setSplitFeature(1)
|
||||
.setLeftChild(3)
|
||||
.setRightChild(4))
|
||||
.addNode(TreeNode.builder(3).setLeafValue(0.0))
|
||||
.addNode(TreeNode.builder(4).setLeafValue(1.0))
|
||||
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
|
||||
.build();
|
||||
Tree tree2 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(2.0))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(1.0))
|
||||
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
|
||||
.build();
|
||||
Tree tree3 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(2.0))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(1.0))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(0.0))
|
||||
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
|
||||
.build();
|
||||
Ensemble ensemble = Ensemble.builder()
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 3))
|
||||
.build();
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(2.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
|
||||
featureVector = Arrays.asList(2.0, 0.7);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
|
||||
featureVector = Arrays.asList(0.0, 1.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
|
||||
featureMap = new HashMap<String, Object>(2) {{
|
||||
put("foo", 0.6);
|
||||
put("bar", null);
|
||||
}};
|
||||
assertThat(1.0,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
||||
}
|
||||
|
||||
public void testRegressionInference() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree tree1 = Tree.builder()
|
||||
|
|
|
@ -16,6 +16,7 @@ import java.util.List;
|
|||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
|
||||
|
@ -23,7 +24,7 @@ public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
|
|||
@Override
|
||||
WeightedMode createTestInstance(int numberOfWeights) {
|
||||
double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray();
|
||||
return new WeightedMode(weights);
|
||||
return new WeightedMode(weights, randomIntBetween(2, 10));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -33,7 +34,7 @@ public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
|
|||
|
||||
@Override
|
||||
protected WeightedMode createTestInstance() {
|
||||
return randomBoolean() ? new WeightedMode() : createTestInstance(randomIntBetween(1, 100));
|
||||
return randomBoolean() ? new WeightedMode(randomIntBetween(2, 10)) : createTestInstance(randomIntBetween(1, 100));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -45,21 +46,33 @@ public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
|
|||
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
|
||||
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
|
||||
|
||||
WeightedMode weightedMode = new WeightedMode(ones);
|
||||
WeightedMode weightedMode = new WeightedMode(ones, 6);
|
||||
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
|
||||
|
||||
double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0};
|
||||
|
||||
weightedMode = new WeightedMode(variedWeights);
|
||||
weightedMode = new WeightedMode(variedWeights, 6);
|
||||
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(5.0));
|
||||
|
||||
weightedMode = new WeightedMode();
|
||||
weightedMode = new WeightedMode(6);
|
||||
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
|
||||
|
||||
values = Arrays.asList(1.0, 1.0, 1.0, 1.0, 2.0);
|
||||
weightedMode = new WeightedMode(6);
|
||||
List<Double> processedValues = weightedMode.processValues(values);
|
||||
assertThat(processedValues.size(), equalTo(6));
|
||||
assertThat(processedValues.get(0), equalTo(0.0));
|
||||
assertThat(processedValues.get(1), closeTo(0.95257412, 0.00001));
|
||||
assertThat(processedValues.get(2), closeTo((1.0 - 0.95257412), 0.00001));
|
||||
assertThat(processedValues.get(3), equalTo(0.0));
|
||||
assertThat(processedValues.get(4), equalTo(0.0));
|
||||
assertThat(processedValues.get(5), equalTo(0.0));
|
||||
assertThat(weightedMode.aggregate(processedValues), equalTo(1.0));
|
||||
}
|
||||
|
||||
public void testCompatibleWith() {
|
||||
WeightedMode weightedMode = createTestInstance();
|
||||
assertThat(weightedMode.compatibleWith(TargetType.CLASSIFICATION), is(true));
|
||||
assertThat(weightedMode.compatibleWith(TargetType.REGRESSION), is(true));
|
||||
assertThat(weightedMode.compatibleWith(TargetType.REGRESSION), is(false));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -424,6 +424,7 @@ public class InferenceIngestIT extends ESRestTestCase {
|
|||
" ],\n" +
|
||||
" \"aggregate_output\": {\n" +
|
||||
" \"weighted_mode\": {\n" +
|
||||
" \"num_classes\": \"2\",\n" +
|
||||
" \"weights\": [\n" +
|
||||
" 0.5,\n" +
|
||||
" 0.5\n" +
|
||||
|
|
|
@ -188,7 +188,7 @@ public class LocalModelTests extends ESTestCase {
|
|||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}))
|
||||
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 2))
|
||||
.build();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue