This has `EnsembleInferenceModel` not parse feature_names from the XContent. Instead, it will rely on `rewriteFeatureIndices` to be called ahead time. Consequently, protections are made for a fail fast path if `rewriteFeatureIndices` has not been called before `infer`.
This commit is contained in:
parent
0ce102a5f4
commit
79c784932f
|
@ -41,7 +41,6 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHe
|
|||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.FEATURE_NAMES;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TARGET_TYPE;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS;
|
||||
|
||||
|
@ -53,14 +52,12 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
private static final ConstructingObjectParser<EnsembleInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
|
||||
"ensemble_inference_model",
|
||||
true,
|
||||
a -> new EnsembleInferenceModel((List<String>)a[0],
|
||||
(List<InferenceModel>)a[1],
|
||||
(OutputAggregator)a[2],
|
||||
TargetType.fromString((String)a[3]),
|
||||
(List<String>)a[4],
|
||||
(List<Double>)a[5]));
|
||||
a -> new EnsembleInferenceModel((List<InferenceModel>)a[0],
|
||||
(OutputAggregator)a[1],
|
||||
TargetType.fromString((String)a[2]),
|
||||
(List<String>)a[3],
|
||||
(List<Double>)a[4]));
|
||||
static {
|
||||
PARSER.declareStringArray(constructorArg(), FEATURE_NAMES);
|
||||
PARSER.declareNamedObjects(constructorArg(),
|
||||
(p, c, n) -> p.namedObject(InferenceModel.class, n, null),
|
||||
(ensembleBuilder) -> {},
|
||||
|
@ -77,20 +74,19 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private String[] featureNames;
|
||||
private String[] featureNames = new String[0];
|
||||
private final List<InferenceModel> models;
|
||||
private final OutputAggregator outputAggregator;
|
||||
private final TargetType targetType;
|
||||
private final List<String> classificationLabels;
|
||||
private final double[] classificationWeights;
|
||||
private volatile boolean preparedForInference = false;
|
||||
|
||||
EnsembleInferenceModel(List<String> featureNames,
|
||||
List<InferenceModel> models,
|
||||
OutputAggregator outputAggregator,
|
||||
TargetType targetType,
|
||||
List<String> classificationLabels,
|
||||
List<Double> classificationWeights) {
|
||||
this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]);
|
||||
private EnsembleInferenceModel(List<InferenceModel> models,
|
||||
OutputAggregator outputAggregator,
|
||||
TargetType targetType,
|
||||
List<String> classificationLabels,
|
||||
List<Double> classificationWeights) {
|
||||
this.models = ExceptionsHelper.requireNonNull(models, TRAINED_MODELS);
|
||||
this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
|
||||
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
|
||||
|
@ -135,6 +131,10 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
|
||||
}
|
||||
if (preparedForInference == false) {
|
||||
throw ExceptionsHelper.serverError("model is not prepared for inference");
|
||||
}
|
||||
assert featureNames != null && featureNames.length > 0;
|
||||
double[][] inferenceResults = new double[this.models.size()][];
|
||||
double[][] featureInfluence = new double[features.length][];
|
||||
int i = 0;
|
||||
|
@ -232,6 +232,10 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
|
||||
@Override
|
||||
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
||||
if (preparedForInference) {
|
||||
return;
|
||||
}
|
||||
preparedForInference = true;
|
||||
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
|
||||
Set<String> referencedFeatures = subModelFeatures();
|
||||
int newFeatureIndex = 0;
|
||||
|
|
|
@ -80,6 +80,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
private final double highOrderCategory;
|
||||
private final int maxDepth;
|
||||
private final int leafSize;
|
||||
private volatile boolean preparedForInference = false;
|
||||
|
||||
TreeInferenceModel(List<String> featureNames, List<NodeBuilder> nodes, TargetType targetType, List<String> classificationLabels) {
|
||||
this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]);
|
||||
|
@ -136,6 +137,9 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
|
||||
}
|
||||
if (preparedForInference == false) {
|
||||
throw ExceptionsHelper.serverError("model is not prepared for inference");
|
||||
}
|
||||
double[][] featureImportance = config.requestingImportance() ?
|
||||
featureImportance(features) :
|
||||
new double[0][];
|
||||
|
@ -288,6 +292,10 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
|
||||
@Override
|
||||
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
||||
if (preparedForInference) {
|
||||
return;
|
||||
}
|
||||
preparedForInference = true;
|
||||
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -66,7 +66,10 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|||
|
||||
public static Ensemble createRandom(TargetType targetType, List<String> featureNames) {
|
||||
int numberOfModels = randomIntBetween(1, 10);
|
||||
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
|
||||
List<String> treeFeatureNames = featureNames.isEmpty() ?
|
||||
Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList()) :
|
||||
featureNames;
|
||||
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(treeFeatureNames, 6))
|
||||
.limit(numberOfModels)
|
||||
.collect(Collectors.toList());
|
||||
double[] weights = randomBoolean() ?
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
|
@ -15,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
|
@ -30,19 +32,26 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class EnsembleInferenceModelTests extends ESTestCase {
|
||||
|
||||
private static final int NUMBER_OF_TEST_RUNS = 20;
|
||||
private final double eps = 1.0E-8;
|
||||
|
||||
public static EnsembleInferenceModel serializeFromTrainedModel(Ensemble ensemble) throws IOException {
|
||||
NamedXContentRegistry registry = new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
return deserializeFromTrainedModel(ensemble,
|
||||
EnsembleInferenceModel model = deserializeFromTrainedModel(ensemble,
|
||||
registry,
|
||||
EnsembleInferenceModel::fromXContent);
|
||||
model.rewriteFeatureIndices(Collections.emptyMap());
|
||||
return model;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -52,6 +61,27 @@ public class EnsembleInferenceModelTests extends ESTestCase {
|
|||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
public void testSerializationFromEnsemble() throws Exception {
|
||||
for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) {
|
||||
int numberOfFeatures = randomIntBetween(1, 10);
|
||||
Ensemble ensemble = EnsembleTests.createRandom(randomFrom(TargetType.values()),
|
||||
randomBoolean() ?
|
||||
Collections.emptyList() :
|
||||
Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList()));
|
||||
assertThat(serializeFromTrainedModel(ensemble), is(not(nullValue())));
|
||||
}
|
||||
}
|
||||
|
||||
public void testInferenceWithoutPreparing() throws IOException {
|
||||
Ensemble ensemble = EnsembleTests.createRandom(TargetType.REGRESSION,
|
||||
Stream.generate(() -> randomAlphaOfLength(10)).limit(4).collect(Collectors.toList()));
|
||||
|
||||
EnsembleInferenceModel model = deserializeFromTrainedModel(ensemble,
|
||||
xContentRegistry(),
|
||||
EnsembleInferenceModel::fromXContent);
|
||||
expectThrows(ElasticsearchException.class, () -> model.infer(Collections.emptyMap(), RegressionConfig.EMPTY_PARAMS, null));
|
||||
}
|
||||
|
||||
public void testClassificationProbability() throws IOException {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
Tree tree1 = Tree.builder()
|
||||
|
@ -100,6 +130,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
|
|||
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
|
||||
xContentRegistry(),
|
||||
EnsembleInferenceModel::fromXContent);
|
||||
ensemble.rewriteFeatureIndices(Collections.emptyMap());
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
|
@ -205,6 +236,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
|
|||
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
|
||||
xContentRegistry(),
|
||||
EnsembleInferenceModel::fromXContent);
|
||||
ensemble.rewriteFeatureIndices(Collections.emptyMap());
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
|
@ -284,6 +316,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
|
|||
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
|
||||
xContentRegistry(),
|
||||
EnsembleInferenceModel::fromXContent);
|
||||
ensemble.rewriteFeatureIndices(Collections.emptyMap());
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
|
@ -349,6 +382,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
|
|||
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
|
||||
xContentRegistry(),
|
||||
EnsembleInferenceModel::fromXContent);
|
||||
ensemble.rewriteFeatureIndices(Collections.emptyMap());
|
||||
|
||||
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
|
@ -373,6 +407,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
|
|||
ensemble = deserializeFromTrainedModel(ensembleObject,
|
||||
xContentRegistry(),
|
||||
EnsembleInferenceModel::fromXContent);
|
||||
ensemble.rewriteFeatureIndices(Collections.emptyMap());
|
||||
|
||||
featureVector = Arrays.asList(0.4, 0.0);
|
||||
featureMap = zipObjMap(featureNames, featureVector);
|
||||
|
@ -466,6 +501,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
|
|||
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
|
||||
xContentRegistry(),
|
||||
EnsembleInferenceModel::fromXContent);
|
||||
ensemble.rewriteFeatureIndices(Collections.emptyMap());
|
||||
|
||||
double[][] featureImportance = ensemble.featureImportance(new double[]{0.0, 0.9});
|
||||
assertThat(featureImportance[0][0], closeTo(-1.653200025, eps));
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
|
@ -16,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Operator;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -31,16 +33,22 @@ import java.util.stream.IntStream;
|
|||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class TreeInferenceModelTests extends ESTestCase {
|
||||
|
||||
private static final int NUMBER_OF_TEST_RUNS = 20;
|
||||
private final double eps = 1.0E-8;
|
||||
|
||||
public static TreeInferenceModel serializeFromTrainedModel(Tree tree) throws IOException {
|
||||
NamedXContentRegistry registry = new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
return deserializeFromTrainedModel(tree,
|
||||
TreeInferenceModel model = deserializeFromTrainedModel(tree,
|
||||
registry,
|
||||
TreeInferenceModel::fromXContent);
|
||||
model.rewriteFeatureIndices(Collections.emptyMap());
|
||||
return model;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -50,6 +58,22 @@ public class TreeInferenceModelTests extends ESTestCase {
|
|||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
public void testSerializationFromEnsemble() throws Exception {
|
||||
for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) {
|
||||
Tree tree = TreeTests.createRandom(randomFrom(TargetType.values()));
|
||||
assertThat(serializeFromTrainedModel(tree), is(not(nullValue())));
|
||||
}
|
||||
}
|
||||
|
||||
public void testInferenceWithoutPreparing() throws IOException {
|
||||
Tree tree = TreeTests.createRandom(randomFrom(TargetType.values()));
|
||||
|
||||
TreeInferenceModel model = deserializeFromTrainedModel(tree,
|
||||
xContentRegistry(),
|
||||
TreeInferenceModel::fromXContent);
|
||||
expectThrows(ElasticsearchException.class, () -> model.infer(Collections.emptyMap(), RegressionConfig.EMPTY_PARAMS, null));
|
||||
}
|
||||
|
||||
public void testInferWithStump() throws IOException {
|
||||
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
|
||||
builder.setRoot(TreeNode.builder(0).setLeafValue(Collections.singletonList(42.0)));
|
||||
|
@ -59,6 +83,7 @@ public class TreeInferenceModelTests extends ESTestCase {
|
|||
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
|
||||
xContentRegistry(),
|
||||
TreeInferenceModel::fromXContent);
|
||||
tree.rewriteFeatureIndices(Collections.emptyMap());
|
||||
List<String> featureNames = Arrays.asList("foo", "bar");
|
||||
List<Double> featureVector = Arrays.asList(0.6, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump
|
||||
|
@ -82,6 +107,7 @@ public class TreeInferenceModelTests extends ESTestCase {
|
|||
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
|
||||
xContentRegistry(),
|
||||
TreeInferenceModel::fromXContent);
|
||||
tree.rewriteFeatureIndices(Collections.emptyMap());
|
||||
// This feature vector should hit the right child of the root node
|
||||
List<Double> featureVector = Arrays.asList(0.6, 0.0);
|
||||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
||||
|
@ -137,6 +163,7 @@ public class TreeInferenceModelTests extends ESTestCase {
|
|||
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
|
||||
xContentRegistry(),
|
||||
TreeInferenceModel::fromXContent);
|
||||
tree.rewriteFeatureIndices(Collections.emptyMap());
|
||||
double eps = 0.000001;
|
||||
// This feature vector should hit the right child of the root node
|
||||
List<Double> featureVector = Arrays.asList(0.6, 0.0);
|
||||
|
@ -211,6 +238,7 @@ public class TreeInferenceModelTests extends ESTestCase {
|
|||
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
|
||||
xContentRegistry(),
|
||||
TreeInferenceModel::fromXContent);
|
||||
tree.rewriteFeatureIndices(Collections.emptyMap());
|
||||
|
||||
double[][] featureImportance = tree.featureImportance(new double[]{0.25, 0.25});
|
||||
assertThat(featureImportance[0][0], closeTo(-5.0, eps));
|
||||
|
|
Loading…
Reference in New Issue