diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java index 29bd8e4579f..0574f19f823 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java @@ -41,7 +41,6 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHe import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS; -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.FEATURE_NAMES; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TARGET_TYPE; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS; @@ -53,14 +52,12 @@ public class EnsembleInferenceModel implements InferenceModel { private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "ensemble_inference_model", true, - a -> new EnsembleInferenceModel((List)a[0], - (List)a[1], - (OutputAggregator)a[2], - TargetType.fromString((String)a[3]), - (List)a[4], - (List)a[5])); + a -> new EnsembleInferenceModel((List)a[0], + (OutputAggregator)a[1], + TargetType.fromString((String)a[2]), + (List)a[3], + (List)a[4])); static { - PARSER.declareStringArray(constructorArg(), FEATURE_NAMES); PARSER.declareNamedObjects(constructorArg(), (p, c, n) -> p.namedObject(InferenceModel.class, n, null), (ensembleBuilder) -> {}, @@ -77,20 +74,19 @@ public class EnsembleInferenceModel implements InferenceModel { return PARSER.apply(parser, null); } - private String[] featureNames; + private String[] featureNames = new String[0]; private final List models; private final OutputAggregator outputAggregator; private final TargetType targetType; private final List classificationLabels; private final double[] classificationWeights; + private volatile boolean preparedForInference = false; - EnsembleInferenceModel(List featureNames, - List models, - OutputAggregator outputAggregator, - TargetType targetType, - List classificationLabels, - List classificationWeights) { - this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]); + private EnsembleInferenceModel(List models, + OutputAggregator outputAggregator, + TargetType targetType, + List classificationLabels, + List classificationWeights) { this.models = ExceptionsHelper.requireNonNull(models, TRAINED_MODELS); this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT); this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); @@ -135,6 +131,10 @@ public class EnsembleInferenceModel implements InferenceModel { throw ExceptionsHelper.badRequestException( "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); } + if (preparedForInference == false) { + throw ExceptionsHelper.serverError("model is not prepared for inference"); + } + assert featureNames != null && featureNames.length > 0; double[][] inferenceResults = new double[this.models.size()][]; double[][] featureInfluence = new double[features.length][]; int i = 0; @@ -232,6 +232,10 @@ public class EnsembleInferenceModel implements InferenceModel { @Override public void rewriteFeatureIndices(Map newFeatureIndexMapping) { + if (preparedForInference) { + return; + } + preparedForInference = true; if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) { Set referencedFeatures = subModelFeatures(); int newFeatureIndex = 0; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java index 68bbbf84a20..f9a0a81700f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java @@ -80,6 +80,7 @@ public class TreeInferenceModel implements InferenceModel { private final double highOrderCategory; private final int maxDepth; private final int leafSize; + private volatile boolean preparedForInference = false; TreeInferenceModel(List featureNames, List nodes, TargetType targetType, List classificationLabels) { this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]); @@ -136,6 +137,9 @@ public class TreeInferenceModel implements InferenceModel { throw ExceptionsHelper.badRequestException( "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); } + if (preparedForInference == false) { + throw ExceptionsHelper.serverError("model is not prepared for inference"); + } double[][] featureImportance = config.requestingImportance() ? featureImportance(features) : new double[0][]; @@ -288,6 +292,10 @@ public class TreeInferenceModel implements InferenceModel { @Override public void rewriteFeatureIndices(Map newFeatureIndexMapping) { + if (preparedForInference) { + return; + } + preparedForInference = true; if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) { return; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 9b1c2088aee..6717ef32c20 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -66,7 +66,10 @@ public class EnsembleTests extends AbstractSerializingTestCase { public static Ensemble createRandom(TargetType targetType, List featureNames) { int numberOfModels = randomIntBetween(1, 10); - List models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6)) + List treeFeatureNames = featureNames.isEmpty() ? + Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList()) : + featureNames; + List models = Stream.generate(() -> TreeTests.buildRandomTree(treeFeatureNames, 6)) .limit(numberOfModels) .collect(Collectors.toList()); double[] weights = randomBoolean() ? diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModelTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModelTests.java index f59fe338fd7..fb9372c2ce4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModelTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModelTests.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; @@ -15,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; @@ -30,19 +32,26 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import java.util.stream.Stream; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel; import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; public class EnsembleInferenceModelTests extends ESTestCase { + private static final int NUMBER_OF_TEST_RUNS = 20; private final double eps = 1.0E-8; public static EnsembleInferenceModel serializeFromTrainedModel(Ensemble ensemble) throws IOException { NamedXContentRegistry registry = new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); - return deserializeFromTrainedModel(ensemble, + EnsembleInferenceModel model = deserializeFromTrainedModel(ensemble, registry, EnsembleInferenceModel::fromXContent); + model.rewriteFeatureIndices(Collections.emptyMap()); + return model; } @Override @@ -52,6 +61,27 @@ public class EnsembleInferenceModelTests extends ESTestCase { return new NamedXContentRegistry(namedXContent); } + public void testSerializationFromEnsemble() throws Exception { + for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) { + int numberOfFeatures = randomIntBetween(1, 10); + Ensemble ensemble = EnsembleTests.createRandom(randomFrom(TargetType.values()), + randomBoolean() ? + Collections.emptyList() : + Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList())); + assertThat(serializeFromTrainedModel(ensemble), is(not(nullValue()))); + } + } + + public void testInferenceWithoutPreparing() throws IOException { + Ensemble ensemble = EnsembleTests.createRandom(TargetType.REGRESSION, + Stream.generate(() -> randomAlphaOfLength(10)).limit(4).collect(Collectors.toList())); + + EnsembleInferenceModel model = deserializeFromTrainedModel(ensemble, + xContentRegistry(), + EnsembleInferenceModel::fromXContent); + expectThrows(ElasticsearchException.class, () -> model.infer(Collections.emptyMap(), RegressionConfig.EMPTY_PARAMS, null)); + } + public void testClassificationProbability() throws IOException { List featureNames = Arrays.asList("foo", "bar"); Tree tree1 = Tree.builder() @@ -100,6 +130,7 @@ public class EnsembleInferenceModelTests extends ESTestCase { EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject, xContentRegistry(), EnsembleInferenceModel::fromXContent); + ensemble.rewriteFeatureIndices(Collections.emptyMap()); List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); @@ -205,6 +236,7 @@ public class EnsembleInferenceModelTests extends ESTestCase { EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject, xContentRegistry(), EnsembleInferenceModel::fromXContent); + ensemble.rewriteFeatureIndices(Collections.emptyMap()); List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); @@ -284,6 +316,7 @@ public class EnsembleInferenceModelTests extends ESTestCase { EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject, xContentRegistry(), EnsembleInferenceModel::fromXContent); + ensemble.rewriteFeatureIndices(Collections.emptyMap()); List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); @@ -349,6 +382,7 @@ public class EnsembleInferenceModelTests extends ESTestCase { EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject, xContentRegistry(), EnsembleInferenceModel::fromXContent); + ensemble.rewriteFeatureIndices(Collections.emptyMap()); List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); @@ -373,6 +407,7 @@ public class EnsembleInferenceModelTests extends ESTestCase { ensemble = deserializeFromTrainedModel(ensembleObject, xContentRegistry(), EnsembleInferenceModel::fromXContent); + ensemble.rewriteFeatureIndices(Collections.emptyMap()); featureVector = Arrays.asList(0.4, 0.0); featureMap = zipObjMap(featureNames, featureVector); @@ -466,6 +501,7 @@ public class EnsembleInferenceModelTests extends ESTestCase { EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject, xContentRegistry(), EnsembleInferenceModel::fromXContent); + ensemble.rewriteFeatureIndices(Collections.emptyMap()); double[][] featureImportance = ensemble.featureImportance(new double[]{0.0, 0.9}); assertThat(featureImportance[0][0], closeTo(-1.653200025, eps)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java index a337d7d2f98..e0d40848376 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; @@ -16,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; import org.elasticsearch.xpack.core.ml.job.config.Operator; import java.io.IOException; @@ -31,16 +33,22 @@ import java.util.stream.IntStream; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; public class TreeInferenceModelTests extends ESTestCase { + private static final int NUMBER_OF_TEST_RUNS = 20; private final double eps = 1.0E-8; public static TreeInferenceModel serializeFromTrainedModel(Tree tree) throws IOException { NamedXContentRegistry registry = new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); - return deserializeFromTrainedModel(tree, + TreeInferenceModel model = deserializeFromTrainedModel(tree, registry, TreeInferenceModel::fromXContent); + model.rewriteFeatureIndices(Collections.emptyMap()); + return model; } @Override @@ -50,6 +58,22 @@ public class TreeInferenceModelTests extends ESTestCase { return new NamedXContentRegistry(namedXContent); } + public void testSerializationFromEnsemble() throws Exception { + for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) { + Tree tree = TreeTests.createRandom(randomFrom(TargetType.values())); + assertThat(serializeFromTrainedModel(tree), is(not(nullValue()))); + } + } + + public void testInferenceWithoutPreparing() throws IOException { + Tree tree = TreeTests.createRandom(randomFrom(TargetType.values())); + + TreeInferenceModel model = deserializeFromTrainedModel(tree, + xContentRegistry(), + TreeInferenceModel::fromXContent); + expectThrows(ElasticsearchException.class, () -> model.infer(Collections.emptyMap(), RegressionConfig.EMPTY_PARAMS, null)); + } + public void testInferWithStump() throws IOException { Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION); builder.setRoot(TreeNode.builder(0).setLeafValue(Collections.singletonList(42.0))); @@ -59,6 +83,7 @@ public class TreeInferenceModelTests extends ESTestCase { TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); + tree.rewriteFeatureIndices(Collections.emptyMap()); List featureNames = Arrays.asList("foo", "bar"); List featureVector = Arrays.asList(0.6, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump @@ -82,6 +107,7 @@ public class TreeInferenceModelTests extends ESTestCase { TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); + tree.rewriteFeatureIndices(Collections.emptyMap()); // This feature vector should hit the right child of the root node List featureVector = Arrays.asList(0.6, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); @@ -137,6 +163,7 @@ public class TreeInferenceModelTests extends ESTestCase { TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); + tree.rewriteFeatureIndices(Collections.emptyMap()); double eps = 0.000001; // This feature vector should hit the right child of the root node List featureVector = Arrays.asList(0.6, 0.0); @@ -211,6 +238,7 @@ public class TreeInferenceModelTests extends ESTestCase { TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); + tree.rewriteFeatureIndices(Collections.emptyMap()); double[][] featureImportance = tree.featureImportance(new double[]{0.25, 0.25}); assertThat(featureImportance[0][0], closeTo(-5.0, eps));