[ML] allow feature_names to be optional in ensemble inference model (#58059) (#58067)

This has `EnsembleInferenceModel` not parse feature_names from the XContent.

Instead, it will rely on `rewriteFeatureIndices` to be called ahead time.

Consequently, protections are made for a fail fast path if `rewriteFeatureIndices` has not been called before `infer`.
This commit is contained in:
Benjamin Trent 2020-06-12 16:33:54 -04:00 committed by GitHub
parent 0ce102a5f4
commit 79c784932f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 98 additions and 19 deletions

View File

@ -41,7 +41,6 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHe
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.FEATURE_NAMES;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TARGET_TYPE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS;
@ -53,14 +52,12 @@ public class EnsembleInferenceModel implements InferenceModel {
private static final ConstructingObjectParser<EnsembleInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
"ensemble_inference_model",
true,
a -> new EnsembleInferenceModel((List<String>)a[0],
(List<InferenceModel>)a[1],
(OutputAggregator)a[2],
TargetType.fromString((String)a[3]),
(List<String>)a[4],
(List<Double>)a[5]));
a -> new EnsembleInferenceModel((List<InferenceModel>)a[0],
(OutputAggregator)a[1],
TargetType.fromString((String)a[2]),
(List<String>)a[3],
(List<Double>)a[4]));
static {
PARSER.declareStringArray(constructorArg(), FEATURE_NAMES);
PARSER.declareNamedObjects(constructorArg(),
(p, c, n) -> p.namedObject(InferenceModel.class, n, null),
(ensembleBuilder) -> {},
@ -77,20 +74,19 @@ public class EnsembleInferenceModel implements InferenceModel {
return PARSER.apply(parser, null);
}
private String[] featureNames;
private String[] featureNames = new String[0];
private final List<InferenceModel> models;
private final OutputAggregator outputAggregator;
private final TargetType targetType;
private final List<String> classificationLabels;
private final double[] classificationWeights;
private volatile boolean preparedForInference = false;
EnsembleInferenceModel(List<String> featureNames,
List<InferenceModel> models,
OutputAggregator outputAggregator,
TargetType targetType,
List<String> classificationLabels,
List<Double> classificationWeights) {
this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]);
private EnsembleInferenceModel(List<InferenceModel> models,
OutputAggregator outputAggregator,
TargetType targetType,
List<String> classificationLabels,
List<Double> classificationWeights) {
this.models = ExceptionsHelper.requireNonNull(models, TRAINED_MODELS);
this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
@ -135,6 +131,10 @@ public class EnsembleInferenceModel implements InferenceModel {
throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
if (preparedForInference == false) {
throw ExceptionsHelper.serverError("model is not prepared for inference");
}
assert featureNames != null && featureNames.length > 0;
double[][] inferenceResults = new double[this.models.size()][];
double[][] featureInfluence = new double[features.length][];
int i = 0;
@ -232,6 +232,10 @@ public class EnsembleInferenceModel implements InferenceModel {
@Override
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
if (preparedForInference) {
return;
}
preparedForInference = true;
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
Set<String> referencedFeatures = subModelFeatures();
int newFeatureIndex = 0;

View File

@ -80,6 +80,7 @@ public class TreeInferenceModel implements InferenceModel {
private final double highOrderCategory;
private final int maxDepth;
private final int leafSize;
private volatile boolean preparedForInference = false;
TreeInferenceModel(List<String> featureNames, List<NodeBuilder> nodes, TargetType targetType, List<String> classificationLabels) {
this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]);
@ -136,6 +137,9 @@ public class TreeInferenceModel implements InferenceModel {
throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
if (preparedForInference == false) {
throw ExceptionsHelper.serverError("model is not prepared for inference");
}
double[][] featureImportance = config.requestingImportance() ?
featureImportance(features) :
new double[0][];
@ -288,6 +292,10 @@ public class TreeInferenceModel implements InferenceModel {
@Override
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
if (preparedForInference) {
return;
}
preparedForInference = true;
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
return;
}

View File

@ -66,7 +66,10 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
public static Ensemble createRandom(TargetType targetType, List<String> featureNames) {
int numberOfModels = randomIntBetween(1, 10);
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
List<String> treeFeatureNames = featureNames.isEmpty() ?
Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList()) :
featureNames;
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(treeFeatureNames, 6))
.limit(numberOfModels)
.collect(Collectors.toList());
double[] weights = randomBoolean() ?

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
@ -15,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
@ -30,19 +32,26 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
public class EnsembleInferenceModelTests extends ESTestCase {
private static final int NUMBER_OF_TEST_RUNS = 20;
private final double eps = 1.0E-8;
public static EnsembleInferenceModel serializeFromTrainedModel(Ensemble ensemble) throws IOException {
NamedXContentRegistry registry = new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return deserializeFromTrainedModel(ensemble,
EnsembleInferenceModel model = deserializeFromTrainedModel(ensemble,
registry,
EnsembleInferenceModel::fromXContent);
model.rewriteFeatureIndices(Collections.emptyMap());
return model;
}
@Override
@ -52,6 +61,27 @@ public class EnsembleInferenceModelTests extends ESTestCase {
return new NamedXContentRegistry(namedXContent);
}
public void testSerializationFromEnsemble() throws Exception {
for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) {
int numberOfFeatures = randomIntBetween(1, 10);
Ensemble ensemble = EnsembleTests.createRandom(randomFrom(TargetType.values()),
randomBoolean() ?
Collections.emptyList() :
Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList()));
assertThat(serializeFromTrainedModel(ensemble), is(not(nullValue())));
}
}
public void testInferenceWithoutPreparing() throws IOException {
Ensemble ensemble = EnsembleTests.createRandom(TargetType.REGRESSION,
Stream.generate(() -> randomAlphaOfLength(10)).limit(4).collect(Collectors.toList()));
EnsembleInferenceModel model = deserializeFromTrainedModel(ensemble,
xContentRegistry(),
EnsembleInferenceModel::fromXContent);
expectThrows(ElasticsearchException.class, () -> model.infer(Collections.emptyMap(), RegressionConfig.EMPTY_PARAMS, null));
}
public void testClassificationProbability() throws IOException {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()
@ -100,6 +130,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
xContentRegistry(),
EnsembleInferenceModel::fromXContent);
ensemble.rewriteFeatureIndices(Collections.emptyMap());
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
@ -205,6 +236,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
xContentRegistry(),
EnsembleInferenceModel::fromXContent);
ensemble.rewriteFeatureIndices(Collections.emptyMap());
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
@ -284,6 +316,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
xContentRegistry(),
EnsembleInferenceModel::fromXContent);
ensemble.rewriteFeatureIndices(Collections.emptyMap());
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
@ -349,6 +382,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
xContentRegistry(),
EnsembleInferenceModel::fromXContent);
ensemble.rewriteFeatureIndices(Collections.emptyMap());
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
@ -373,6 +407,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
ensemble = deserializeFromTrainedModel(ensembleObject,
xContentRegistry(),
EnsembleInferenceModel::fromXContent);
ensemble.rewriteFeatureIndices(Collections.emptyMap());
featureVector = Arrays.asList(0.4, 0.0);
featureMap = zipObjMap(featureNames, featureVector);
@ -466,6 +501,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
EnsembleInferenceModel ensemble = deserializeFromTrainedModel(ensembleObject,
xContentRegistry(),
EnsembleInferenceModel::fromXContent);
ensemble.rewriteFeatureIndices(Collections.emptyMap());
double[][] featureImportance = ensemble.featureImportance(new double[]{0.0, 0.9});
assertThat(featureImportance[0][0], closeTo(-1.653200025, eps));

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
@ -16,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import java.io.IOException;
@ -31,16 +33,22 @@ import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
public class TreeInferenceModelTests extends ESTestCase {
private static final int NUMBER_OF_TEST_RUNS = 20;
private final double eps = 1.0E-8;
public static TreeInferenceModel serializeFromTrainedModel(Tree tree) throws IOException {
NamedXContentRegistry registry = new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return deserializeFromTrainedModel(tree,
TreeInferenceModel model = deserializeFromTrainedModel(tree,
registry,
TreeInferenceModel::fromXContent);
model.rewriteFeatureIndices(Collections.emptyMap());
return model;
}
@Override
@ -50,6 +58,22 @@ public class TreeInferenceModelTests extends ESTestCase {
return new NamedXContentRegistry(namedXContent);
}
public void testSerializationFromEnsemble() throws Exception {
for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) {
Tree tree = TreeTests.createRandom(randomFrom(TargetType.values()));
assertThat(serializeFromTrainedModel(tree), is(not(nullValue())));
}
}
public void testInferenceWithoutPreparing() throws IOException {
Tree tree = TreeTests.createRandom(randomFrom(TargetType.values()));
TreeInferenceModel model = deserializeFromTrainedModel(tree,
xContentRegistry(),
TreeInferenceModel::fromXContent);
expectThrows(ElasticsearchException.class, () -> model.infer(Collections.emptyMap(), RegressionConfig.EMPTY_PARAMS, null));
}
public void testInferWithStump() throws IOException {
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
builder.setRoot(TreeNode.builder(0).setLeafValue(Collections.singletonList(42.0)));
@ -59,6 +83,7 @@ public class TreeInferenceModelTests extends ESTestCase {
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
xContentRegistry(),
TreeInferenceModel::fromXContent);
tree.rewriteFeatureIndices(Collections.emptyMap());
List<String> featureNames = Arrays.asList("foo", "bar");
List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump
@ -82,6 +107,7 @@ public class TreeInferenceModelTests extends ESTestCase {
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
xContentRegistry(),
TreeInferenceModel::fromXContent);
tree.rewriteFeatureIndices(Collections.emptyMap());
// This feature vector should hit the right child of the root node
List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
@ -137,6 +163,7 @@ public class TreeInferenceModelTests extends ESTestCase {
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
xContentRegistry(),
TreeInferenceModel::fromXContent);
tree.rewriteFeatureIndices(Collections.emptyMap());
double eps = 0.000001;
// This feature vector should hit the right child of the root node
List<Double> featureVector = Arrays.asList(0.6, 0.0);
@ -211,6 +238,7 @@ public class TreeInferenceModelTests extends ESTestCase {
TreeInferenceModel tree = deserializeFromTrainedModel(treeObject,
xContentRegistry(),
TreeInferenceModel::fromXContent);
tree.rewriteFeatureIndices(Collections.emptyMap());
double[][] featureImportance = tree.featureImportance(new double[]{0.25, 0.25});
assertThat(featureImportance[0][0], closeTo(-5.0, eps));