diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java index 37d589badd1..2ba08effb7b 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -34,13 +34,15 @@ public class WeightedMode implements OutputAggregator { public static final String NAME = "weighted_mode"; public static final ParseField WEIGHTS = new ParseField("weights"); + public static final ParseField NUM_CLASSES = new ParseField("num_classes"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, true, - a -> new WeightedMode((List)a[0])); + a -> new WeightedMode((Integer)a[0], (List)a[1])); static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_CLASSES); PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); } @@ -49,9 +51,11 @@ public class WeightedMode implements OutputAggregator { } private final List weights; + private final int numClasses; - public WeightedMode(List weights) { + public WeightedMode(int numClasses, List weights) { this.weights = weights; + this.numClasses = numClasses; } @Override @@ -65,6 +69,7 @@ public class WeightedMode implements OutputAggregator { if (weights != null) { builder.field(WEIGHTS.getPreferredName(), weights); } + builder.field(NUM_CLASSES.getPreferredName(), numClasses); builder.endObject(); return builder; } @@ -74,11 +79,11 @@ public class WeightedMode implements OutputAggregator { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; WeightedMode that = (WeightedMode) o; - return Objects.equals(weights, that.weights); + return Objects.equals(weights, that.weights) && numClasses == that.numClasses; } @Override public int hashCode() { - return Objects.hash(weights); + return Objects.hash(weights, numClasses); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 6347454f41e..a0470ea2500 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -31,7 +31,6 @@ import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.function.Predicate; @@ -69,17 +68,17 @@ public class EnsembleTests extends AbstractXContentTestCase { List models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6, targetType)) .limit(numberOfModels) .collect(Collectors.toList()); - List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); - List possibleAggregators = new ArrayList<>(Arrays.asList(new WeightedMode(weights), - new LogisticRegression(weights))); - if (targetType.equals(TargetType.REGRESSION)) { - possibleAggregators.add(new WeightedSum(weights)); - } - OutputAggregator outputAggregator = randomFrom(possibleAggregators.toArray(new OutputAggregator[0])); List categoryLabels = null; if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) { - categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); + categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10)); } + List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); + OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) : + randomFrom( + new WeightedMode( + categoryLabels != null ? categoryLabels.size() : randomIntBetween(2, 10), + weights), + new LogisticRegression(weights)); double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ? Stream.generate(ESTestCase::randomDouble) .limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size()) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java index a04652c1d38..65f89f614c2 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -30,7 +30,9 @@ import java.util.stream.Stream; public class WeightedModeTests extends AbstractXContentTestCase { WeightedMode createTestInstance(int numberOfWeights) { - return new WeightedMode(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList())); + return new WeightedMode( + randomIntBetween(2, 10), + Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList())); } @Override @@ -45,7 +47,7 @@ public class WeightedModeTests extends AbstractXContentTestCase { @Override protected WeightedMode createTestInstance() { - return randomBoolean() ? new WeightedMode(null) : createTestInstance(randomIntBetween(1, 100)); + return randomBoolean() ? new WeightedMode(randomIntBetween(2, 10), null) : createTestInstance(randomIntBetween(1, 100)); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index 73689d16b1c..df2e33e4e6f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.ArrayList; @@ -29,6 +30,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedMode.class); public static final ParseField NAME = new ParseField("weighted_mode"); public static final ParseField WEIGHTS = new ParseField("weights"); + public static final ParseField NUM_CLASSES = new ParseField("num_classes"); private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser STRICT_PARSER = createParser(false); @@ -38,7 +40,8 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), lenient, - a -> new WeightedMode((List)a[0])); + a -> new WeightedMode((Integer) a[0], (List)a[1])); + parser.declareInt(ConstructingObjectParser.constructorArg(), NUM_CLASSES); parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); return parser; } @@ -52,17 +55,23 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa } private final double[] weights; + private final int numClasses; - WeightedMode() { - this((List) null); + WeightedMode(int numClasses) { + this(numClasses, null); } - private WeightedMode(List weights) { - this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray()); + private WeightedMode(Integer numClasses, List weights) { + this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray(), numClasses); } - public WeightedMode(double[] weights) { + public WeightedMode(double[] weights, Integer numClasses) { this.weights = weights; + this.numClasses = ExceptionsHelper.requireNonNull(numClasses, NUM_CLASSES); + if (this.numClasses <= 1) { + throw new IllegalArgumentException("[" + NUM_CLASSES.getPreferredName() + "] must be greater than 1."); + } + } public WeightedMode(StreamInput in) throws IOException { @@ -71,6 +80,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa } else { this.weights = null; } + this.numClasses = in.readVInt(); } @Override @@ -99,7 +109,10 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa maxVal = integerValue; } } - List frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY)); + if (maxVal >= numClasses) { + throw new IllegalArgumentException("values contain entries larger than expected max of [" + (numClasses - 1) + "]"); + } + List frequencies = new ArrayList<>(Collections.nCopies(numClasses, Double.NEGATIVE_INFINITY)); for (int i = 0; i < freqArray.size(); i++) { Double weight = weights == null ? 1.0 : weights[i]; Integer value = freqArray.get(i); @@ -133,7 +146,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa @Override public boolean compatibleWith(TargetType targetType) { - return true; + return targetType.equals(TargetType.CLASSIFICATION); } @Override @@ -147,6 +160,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa if (weights != null) { out.writeDoubleArray(weights); } + out.writeVInt(numClasses); } @Override @@ -155,6 +169,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa if (weights != null) { builder.field(WEIGHTS.getPreferredName(), weights); } + builder.field(NUM_CLASSES.getPreferredName(), numClasses); builder.endObject(); return builder; } @@ -164,12 +179,12 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; WeightedMode that = (WeightedMode) o; - return Arrays.equals(weights, that.weights); + return Arrays.equals(weights, that.weights) && numClasses == that.numClasses; } @Override public int hashCode() { - return Arrays.hashCode(weights); + return Objects.hash(Arrays.hashCode(weights), numClasses); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java index e8547bf8534..dbebcd91921 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java @@ -18,6 +18,8 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; @@ -26,7 +28,9 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -302,4 +306,70 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase TrainedModelDefinition.fromXContent(parser, true).build(), + xContentRegistry()); + + Map fields = new HashMap(){{ + put("sepal_length", 5.1); + put("sepal_width", 3.5); + put("petal_length", 1.4); + put("petal_width", 0.2); + }}; + + assertThat( + ((ClassificationInferenceResults)definition.getTrainedModel() + .infer(fields, ClassificationConfig.EMPTY_PARAMS)) + .getClassificationLabel(), + equalTo("Iris-setosa")); + + fields = new HashMap(){{ + put("sepal_length", 7.0); + put("sepal_width", 3.2); + put("petal_length", 4.7); + put("petal_width", 1.4); + }}; + assertThat( + ((ClassificationInferenceResults)definition.getTrainedModel() + .infer(fields, ClassificationConfig.EMPTY_PARAMS)) + .getClassificationLabel(), + equalTo("Iris-versicolor")); + + fields = new HashMap(){{ + put("sepal_length", 6.5); + put("sepal_width", 3.0); + put("petal_length", 5.2); + put("petal_width", 2.0); + }}; + assertThat( + ((ClassificationInferenceResults)definition.getTrainedModel() + .infer(fields, ClassificationConfig.EMPTY_PARAMS)) + .getClassificationLabel(), + equalTo("Iris-virginica")); + } + } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index d1f148dcab6..195ab3cf2b9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -72,14 +72,19 @@ public class EnsembleTests extends AbstractSerializingTestCase { double[] weights = randomBoolean() ? null : Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).mapToDouble(Double::valueOf).toArray(); - OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), - new WeightedSum(weights), - new LogisticRegression(weights)); TargetType targetType = randomFrom(TargetType.values()); List categoryLabels = null; if (randomBoolean() && targetType == TargetType.CLASSIFICATION) { - categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); + categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10)); } + + OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) : + randomFrom( + new WeightedMode( + weights, + categoryLabels != null ? categoryLabels.size() : randomIntBetween(2, 10)), + new LogisticRegression(weights)); + double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ? Stream.generate(ESTestCase::randomDouble) .limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size()) @@ -122,7 +127,7 @@ public class EnsembleTests extends AbstractSerializingTestCase { for (int i = 0; i < numberOfModels + 2; i++) { weights[i] = randomDouble(); } - OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); + OutputAggregator outputAggregator = new WeightedSum(weights); List models = new ArrayList<>(numberOfModels); for (int i = 0; i < numberOfModels; i++) { @@ -252,7 +257,7 @@ public class EnsembleTests extends AbstractSerializingTestCase { .setTargetType(TargetType.CLASSIFICATION) .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) - .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0})) + .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 2)) .setClassificationWeights(Arrays.asList(0.7, 0.3)) .build(); @@ -350,7 +355,7 @@ public class EnsembleTests extends AbstractSerializingTestCase { .setTargetType(TargetType.CLASSIFICATION) .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) - .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0})) + .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 2)) .build(); List featureVector = Arrays.asList(0.4, 0.0); @@ -376,6 +381,77 @@ public class EnsembleTests extends AbstractSerializingTestCase { closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); } + public void testMultiClassClassificationInference() { + List featureNames = Arrays.asList("foo", "bar"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(2.0)) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.0)) + .addNode(TreeNode.builder(4).setLeafValue(1.0)) + .setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION)) + .build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(1) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(2.0)) + .addNode(TreeNode.builder(2).setLeafValue(1.0)) + .setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(1) + .setThreshold(2.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2).setLeafValue(0.0)) + .setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION)) + .build(); + Ensemble ensemble = Ensemble.builder() + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 3)) + .build(); + + List featureVector = Arrays.asList(0.4, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); + assertThat(2.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); + + featureVector = Arrays.asList(2.0, 0.7); + featureMap = zipObjMap(featureNames, featureVector); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); + + featureVector = Arrays.asList(0.0, 1.0); + featureMap = zipObjMap(featureNames, featureVector); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); + + featureMap = new HashMap(2) {{ + put("foo", 0.6); + put("bar", null); + }}; + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); + } + public void testRegressionInference() { List featureNames = Arrays.asList("foo", "bar"); Tree tree1 = Tree.builder() diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java index 4421a8fbb93..6f0496772be 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -16,6 +16,7 @@ import java.util.List; import java.util.stream.Stream; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; public class WeightedModeTests extends WeightedAggregatorTests { @@ -23,7 +24,7 @@ public class WeightedModeTests extends WeightedAggregatorTests { @Override WeightedMode createTestInstance(int numberOfWeights) { double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray(); - return new WeightedMode(weights); + return new WeightedMode(weights, randomIntBetween(2, 10)); } @Override @@ -33,7 +34,7 @@ public class WeightedModeTests extends WeightedAggregatorTests { @Override protected WeightedMode createTestInstance() { - return randomBoolean() ? new WeightedMode() : createTestInstance(randomIntBetween(1, 100)); + return randomBoolean() ? new WeightedMode(randomIntBetween(2, 10)) : createTestInstance(randomIntBetween(1, 100)); } @Override @@ -45,21 +46,33 @@ public class WeightedModeTests extends WeightedAggregatorTests { double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0}; List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); - WeightedMode weightedMode = new WeightedMode(ones); + WeightedMode weightedMode = new WeightedMode(ones, 6); assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0}; - weightedMode = new WeightedMode(variedWeights); + weightedMode = new WeightedMode(variedWeights, 6); assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(5.0)); - weightedMode = new WeightedMode(); + weightedMode = new WeightedMode(6); assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); + + values = Arrays.asList(1.0, 1.0, 1.0, 1.0, 2.0); + weightedMode = new WeightedMode(6); + List processedValues = weightedMode.processValues(values); + assertThat(processedValues.size(), equalTo(6)); + assertThat(processedValues.get(0), equalTo(0.0)); + assertThat(processedValues.get(1), closeTo(0.95257412, 0.00001)); + assertThat(processedValues.get(2), closeTo((1.0 - 0.95257412), 0.00001)); + assertThat(processedValues.get(3), equalTo(0.0)); + assertThat(processedValues.get(4), equalTo(0.0)); + assertThat(processedValues.get(5), equalTo(0.0)); + assertThat(weightedMode.aggregate(processedValues), equalTo(1.0)); } public void testCompatibleWith() { WeightedMode weightedMode = createTestInstance(); assertThat(weightedMode.compatibleWith(TargetType.CLASSIFICATION), is(true)); - assertThat(weightedMode.compatibleWith(TargetType.REGRESSION), is(true)); + assertThat(weightedMode.compatibleWith(TargetType.REGRESSION), is(false)); } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index efe2bb2c95f..d0be015d2fb 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -424,6 +424,7 @@ public class InferenceIngestIT extends ESRestTestCase { " ],\n" + " \"aggregate_output\": {\n" + " \"weighted_mode\": {\n" + + " \"num_classes\": \"2\",\n" + " \"weights\": [\n" + " 0.5,\n" + " 0.5\n" + diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index ff0741d5dc4..12d26d85c3d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -188,7 +188,7 @@ public class LocalModelTests extends ESTestCase { .setTargetType(TargetType.CLASSIFICATION) .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) - .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0})) + .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 2)) .build(); }