diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java index 0be0e8f6c58..d8a00f1d8f6 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java @@ -19,6 +19,9 @@ package org.elasticsearch.client.ml.inference; import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding; +import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression; @@ -61,6 +64,14 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider { new ParseField(LangIdentNeuralNetwork.NAME), LangIdentNeuralNetwork::fromXContent)); + // Inference Config + namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class, + ClassificationConfig.NAME, + ClassificationConfig::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class, + RegressionConfig.NAME, + RegressionConfig::fromXContent)); + // Aggregating output namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class, new ParseField(WeightedMode.NAME), diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java index 1795f5da495..10a739280be 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java @@ -54,4 +54,14 @@ public final class NamedXContentObjectHelper { } return builder; } + + public static XContentBuilder writeNamedObject(XContentBuilder builder, + ToXContent.Params params, + String namedObjectName, + NamedXContentObject namedObject) throws IOException { + builder.startObject(namedObjectName); + builder.field(namedObject.getName(), namedObject, params); + builder.endObject(); + return builder; + } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java index d8749c0a6cd..9d821752240 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -20,6 +20,7 @@ package org.elasticsearch.client.ml.inference; import org.elasticsearch.Version; import org.elasticsearch.client.common.TimeUtil; +import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.unit.ByteSizeValue; @@ -36,6 +37,8 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import static org.elasticsearch.client.ml.inference.NamedXContentObjectHelper.writeNamedObject; + public class TrainedModelConfig implements ToXContentObject { public static final String NAME = "trained_model_config"; @@ -54,6 +57,7 @@ public class TrainedModelConfig implements ToXContentObject { public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations"); public static final ParseField LICENSE_LEVEL = new ParseField("license_level"); public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map"); + public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config"); public static final ObjectParser PARSER = new ObjectParser<>(NAME, true, @@ -78,6 +82,9 @@ public class TrainedModelConfig implements ToXContentObject { PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS); PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL); PARSER.declareObject(TrainedModelConfig.Builder::setDefaultFieldMap, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAP); + PARSER.declareNamedObject(TrainedModelConfig.Builder::setInferenceConfig, + (p, c, n) -> p.namedObject(InferenceConfig.class, n, null), + INFERENCE_CONFIG); } public static TrainedModelConfig fromXContent(XContentParser parser) throws IOException { @@ -98,6 +105,7 @@ public class TrainedModelConfig implements ToXContentObject { private final Long estimatedOperations; private final String licenseLevel; private final Map defaultFieldMap; + private final InferenceConfig inferenceConfig; TrainedModelConfig(String modelId, String createdBy, @@ -112,7 +120,8 @@ public class TrainedModelConfig implements ToXContentObject { Long estimatedHeapMemory, Long estimatedOperations, String licenseLevel, - Map defaultFieldMap) { + Map defaultFieldMap, + InferenceConfig inferenceConfig) { this.modelId = modelId; this.createdBy = createdBy; this.version = version; @@ -127,6 +136,7 @@ public class TrainedModelConfig implements ToXContentObject { this.estimatedOperations = estimatedOperations; this.licenseLevel = licenseLevel; this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap); + this.inferenceConfig = inferenceConfig; } public String getModelId() { @@ -189,6 +199,10 @@ public class TrainedModelConfig implements ToXContentObject { return defaultFieldMap; } + public InferenceConfig getInferenceConfig() { + return inferenceConfig; + } + public static Builder builder() { return new Builder(); } @@ -238,6 +252,9 @@ public class TrainedModelConfig implements ToXContentObject { if (defaultFieldMap != null) { builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap); } + if (inferenceConfig != null) { + writeNamedObject(builder, params, INFERENCE_CONFIG.getPreferredName(), inferenceConfig); + } builder.endObject(); return builder; } @@ -265,6 +282,7 @@ public class TrainedModelConfig implements ToXContentObject { Objects.equals(estimatedOperations, that.estimatedOperations) && Objects.equals(licenseLevel, that.licenseLevel) && Objects.equals(defaultFieldMap, that.defaultFieldMap) && + Objects.equals(inferenceConfig, that.inferenceConfig) && Objects.equals(metadata, that.metadata); } @@ -283,6 +301,7 @@ public class TrainedModelConfig implements ToXContentObject { metadata, licenseLevel, input, + inferenceConfig, defaultFieldMap); } @@ -303,6 +322,7 @@ public class TrainedModelConfig implements ToXContentObject { private Long estimatedOperations; private String licenseLevel; private Map defaultFieldMap; + private InferenceConfig inferenceConfig; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -387,6 +407,11 @@ public class TrainedModelConfig implements ToXContentObject { return this; } + public Builder setInferenceConfig(InferenceConfig inferenceConfig) { + this.inferenceConfig = inferenceConfig; + return this; + } + public TrainedModelConfig build() { return new TrainedModelConfig( modelId, @@ -402,7 +427,8 @@ public class TrainedModelConfig implements ToXContentObject { estimatedHeapMemory, estimatedOperations, licenseLevel, - defaultFieldMap); + defaultFieldMap, + inferenceConfig); } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfig.java new file mode 100644 index 00000000000..87b37434075 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfig.java @@ -0,0 +1,129 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class ClassificationConfig implements InferenceConfig { + + public static final ParseField NAME = new ParseField("classification"); + + public static final ParseField RESULTS_FIELD = new ParseField("results_field"); + public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field"); + public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); + + + private final Integer numTopClasses; + private final String topClassesResultsField; + private final String resultsField; + private final Integer numTopFeatureImportanceValues; + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new ClassificationConfig( + (Integer) args[0], (String) args[1], (String) args[2], (Integer) args[3])); + + static { + PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); + PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD); + PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD); + PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); + } + + public static ClassificationConfig fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public ClassificationConfig() { + this(null, null, null, null); + } + + public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField, Integer featureImportance) { + this.numTopClasses = numTopClasses; + this.topClassesResultsField = topClassesResultsField; + this.resultsField = resultsField; + this.numTopFeatureImportanceValues = featureImportance; + } + + public Integer getNumTopClasses() { + return numTopClasses; + } + + public String getTopClassesResultsField() { + return topClassesResultsField; + } + + public String getResultsField() { + return resultsField; + } + + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassificationConfig that = (ClassificationConfig) o; + return Objects.equals(numTopClasses, that.numTopClasses) + && Objects.equals(topClassesResultsField, that.topClassesResultsField) + && Objects.equals(resultsField, that.resultsField) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); + } + + @Override + public int hashCode() { + return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (numTopClasses != null) { + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); + } + if (topClassesResultsField != null) { + builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField); + } + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } + builder.endObject(); + return builder; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/InferenceConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/InferenceConfig.java new file mode 100644 index 00000000000..0e3e911cd2c --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/InferenceConfig.java @@ -0,0 +1,26 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel; + +import org.elasticsearch.client.ml.inference.NamedXContentObject; + + +public interface InferenceConfig extends NamedXContentObject { + +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfig.java new file mode 100644 index 00000000000..dddd9c5ab48 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfig.java @@ -0,0 +1,104 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class RegressionConfig implements InferenceConfig { + + public static final ParseField NAME = new ParseField("regression"); + public static final ParseField RESULTS_FIELD = new ParseField("results_field"); + public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME.getPreferredName(), + true, + args -> new RegressionConfig((String) args[0], (Integer)args[1])); + + static { + PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD); + PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); + } + + public static RegressionConfig fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final String resultsField; + private final Integer numTopFeatureImportanceValues; + + public RegressionConfig() { + this(null, null); + } + + public RegressionConfig(String resultsField, Integer numTopFeatureImportanceValues) { + this.resultsField = resultsField; + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + } + + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + + public String getResultsField() { + return resultsField; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RegressionConfig that = (RegressionConfig)o; + return Objects.equals(this.resultsField, that.resultsField) + && Objects.equals(this.numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); + } + + @Override + public int hashCode() { + return Objects.hash(resultsField, numTopFeatureImportanceValues); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index d1b8188b100..08ce2d5f5cc 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -158,6 +158,7 @@ import org.elasticsearch.client.ml.inference.TrainedModelDefinition; import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.client.ml.inference.TrainedModelInput; import org.elasticsearch.client.ml.inference.TrainedModelStats; +import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; import org.elasticsearch.client.ml.job.config.AnalysisConfig; @@ -2272,6 +2273,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder() .setDefinition(definition) .setModelId(modelId) + .setInferenceConfig(new RegressionConfig()) .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4"))) .setDescription("test model") .build(); @@ -2285,6 +2287,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { trainedModelConfig = TrainedModelConfig.builder() .setCompressedDefinition(InferenceToXContentCompressor.deflate(definition)) .setModelId(modelIdCompressed) + .setInferenceConfig(new RegressionConfig()) .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4"))) .setDescription("test model") .build(); @@ -2591,6 +2594,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder() .setDefinition(definition) .setModelId(modelId) + .setInferenceConfig(new RegressionConfig()) .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4"))) .setDescription("test model") .build(); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index a93da23328c..e7575b737b8 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -75,6 +75,8 @@ import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding; import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding; +import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode; @@ -699,7 +701,7 @@ public class RestHighLevelClientTests extends ESTestCase { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(62, namedXContents.size()); + assertEquals(64, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -709,7 +711,7 @@ public class RestHighLevelClientTests extends ESTestCase { categories.put(namedXContent.categoryClass, counter + 1); } } - assertEquals("Had: " + categories, 13, categories.size()); + assertEquals("Had: " + categories, 14, categories.size()); assertEquals(Integer.valueOf(3), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); @@ -783,6 +785,9 @@ public class RestHighLevelClientTests extends ESTestCase { assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class)); assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME)); + assertEquals(Integer.valueOf(2), + categories.get(org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig.class)); + assertThat(names, hasItems(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName())); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index ac43ac07a3e..b11804a6884 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -174,6 +174,7 @@ import org.elasticsearch.client.ml.inference.TrainedModelDefinition; import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.client.ml.inference.TrainedModelInput; import org.elasticsearch.client.ml.inference.TrainedModelStats; +import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.AnalysisLimits; @@ -3646,11 +3647,13 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .setDescription("test model") // <5> .setMetadata(new HashMap<>()) // <6> .setTags("my_regression_models") // <7> + .setInferenceConfig(new RegressionConfig("value", 0)) // <8> .build(); // end::put-trained-model-config trainedModelConfig = TrainedModelConfig.builder() .setDefinition(definition) + .setInferenceConfig(new RegressionConfig(null, null)) .setModelId("my-new-trained-model") .setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) .setDescription("test model") @@ -4234,6 +4237,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder() .setDefinition(definition) .setModelId(modelId) + .setInferenceConfig(new RegressionConfig("value", 0)) .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4"))) .setDescription("test model") .build(); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index be4b0443e59..429ab53dcc9 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -19,6 +19,9 @@ package org.elasticsearch.client.ml.inference; import org.elasticsearch.Version; +import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfigTests; +import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfigTests; +import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; @@ -39,13 +42,14 @@ import java.util.stream.Stream; public class TrainedModelConfigTests extends AbstractXContentTestCase { public static TrainedModelConfig createTestTrainedModelConfig() { + TargetType targetType = randomFrom(TargetType.values()); return new TrainedModelConfig( randomAlphaOfLength(10), randomAlphaOfLength(10), Version.CURRENT, randomBoolean() ? null : randomAlphaOfLength(100), Instant.ofEpochMilli(randomNonNegativeLong()), - randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(), + randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder(targetType).build(), randomBoolean() ? null : randomAlphaOfLength(100), randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()), @@ -57,7 +61,10 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase randomAlphaOfLength(10)) .limit(randomIntBetween(1, 10)) - .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10)))); + .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))), + targetType.equals(TargetType.CLASSIFICATION) ? + ClassificationConfigTests.randomClassificationConfig() : + RegressionConfigTests.randomRegressionConfig()); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfigTests.java new file mode 100644 index 00000000000..3dab960fcc3 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfigTests.java @@ -0,0 +1,51 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.inference.trainedmodel; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class ClassificationConfigTests extends AbstractXContentTestCase { + + public static ClassificationConfig randomClassificationConfig() { + return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10), + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomIntBetween(0, 10) + ); + } + + @Override + protected ClassificationConfig createTestInstance() { + return randomClassificationConfig(); + } + + @Override + protected ClassificationConfig doParseInstance(XContentParser parser) throws IOException { + return ClassificationConfig.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfigTests.java new file mode 100644 index 00000000000..4057b961b96 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfigTests.java @@ -0,0 +1,50 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.inference.trainedmodel; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class RegressionConfigTests extends AbstractXContentTestCase { + + public static RegressionConfig randomRegressionConfig() { + return new RegressionConfig( + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomIntBetween(0, 10)); + } + + + @Override + protected RegressionConfig createTestInstance() { + return randomRegressionConfig(); + } + + @Override + protected RegressionConfig doParseInstance(XContentParser parser) throws IOException { + return RegressionConfig.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/docs/java-rest/high-level/ml/put-trained-model.asciidoc b/docs/java-rest/high-level/ml/put-trained-model.asciidoc index 6a0f96a78b9..7637b739940 100644 --- a/docs/java-rest/high-level/ml/put-trained-model.asciidoc +++ b/docs/java-rest/high-level/ml/put-trained-model.asciidoc @@ -39,6 +39,8 @@ include-tagged::{doc-tests-file}[{api}-config] <5> Optionally, a human-readable description <6> Optionally, an object map contain metadata about the model <7> Optionally, an array of tags to organize the model +<8> The default inference config to use with the model. Must match the underlying + definition target_type. include::../execution.asciidoc[] diff --git a/docs/reference/ingest/processors/inference.asciidoc b/docs/reference/ingest/processors/inference.asciidoc index 6066a78feaa..4f5a1d76c7f 100644 --- a/docs/reference/ingest/processors/inference.asciidoc +++ b/docs/reference/ingest/processors/inference.asciidoc @@ -38,44 +38,38 @@ include::common-options.asciidoc[] [[inference-processor-regression-opt]] ==== {regression-cap} configuration options +Regression configuration for inference. + `results_field`:: (Optional, string) -Specifies the field to which the inference prediction is written. Defaults to -`predicted_value`. +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-regression-results-field] -`num_top_feature_importance_values`:::: +`num_top_feature_importance_values`:: (Optional, integer) -Specifies the maximum number of -{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature -importance] values per document. By default, it is zero and no feature importance -calculation occurs. +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-regression-num-top-feature-importance-values] [discrete] [[inference-processor-classification-opt]] ==== {classification-cap} configuration options -`results_field`:: -(Optional, string) -The field that is added to incoming documents to contain the inference prediction. Defaults to -`predicted_value`. +Classification configuration for inference. `num_top_classes`:: -(Optional, integer) -Specifies the number of top class predictions to return. Defaults to 0. +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-classes] + +`num_top_feature_importance_values`:: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values] + +`results_field`:: +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-results-field] `top_classes_results_field`:: -(Optional, string) -Specifies the field to which the top classes are written. Defaults to -`top_classes`. - -`num_top_feature_importance_values`:::: -(Optional, integer) -Specifies the maximum number of -{ml-docs}/dfa-classification.html#dfa-classification-feature-importance[feature -importance] values per document. By default, it is zero and no feature -importance calculation occurs. - +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-top-classes-results-field] [discrete] [[inference-processor-config-example]] @@ -178,4 +172,4 @@ You can also specify a target field as follows: // NOTCONSOLE In this case, {feat-imp} is exposed in the -`my_field.foo.feature_importance` field. \ No newline at end of file +`my_field.foo.feature_importance` field. diff --git a/docs/reference/ml/df-analytics/apis/put-inference.asciidoc b/docs/reference/ml/df-analytics/apis/put-inference.asciidoc index 44ccfc445b8..f49b288c306 100644 --- a/docs/reference/ml/df-analytics/apis/put-inference.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-inference.asciidoc @@ -43,7 +43,7 @@ is not created by {dfanalytics}. (Required, string) include::{docdir}/ml/ml-shared.asciidoc[tag=model-id] - +[role="child_attributes"] [[ml-put-inference-request-body]] ==== {api-request-body-title} @@ -52,32 +52,96 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=model-id] The compressed (GZipped and Base64 encoded) {infer} definition of the model. If `compressed_definition` is specified, then `definition` cannot be specified. +//Begin definition `definition`:: (Required, object) The {infer} definition for the model. If `definition` is specified, then `compressed_definition` cannot be specified. - -`definition`.`preprocessors`::: ++ +.Properties of `definition` +[%collapsible%open] +==== +`preprocessors`::: (Optional, object) Collection of preprocessors. See <> for the full list of available preprocessors. -`definition`.`trained_model`::: +`trained_model`::: (Required, object) The definition of the trained model. See <> for details. +==== +//End definition `description`:: (Optional, string) A human-readable description of the {infer} trained model. +//Begin inference_config +`inference_config`:: +(Required, object) +The default configuration for inference. This can be either a `regression` +or `classification` configuration. It must match the underlying +`definition.trained_model`'s `target_type`. ++ +.Properties of `inference_config` +[%collapsible%open] +==== +`regression`::: +(Optional, object) +Regression configuration for inference. ++ +.Properties of regression inference +[%collapsible%open] +===== +`num_top_feature_importance_values`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-regression-num-top-feature-importance-values] + +`results_field`:::: +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-regression-results-field] +===== + +`classification`::: +(Optional, object) +Classification configuration for inference. ++ +.Properties of classification inference +[%collapsible%open] +===== +`num_top_classes`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-classes] + +`num_top_feature_importance_values`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values] + +`results_field`:::: +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-results-field] + +`top_classes_results_field`:::: +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-top-classes-results-field] +===== +==== +//End of inference_config + +//Begin input `input`:: (Required, object) The input field names for the model definition. - -`input`.`field_names`::: ++ +.Properties of `input` +[%collapsible%open] +==== +`field_names`::: (Required, string) An array of input field names for the model. +==== +//End input `metadata`:: (Optional, object) @@ -87,7 +151,6 @@ An object map that contains metadata about the model. (Optional, string) An array of tags to organize the model. - [[ml-put-inference-preprocessors]] ===== {infer-cap} preprocessor definitions @@ -491,4 +554,4 @@ Example of a `weighted_mode` object: ===== {infer-cap} JSON schema For the full JSON schema of model {infer}, -https://github.com/elastic/ml-json-schemas[click here]. \ No newline at end of file +https://github.com/elastic/ml-json-schemas[click here]. diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index c144df477c4..da1a4d60a8e 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -1213,6 +1213,39 @@ For more information about these options, see <>. -- end::indices-options[] +tag::inference-config-classification-num-top-classes[] +Specifies the number of top class predictions to return. Defaults to 0. +end::inference-config-classification-num-top-classes[] + +tag::inference-config-classification-num-top-feature-importance-values[] +Specifies the maximum number of +{ml-docs}/dfa-classification.html#dfa-classification-feature-importance[feature +importance] values per document. By default, it is zero and no feature +importance calculation occurs. +end::inference-config-classification-num-top-feature-importance-values[] + +tag::inference-config-classification-results-field[] +The field that is added to incoming documents to contain the inference +prediction. Defaults to `predicted_value`. +end::inference-config-classification-results-field[] + +tag::inference-config-classification-top-classes-results-field[] +Specifies the field to which the top classes are written. Defaults to +`top_classes`. +end::inference-config-classification-top-classes-results-field[] + +tag::inference-config-regression-num-top-feature-importance-values[] +Specifies the maximum number of +{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature +importance] values per document. By default, it is zero and no feature importance +calculation occurs. +end::inference-config-regression-num-top-feature-importance-values[] + +tag::inference-config-regression-results-field[] +Specifies the field to which the inference prediction is written. Defaults to +`predicted_value`. +end::inference-config-regression-results-field[] + tag::influencers[] A comma separated list of influencer field names. Typically these can be the by, over, or partition fields that are used in the detector configuration. You might diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java index 556fb2b4f02..82c4bc5f435 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; @@ -13,8 +14,12 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -37,27 +42,30 @@ public class InternalInferModelAction extends ActionType> objectsToInfer; - private final InferenceConfig config; + private final InferenceConfigUpdate update; private final boolean previouslyLicensed; public Request(String modelId, boolean previouslyLicensed) { - this(modelId, Collections.emptyList(), RegressionConfig.EMPTY_PARAMS, previouslyLicensed); + this(modelId, Collections.emptyList(), RegressionConfigUpdate.EMPTY_PARAMS, previouslyLicensed); } public Request(String modelId, List> objectsToInfer, - InferenceConfig inferenceConfig, + InferenceConfigUpdate inferenceConfig, boolean previouslyLicensed) { this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer")); - this.config = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config"); + this.update = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config"); this.previouslyLicensed = previouslyLicensed; } - public Request(String modelId, Map objectToInfer, InferenceConfig config, boolean previouslyLicensed) { + public Request(String modelId, + Map objectToInfer, + InferenceConfigUpdate update, + boolean previouslyLicensed) { this(modelId, Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")), - config, + update, previouslyLicensed); } @@ -65,7 +73,18 @@ public class InternalInferModelAction extends ActionType)in.readNamedWriteable(InferenceConfigUpdate.class); + } else { + InferenceConfig oldConfig = in.readNamedWriteable(InferenceConfig.class); + if (oldConfig instanceof RegressionConfig) { + this.update = RegressionConfigUpdate.fromConfig((RegressionConfig)oldConfig); + } else if (oldConfig instanceof ClassificationConfig) { + this.update = ClassificationConfigUpdate.fromConfig((ClassificationConfig) oldConfig); + } else { + throw new IOException("Unexpected configuration type [" + oldConfig.getName() + "]"); + } + } this.previouslyLicensed = in.readBoolean(); } @@ -77,8 +96,8 @@ public class InternalInferModelAction extends ActionType LENIENT_PARSER = createParser(true); @@ -93,6 +98,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { parser.declareString(TrainedModelConfig.Builder::setLazyDefinition, COMPRESSED_DEFINITION); parser.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL); parser.declareObject(TrainedModelConfig.Builder::setDefaultFieldMap, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAP); + parser.declareNamedObject(TrainedModelConfig.Builder::setInferenceConfig, (p, c, n) -> ignoreUnknownFields ? + p.namedObject(LenientlyParsedInferenceConfig.class, n, null) : + p.namedObject(StrictlyParsedInferenceConfig.class, n, null), + INFERENCE_CONFIG); return parser; } @@ -112,6 +121,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { private final long estimatedOperations; private final License.OperationMode licenseLevel; private final Map defaultFieldMap; + private final InferenceConfig inferenceConfig; private final LazyModelDefinition definition; @@ -127,7 +137,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { Long estimatedHeapMemory, Long estimatedOperations, String licenseLevel, - Map defaultFieldMap) { + Map defaultFieldMap, + InferenceConfig inferenceConfig) { this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY); this.version = ExceptionsHelper.requireNonNull(version, VERSION); @@ -148,6 +159,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { this.estimatedOperations = estimatedOperations; this.licenseLevel = License.OperationMode.parse(ExceptionsHelper.requireNonNull(licenseLevel, LICENSE_LEVEL)); this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap); + this.inferenceConfig = inferenceConfig; } public TrainedModelConfig(StreamInput in) throws IOException { @@ -170,6 +182,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { } else { this.defaultFieldMap = null; } + if (in.getVersion().onOrAfter(Version.V_7_8_0)) { + this.inferenceConfig = in.readOptionalNamedWriteable(InferenceConfig.class); + } else { + this.inferenceConfig = null; + } } public String getModelId() { @@ -204,6 +221,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return defaultFieldMap; } + @Nullable + public InferenceConfig getInferenceConfig() { + return inferenceConfig; + } + @Nullable public String getCompressedDefinition() throws IOException { if (definition == null) { @@ -274,6 +296,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { out.writeBoolean(false); } } + if (out.getVersion().onOrAfter(Version.V_7_8_0)) { + out.writeOptionalNamedWriteable(inferenceConfig); + } } @Override @@ -311,6 +336,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { if (defaultFieldMap != null && defaultFieldMap.isEmpty() == false) { builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap); } + if (inferenceConfig != null) { + writeNamedObject(builder, params, INFERENCE_CONFIG.getPreferredName(), inferenceConfig); + } builder.endObject(); return builder; } @@ -337,6 +365,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { Objects.equals(estimatedOperations, that.estimatedOperations) && Objects.equals(licenseLevel, that.licenseLevel) && Objects.equals(defaultFieldMap, that.defaultFieldMap) && + Objects.equals(inferenceConfig, that.inferenceConfig) && Objects.equals(metadata, that.metadata); } @@ -354,6 +383,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { estimatedOperations, input, licenseLevel, + inferenceConfig, defaultFieldMap); } @@ -372,6 +402,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { private LazyModelDefinition definition; private String licenseLevel; private Map defaultFieldMap; + private InferenceConfig inferenceConfig; public Builder() {} @@ -389,6 +420,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { this.estimatedHeapMemory = config.estimatedHeapMemory; this.licenseLevel = config.licenseLevel.description(); this.defaultFieldMap = config.defaultFieldMap == null ? null : new HashMap<>(config.defaultFieldMap); + this.inferenceConfig = config.inferenceConfig; } public Builder setModelId(String modelId) { @@ -512,6 +544,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return this; } + public Builder setInferenceConfig(InferenceConfig inferenceConfig) { + this.inferenceConfig = inferenceConfig; + return this; + } + public Builder validate() { return validate(false); } @@ -530,6 +567,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { if (modelId == null) { validationException = addValidationError("[" + MODEL_ID.getPreferredName() + "] must not be null.", validationException); } + if (inferenceConfig == null && forCreation) { + validationException = addValidationError("[" + INFERENCE_CONFIG.getPreferredName() + "] must not be null.", + validationException); + } if (modelId != null && MlStrings.isValidId(modelId) == false) { validationException = addValidationError(Messages.getMessage(Messages.INVALID_ID, @@ -605,7 +646,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { estimatedHeapMemory == null ? 0 : estimatedHeapMemory, estimatedOperations == null ? 0 : estimatedOperations, licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel, - defaultFieldMap); + defaultFieldMap, + inferenceConfig); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java index 8621786cedd..7c7ef38d9cc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java @@ -5,19 +5,31 @@ */ package org.elasticsearch.xpack.core.ml.inference.persistence; +import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; +import org.elasticsearch.xpack.core.template.TemplateUtils; /** * Class containing the index constants so that the index version, name, and prefix are available to a wider audience. */ public final class InferenceIndexConstants { - public static final String INDEX_VERSION = "000001"; + /** + * version: 7.8.0: + * - adds inference_config definition to trained model config + * + */ + public static final String INDEX_VERSION = "000002"; public static final String INDEX_NAME_PREFIX = ".ml-inference-"; public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*"; public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION; public static final ParseField DOC_TYPE = new ParseField("doc_type"); private InferenceIndexConstants() {} + private static final String MAPPINGS_VERSION_VARIABLE = "xpack.ml.version"; + public static String mapping() { + return TemplateUtils.loadTemplate("/org/elasticsearch/xpack/core/ml/inference_index_mappings.json", + Version.CURRENT.toString(), MAPPINGS_VERSION_VARIABLE); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java index 1aa8c816ccb..82749c27ce9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java @@ -9,24 +9,19 @@ import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; import java.util.Objects; -import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; - -public class ClassificationConfig implements InferenceConfig { +public class ClassificationConfig implements LenientlyParsedInferenceConfig, StrictlyParsedInferenceConfig { public static final ParseField NAME = new ParseField("classification"); public static final String DEFAULT_TOP_CLASSES_RESULTS_FIELD = "top_classes"; - private static final String DEFAULT_RESULTS_FIELD = "predicted_value"; + public static final String DEFAULT_RESULTS_FIELD = "predicted_value"; public static final ParseField RESULTS_FIELD = new ParseField("results_field"); public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); @@ -42,32 +37,27 @@ public class ClassificationConfig implements InferenceConfig { private final String resultsField; private final int numTopFeatureImportanceValues; - public static ClassificationConfig fromMap(Map map) { - Map options = new HashMap<>(map); - Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName()); - String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName()); - String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName()); - Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); + private static final ObjectParser LENIENT_PARSER = createParser(true); + private static final ObjectParser STRICT_PARSER = createParser(false); - if (options.isEmpty() == false) { - throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet()); - } - return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, featureImportance); + private static ObjectParser createParser(boolean lenient) { + ObjectParser parser = new ObjectParser<>( + NAME.getPreferredName(), + lenient, + ClassificationConfig.Builder::new); + parser.declareInt(ClassificationConfig.Builder::setNumTopClasses, NUM_TOP_CLASSES); + parser.declareString(ClassificationConfig.Builder::setResultsField, RESULTS_FIELD); + parser.declareString(ClassificationConfig.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD); + parser.declareInt(ClassificationConfig.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES); + return parser; } - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ClassificationConfig( - (Integer) args[0], (String) args[1], (String) args[2], (Integer) args[3])); - - static { - PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); - PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD); - PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD); - PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); + public static ClassificationConfig fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null).build(); } - public static ClassificationConfig fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); + public static ClassificationConfig fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null).build(); } public ClassificationConfig(Integer numTopClasses) { @@ -150,14 +140,10 @@ public class ClassificationConfig implements InferenceConfig { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - if (numTopClasses != 0) { - builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); - } + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField); builder.field(RESULTS_FIELD.getPreferredName(), resultsField); - if (numTopFeatureImportanceValues > 0) { - builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); - } + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); builder.endObject(); return builder; } @@ -179,7 +165,50 @@ public class ClassificationConfig implements InferenceConfig { @Override public Version getMinimalSupportedVersion() { - return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION; + return requestingImportance() ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION; } + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private Integer numTopClasses; + private String topClassesResultsField; + private String resultsField; + private Integer numTopFeatureImportanceValues; + + Builder() {} + + Builder(ClassificationConfig config) { + this.numTopClasses = config.numTopClasses; + this.topClassesResultsField = config.topClassesResultsField; + this.resultsField = config.resultsField; + this.numTopFeatureImportanceValues = config.numTopFeatureImportanceValues; + } + + public Builder setNumTopClasses(Integer numTopClasses) { + this.numTopClasses = numTopClasses; + return this; + } + + public Builder setTopClassesResultsField(String topClassesResultsField) { + this.topClassesResultsField = topClassesResultsField; + return this; + } + + public Builder setResultsField(String resultsField) { + this.resultsField = resultsField; + return this; + } + + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + + public ClassificationConfig build() { + return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, numTopFeatureImportanceValues); + } + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java new file mode 100644 index 00000000000..2df5678eb5e --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java @@ -0,0 +1,235 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.NUM_TOP_CLASSES; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.RESULTS_FIELD; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.TOP_CLASSES_RESULTS_FIELD; + +public class ClassificationConfigUpdate implements InferenceConfigUpdate { + + public static final ParseField NAME = new ParseField("classification"); + + public static ClassificationConfigUpdate EMPTY_PARAMS = + new ClassificationConfigUpdate(null, null, null, null); + + private final Integer numTopClasses; + private final String topClassesResultsField; + private final String resultsField; + private final Integer numTopFeatureImportanceValues; + + public static ClassificationConfigUpdate fromMap(Map map) { + Map options = new HashMap<>(map); + Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName()); + String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName()); + String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName()); + Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); + + if (options.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet()); + } + return new ClassificationConfigUpdate(numTopClasses, resultsField, topClassesResultsField, featureImportance); + } + + public static ClassificationConfigUpdate fromConfig(ClassificationConfig config) { + return new ClassificationConfigUpdate(config.getNumTopClasses(), + config.getResultsField(), + config.getTopClassesResultsField(), + config.getNumTopFeatureImportanceValues()); + } + + private static final ObjectParser STRICT_PARSER = createParser(false); + + private static ObjectParser createParser(boolean lenient) { + ObjectParser parser = new ObjectParser<>( + NAME.getPreferredName(), + lenient, + ClassificationConfigUpdate.Builder::new); + parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopClasses, NUM_TOP_CLASSES); + parser.declareString(ClassificationConfigUpdate.Builder::setResultsField, RESULTS_FIELD); + parser.declareString(ClassificationConfigUpdate.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD); + parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES); + return parser; + } + + public static ClassificationConfigUpdate fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null).build(); + } + + public ClassificationConfigUpdate(Integer numTopClasses, + String resultsField, + String topClassesResultsField, + Integer featureImportance) { + this.numTopClasses = numTopClasses; + this.topClassesResultsField = topClassesResultsField; + this.resultsField = resultsField; + if (featureImportance != null && featureImportance < 0) { + throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() + + "] must be greater than or equal to 0"); + } + this.numTopFeatureImportanceValues = featureImportance; + } + + public ClassificationConfigUpdate(StreamInput in) throws IOException { + this.numTopClasses = in.readOptionalInt(); + this.topClassesResultsField = in.readOptionalString(); + this.resultsField = in.readOptionalString(); + this.numTopFeatureImportanceValues = in.readOptionalVInt(); + } + + public Integer getNumTopClasses() { + return numTopClasses; + } + + public String getTopClassesResultsField() { + return topClassesResultsField; + } + + public String getResultsField() { + return resultsField; + } + + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalInt(numTopClasses); + out.writeOptionalString(topClassesResultsField); + out.writeOptionalString(resultsField); + out.writeOptionalVInt(numTopFeatureImportanceValues); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassificationConfigUpdate that = (ClassificationConfigUpdate) o; + return Objects.equals(numTopClasses, that.numTopClasses) + && Objects.equals(topClassesResultsField, that.topClassesResultsField) + && Objects.equals(resultsField, that.resultsField) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); + } + + @Override + public int hashCode() { + return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (numTopClasses != null) { + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); + } + if (topClassesResultsField != null) { + builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField); + } + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public ClassificationConfig apply(ClassificationConfig originalConfig) { + if (isNoop(originalConfig)) { + return originalConfig; + } + ClassificationConfig.Builder builder = new ClassificationConfig.Builder(originalConfig); + if (resultsField != null) { + builder.setResultsField(resultsField); + } + if (numTopFeatureImportanceValues != null) { + builder.setNumTopFeatureImportanceValues(numTopFeatureImportanceValues); + } + if (topClassesResultsField != null) { + builder.setTopClassesResultsField(topClassesResultsField); + } + if (numTopClasses != null) { + builder.setNumTopClasses(numTopClasses); + } + return builder.build(); + } + + @Override + public InferenceConfig toConfig() { + return apply(ClassificationConfig.EMPTY_PARAMS); + } + + @Override + public boolean isSupported(InferenceConfig inferenceConfig) { + return inferenceConfig instanceof ClassificationConfig; + } + + boolean isNoop(ClassificationConfig originalConfig) { + return (resultsField == null || resultsField.equals(originalConfig.getResultsField())) + && (numTopFeatureImportanceValues == null + || originalConfig.getNumTopFeatureImportanceValues() == numTopFeatureImportanceValues) + && (topClassesResultsField == null || topClassesResultsField.equals(originalConfig.getTopClassesResultsField())) + && (numTopClasses == null || originalConfig.getNumTopClasses() == numTopClasses); + } + + public static class Builder { + private Integer numTopClasses; + private String topClassesResultsField; + private String resultsField; + private Integer numTopFeatureImportanceValues; + + public Builder setNumTopClasses(int numTopClasses) { + this.numTopClasses = numTopClasses; + return this; + } + + public Builder setTopClassesResultsField(String topClassesResultsField) { + this.topClassesResultsField = topClassesResultsField; + return this; + } + + public Builder setResultsField(String resultsField) { + this.resultsField = resultsField; + return this; + } + + public Builder setNumTopFeatureImportanceValues(int numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + + public ClassificationConfigUpdate build() { + return new ClassificationConfigUpdate(numTopClasses, resultsField, topClassesResultsField, numTopFeatureImportanceValues); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigUpdate.java new file mode 100644 index 00000000000..72c8cff3c24 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigUpdate.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + + +public interface InferenceConfigUpdate extends NamedXContentObject, NamedWriteable { + + T apply(T originalConfig); + + InferenceConfig toConfig(); + + boolean isSupported(InferenceConfig config); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LenientlyParsedInferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LenientlyParsedInferenceConfig.java new file mode 100644 index 00000000000..501a9dff1ad --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LenientlyParsedInferenceConfig.java @@ -0,0 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +public interface LenientlyParsedInferenceConfig extends InferenceConfig { +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java index 4c8244c734c..ad5cfa4244e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java @@ -9,48 +9,42 @@ import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; import java.util.Objects; -import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; - -public class RegressionConfig implements InferenceConfig { +public class RegressionConfig implements LenientlyParsedInferenceConfig, StrictlyParsedInferenceConfig { public static final ParseField NAME = new ParseField("regression"); private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0; public static final ParseField RESULTS_FIELD = new ParseField("results_field"); public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); - private static final String DEFAULT_RESULTS_FIELD = "predicted_value"; + public static final String DEFAULT_RESULTS_FIELD = "predicted_value"; public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD, null); - public static RegressionConfig fromMap(Map map) { - Map options = new HashMap<>(map); - String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName()); - Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); - if (options.isEmpty() == false) { - throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet()); - } - return new RegressionConfig(resultsField, featureImportance); + private static final ObjectParser LENIENT_PARSER = createParser(true); + private static final ObjectParser STRICT_PARSER = createParser(false); + + private static ObjectParser createParser(boolean lenient) { + ObjectParser parser = new ObjectParser<>( + NAME.getPreferredName(), + lenient, + RegressionConfig.Builder::new); + parser.declareString(RegressionConfig.Builder::setResultsField, RESULTS_FIELD); + parser.declareInt(RegressionConfig.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES); + return parser; } - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0], (Integer)args[1])); - - static { - PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD); - PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); + public static RegressionConfig fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null).build(); } - public static RegressionConfig fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); + public static RegressionConfig fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null).build(); } private final String resultsField; @@ -113,9 +107,7 @@ public class RegressionConfig implements InferenceConfig { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(RESULTS_FIELD.getPreferredName(), resultsField); - if (numTopFeatureImportanceValues > 0) { - builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); - } + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); builder.endObject(); return builder; } @@ -141,7 +133,36 @@ public class RegressionConfig implements InferenceConfig { @Override public Version getMinimalSupportedVersion() { - return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION; + return requestingImportance() ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION; } + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String resultsField; + private Integer numTopFeatureImportanceValues; + + Builder() {} + + Builder(RegressionConfig config) { + this.resultsField = config.resultsField; + this.numTopFeatureImportanceValues = config.numTopFeatureImportanceValues; + } + + public Builder setResultsField(String resultsField) { + this.resultsField = resultsField; + return this; + } + + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + + public RegressionConfig build() { + return new RegressionConfig(resultsField, numTopFeatureImportanceValues); + } + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdate.java new file mode 100644 index 00000000000..b9def23de3b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdate.java @@ -0,0 +1,178 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.DEFAULT_RESULTS_FIELD; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.RESULTS_FIELD; + +public class RegressionConfigUpdate implements InferenceConfigUpdate { + + public static final ParseField NAME = new ParseField("regression"); + + public static RegressionConfigUpdate EMPTY_PARAMS = new RegressionConfigUpdate(null, null); + + public static RegressionConfigUpdate fromMap(Map map) { + Map options = new HashMap<>(map); + String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName()); + Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); + if (options.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet()); + } + return new RegressionConfigUpdate(resultsField, featureImportance); + } + + public static RegressionConfigUpdate fromConfig(RegressionConfig config) { + return new RegressionConfigUpdate(config.getResultsField(), config.getNumTopFeatureImportanceValues()); + } + + private static final ObjectParser STRICT_PARSER = createParser(false); + + private static ObjectParser createParser(boolean lenient) { + ObjectParser parser = new ObjectParser<>( + NAME.getPreferredName(), + lenient, + RegressionConfigUpdate.Builder::new); + parser.declareString(RegressionConfigUpdate.Builder::setResultsField, RESULTS_FIELD); + parser.declareInt(RegressionConfigUpdate.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES); + return parser; + } + + public static RegressionConfigUpdate fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null).build(); + } + + private final String resultsField; + private final Integer numTopFeatureImportanceValues; + + public RegressionConfigUpdate(String resultsField, Integer numTopFeatureImportanceValues) { + this.resultsField = resultsField; + if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) { + throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() + + "] must be greater than or equal to 0"); + } + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + } + + public RegressionConfigUpdate(StreamInput in) throws IOException { + this.resultsField = in.readOptionalString(); + this.numTopFeatureImportanceValues = in.readOptionalVInt(); + } + + public int getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues == null ? 0 : numTopFeatureImportanceValues; + } + + public String getResultsField() { + return resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField; + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(resultsField); + out.writeOptionalVInt(numTopFeatureImportanceValues); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RegressionConfigUpdate that = (RegressionConfigUpdate)o; + return Objects.equals(this.resultsField, that.resultsField) + && Objects.equals(this.numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); + } + + @Override + public int hashCode() { + return Objects.hash(resultsField, numTopFeatureImportanceValues); + } + + @Override + public RegressionConfig apply(RegressionConfig originalConfig) { + if (isNoop(originalConfig)) { + return originalConfig; + } + RegressionConfig.Builder builder = new RegressionConfig.Builder(originalConfig); + if (resultsField != null) { + builder.setResultsField(resultsField); + } + if (numTopFeatureImportanceValues != null) { + builder.setNumTopFeatureImportanceValues(numTopFeatureImportanceValues); + } + return builder.build(); + } + + @Override + public InferenceConfig toConfig() { + return apply(RegressionConfig.EMPTY_PARAMS); + } + + @Override + public boolean isSupported(InferenceConfig inferenceConfig) { + return inferenceConfig instanceof RegressionConfig; + } + + boolean isNoop(RegressionConfig originalConfig) { + return (resultsField == null || originalConfig.getResultsField().equals(resultsField)) + && (numTopFeatureImportanceValues == null + || originalConfig.getNumTopFeatureImportanceValues() == numTopFeatureImportanceValues); + } + + public static class Builder { + private String resultsField; + private Integer numTopFeatureImportanceValues; + + public Builder setResultsField(String resultsField) { + this.resultsField = resultsField; + return this; + } + + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + + public RegressionConfigUpdate build() { + return new RegressionConfigUpdate(resultsField, numTopFeatureImportanceValues); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/StrictlyParsedInferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/StrictlyParsedInferenceConfig.java new file mode 100644 index 00000000000..e6ea5129b31 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/StrictlyParsedInferenceConfig.java @@ -0,0 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +public interface StrictlyParsedInferenceConfig extends InferenceConfig { +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java index a7a6d22ae3e..d4906173674 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java @@ -41,4 +41,14 @@ public final class NamedXContentObjectHelper { } return builder; } + + public static XContentBuilder writeNamedObject(XContentBuilder builder, + ToXContent.Params params, + String namedObjectName, + NamedXContentObject namedObject) throws IOException { + builder.startObject(namedObjectName); + builder.field(namedObject.getName(), namedObject, params); + builder.endObject(); + return builder; + } } diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json index 2a1131ca557..a77a0119e95 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json @@ -2,7 +2,7 @@ "order" : 0, "version" : ${xpack.ml.version.id}, "index_patterns" : [ - ".ml-inference-000001" + ".ml-inference-000002" ], "settings" : { "index" : { @@ -67,6 +67,9 @@ }, "default_field_map": { "enabled": false + }, + "inference_config": { + "enabled": false } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCSerializationTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCSerializationTestCase.java index f34f8fc008f..01351f055f1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCSerializationTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCSerializationTestCase.java @@ -11,22 +11,12 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; -import java.util.Collections; import java.util.List; -import java.util.stream.Collectors; -import static org.elasticsearch.Version.getDeclaredVersions; +import static org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase.DEFAULT_BWC_VERSIONS; public abstract class AbstractBWCSerializationTestCase extends AbstractSerializingTestCase { - private static final List ALL_VERSIONS = Collections.unmodifiableList(getDeclaredVersions(Version.class)); - - public static List getAllBWCVersions(Version version) { - return ALL_VERSIONS.stream().filter(v -> v.before(version) && version.isCompatible(v)).collect(Collectors.toList()); - } - - private static final List DEFAULT_BWC_VERSIONS = getAllBWCVersions(Version.CURRENT); - /** * Returns the expected instance if serialized from the given version. */ diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCWireSerializationTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCWireSerializationTestCase.java new file mode 100644 index 00000000000..300370c3fc4 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCWireSerializationTestCase.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.elasticsearch.Version.getDeclaredVersions; + +public abstract class AbstractBWCWireSerializationTestCase extends AbstractWireSerializingTestCase { + + static final List ALL_VERSIONS = Collections.unmodifiableList(getDeclaredVersions(Version.class)); + + public static List getAllBWCVersions(Version version) { + return ALL_VERSIONS.stream().filter(v -> v.before(version) && version.isCompatible(v)).collect(Collectors.toList()); + } + + static final List DEFAULT_BWC_VERSIONS = getAllBWCVersions(Version.CURRENT); + + /** + * Returns the expected instance if serialized from the given version. + */ + protected abstract T mutateInstanceForVersion(T instance, Version version); + + /** + * The bwc versions to test serialization against + */ + protected List bwcVersions() { + return DEFAULT_BWC_VERSIONS; + } + + /** + * Test serialization and deserialization of the test instance across versions + */ + public final void testBwcSerialization() throws IOException { + for (int runs = 0; runs < NUMBER_OF_TEST_RUNS; runs++) { + T testInstance = createTestInstance(); + for (Version bwcVersion : bwcVersions()) { + assertBwcSerialization(testInstance, bwcVersion); + } + } + } + + /** + * Assert that instances copied at a particular version are equal. The version is useful + * for sanity checking the backwards compatibility of the wire. It isn't a substitute for + * real backwards compatibility tests but it is *so* much faster. + */ + protected final void assertBwcSerialization(T testInstance, Version version) throws IOException { + T deserializedInstance = copyWriteable(testInstance, getNamedWriteableRegistry(), instanceReader(), version); + assertOnBWCObject(deserializedInstance, mutateInstanceForVersion(testInstance, version), version); + } + + /** + * @param bwcSerializedObject The object deserialized from the previous version + * @param testInstance The original test instance + * @param version The version which serialized + */ + protected void assertOnBWCObject(T bwcSerializedObject, T testInstance, Version version) { + assertNotSame(version.toString(), bwcSerializedObject, testInstance); + assertEquals(version.toString(), bwcSerializedObject, testInstance); + assertEquals(version.toString(), bwcSerializedObject.hashCode(), testInstance.hashCode()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java index d743a579f85..80693c94be9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java @@ -5,14 +5,21 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdateTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdateTests; import java.util.ArrayList; import java.util.List; @@ -22,25 +29,27 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class InternalInferModelActionRequestTests extends AbstractWireSerializingTestCase { +public class InternalInferModelActionRequestTests extends AbstractBWCWireSerializationTestCase { @Override + @SuppressWarnings("unchecked") protected Request createTestInstance() { return randomBoolean() ? new Request( randomAlphaOfLength(10), Stream.generate(InternalInferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()), - randomInferenceConfig(), + randomInferenceConfigUpdate(), randomBoolean()) : new Request( randomAlphaOfLength(10), randomMap(), - randomInferenceConfig(), + randomInferenceConfigUpdate(), randomBoolean()); } - private static InferenceConfig randomInferenceConfig() { - return randomFrom(RegressionConfigTests.randomRegressionConfig(), ClassificationConfigTests.randomClassificationConfig()); + private static InferenceConfigUpdate randomInferenceConfigUpdate() { + return randomFrom(RegressionConfigUpdateTests.randomRegressionConfig(), + ClassificationConfigUpdateTests.randomClassificationConfig()); } private static Map randomMap() { @@ -60,4 +69,26 @@ public class InternalInferModelActionRequestTests extends AbstractWireSerializin entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); return new NamedWriteableRegistry(entries); } + + @Override + @SuppressWarnings("unchecked") + protected Request mutateInstanceForVersion(Request instance, Version version) { + if (version.before(Version.V_7_8_0)) { + InferenceConfigUpdate update = null; + if (instance.getUpdate() instanceof ClassificationConfigUpdate) { + update = ClassificationConfigUpdate.fromConfig( + ClassificationConfigTests.mutateForVersion((ClassificationConfig) instance.getUpdate().toConfig(), version)); + } + else if (instance.getUpdate() instanceof RegressionConfigUpdate) { + update = RegressionConfigUpdate.fromConfig( + RegressionConfigTests.mutateForVersion((RegressionConfig) instance.getUpdate().toConfig(), version)); + } + else { + fail("unknown update type " + instance.getUpdate().getName()); + } + return new Request(instance.getModelId(), instance.getObjectsToInfer(), update, instance.isPreviouslyLicensed()); + } + return instance; + } + } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index ace9cad261f..06e004850f7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -21,7 +21,9 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.license.License; import org.elasticsearch.search.SearchModule; -import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.MlStrings; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; @@ -46,7 +48,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.not; -public class TrainedModelConfigTests extends AbstractSerializingTestCase { +public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase { private boolean lenient; @@ -66,6 +68,8 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase randomAlphaOfLength(10)) .limit(randomIntBetween(1, 10)) - .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10)))); + .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))), + randomFrom(ClassificationConfigTests.randomClassificationConfig(), RegressionConfigTests.randomRegressionConfig())); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); assertThat(reference.utf8ToString(), containsString("\"compressed_definition\"")); @@ -182,7 +187,8 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase randomAlphaOfLength(10)) .limit(randomIntBetween(1, 10)) - .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10)))); + .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))), + randomFrom(ClassificationConfigTests.randomClassificationConfig(), RegressionConfigTests.randomRegressionConfig())); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); Map objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2(); @@ -311,4 +317,12 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase { +public class ClassificationConfigTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; public static ClassificationConfig randomClassificationConfig() { return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10), @@ -26,23 +25,17 @@ public class ClassificationConfigTests extends AbstractSerializingTestCase configMap = new HashMap<>(); - configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3); - configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo"); - configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar"); - configMap.put(ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 2); - assertThat(ClassificationConfig.fromMap(configMap), equalTo(expected)); + public static ClassificationConfig mutateForVersion(ClassificationConfig instance, Version version) { + ClassificationConfig.Builder builder = new ClassificationConfig.Builder(instance); + if (version.before(Version.V_7_7_0)) { + builder.setNumTopFeatureImportanceValues(0); + } + return builder.build(); } - public void testFromMapWithUnknownField() { - ElasticsearchException ex = expectThrows(ElasticsearchException.class, - () -> ClassificationConfig.fromMap(Collections.singletonMap("some_key", 1))); - assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); } @Override @@ -57,6 +50,16 @@ public class ClassificationConfigTests extends AbstractSerializingTestCase { + + public static ClassificationConfigUpdate randomClassificationConfig() { + return new ClassificationConfigUpdate(randomBoolean() ? null : randomIntBetween(-1, 10), + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomIntBetween(0, 10) + ); + } + + public void testFromMap() { + ClassificationConfigUpdate expected = new ClassificationConfigUpdate(null, null, null, null); + assertThat(ClassificationConfigUpdate.fromMap(Collections.emptyMap()), equalTo(expected)); + + expected = new ClassificationConfigUpdate(3, "foo", "bar", 2); + Map configMap = new HashMap<>(); + configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3); + configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo"); + configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar"); + configMap.put(ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 2); + assertThat(ClassificationConfigUpdate.fromMap(configMap), equalTo(expected)); + } + + public void testFromMapWithUnknownField() { + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> ClassificationConfigUpdate.fromMap(Collections.singletonMap("some_key", 1))); + assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + } + + @Override + protected ClassificationConfigUpdate createTestInstance() { + return randomClassificationConfig(); + } + + @Override + protected Writeable.Reader instanceReader() { + return ClassificationConfigUpdate::new; + } + + @Override + protected ClassificationConfigUpdate doParseInstance(XContentParser parser) throws IOException { + return ClassificationConfigUpdate.fromXContentStrict(parser); + } + + @Override + protected ClassificationConfigUpdate mutateInstanceForVersion(ClassificationConfigUpdate instance, Version version) { + return instance; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java index ba8b8b62cd5..57dfa18520a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java @@ -5,37 +5,32 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; -import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.junit.Before; import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import static org.hamcrest.Matchers.equalTo; - -public class RegressionConfigTests extends AbstractSerializingTestCase { +public class RegressionConfigTests extends AbstractBWCSerializationTestCase { + private boolean lenient; public static RegressionConfig randomRegressionConfig() { return new RegressionConfig(randomBoolean() ? null : randomAlphaOfLength(10)); } - public void testFromMap() { - RegressionConfig expected = new RegressionConfig("foo", 3); - Map config = new HashMap(){{ - put(RegressionConfig.RESULTS_FIELD.getPreferredName(), "foo"); - put(RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 3); - }}; - assertThat(RegressionConfig.fromMap(config), equalTo(expected)); + public static RegressionConfig mutateForVersion(RegressionConfig instance, Version version) { + RegressionConfig.Builder builder = new RegressionConfig.Builder(instance); + if (version.before(Version.V_7_7_0)) { + builder.setNumTopFeatureImportanceValues(0); + } + return builder.build(); } - public void testFromMapWithUnknownField() { - ElasticsearchException ex = expectThrows(ElasticsearchException.class, - () -> RegressionConfig.fromMap(Collections.singletonMap("some_key", 1))); - assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); } @Override @@ -50,6 +45,16 @@ public class RegressionConfigTests extends AbstractSerializingTestCase { + + public static RegressionConfigUpdate randomRegressionConfig() { + return new RegressionConfigUpdate(randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomIntBetween(0, 10)); + } + + public void testFromMap() { + RegressionConfigUpdate expected = new RegressionConfigUpdate("foo", 3); + Map config = new HashMap(){{ + put(RegressionConfig.RESULTS_FIELD.getPreferredName(), "foo"); + put(RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 3); + }}; + assertThat(RegressionConfigUpdate.fromMap(config), equalTo(expected)); + } + + public void testFromMapWithUnknownField() { + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> RegressionConfigUpdate.fromMap(Collections.singletonMap("some_key", 1))); + assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + } + + @Override + protected RegressionConfigUpdate createTestInstance() { + return randomRegressionConfig(); + } + + @Override + protected Writeable.Reader instanceReader() { + return RegressionConfigUpdate::new; + } + + @Override + protected RegressionConfigUpdate doParseInstance(XContentParser parser) throws IOException { + return RegressionConfigUpdate.fromXContentStrict(parser); + } + + @Override + protected RegressionConfigUpdate mutateInstanceForVersion(RegressionConfigUpdate instance, Version version) { + return instance; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index e6c368fefa4..4e2a5398d79 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -65,6 +65,10 @@ public class EnsembleTests extends AbstractSerializingTestCase { } public static Ensemble createRandom() { + return createRandom(randomFrom(TargetType.values())); + } + + public static Ensemble createRandom(TargetType targetType) { int numberOfFeatures = randomIntBetween(1, 10); List featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList()); int numberOfModels = randomIntBetween(1, 10); @@ -74,7 +78,6 @@ public class EnsembleTests extends AbstractSerializingTestCase { double[] weights = randomBoolean() ? null : Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).mapToDouble(Double::valueOf).toArray(); - TargetType targetType = randomFrom(TargetType.values()); List categoryLabels = null; if (randomBoolean() && targetType == TargetType.CLASSIFICATION) { categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index d801dec470e..dd418c57cfa 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -64,16 +64,20 @@ public class TreeTests extends AbstractSerializingTestCase { return createRandom(); } - public static Tree createRandom() { + public static Tree createRandom(TargetType targetType) { int numberOfFeatures = randomIntBetween(1, 10); List featureNames = new ArrayList<>(); for (int i = 0; i < numberOfFeatures; i++) { featureNames.add(randomAlphaOfLength(10)); } - return buildRandomTree(featureNames, 6); + return buildRandomTree(targetType, featureNames, 6); } - public static Tree buildRandomTree(List featureNames, int depth) { + public static Tree createRandom() { + return createRandom(randomFrom(TargetType.values())); + } + + public static Tree buildRandomTree(TargetType targetType, List featureNames, int depth) { Tree.Builder builder = Tree.builder(); int maxFeatureIndex = featureNames.size() - 1; builder.setFeatureNames(featureNames); @@ -96,7 +100,6 @@ public class TreeTests extends AbstractSerializingTestCase { } childNodes = nextNodes; } - TargetType targetType = randomFrom(TargetType.values()); List categoryLabels = null; if (randomBoolean() && targetType == TargetType.CLASSIFICATION) { categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); @@ -105,6 +108,10 @@ public class TreeTests extends AbstractSerializingTestCase { return builder.setTargetType(targetType).setClassificationLabels(categoryLabels).build(); } + public static Tree buildRandomTree(List featureNames, int depth) { + return buildRandomTree(randomFrom(TargetType.values()), featureNames, depth); + } + @Override protected Writeable.Reader instanceReader() { return Tree::new; diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 1863bb096ee..5eebec9559d 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -141,9 +141,11 @@ integTest.runner { 'ml/inference_crud/Test put ensemble with empty models', 'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index', 'ml/inference_crud/Test put model with empty input.field_names', + 'ml/inference_crud/Test PUT model where target type and inference config mismatch', 'ml/inference_processor/Test create processor with missing mandatory fields', 'ml/inference_processor/Test create and delete pipeline with inference processor', 'ml/inference_processor/Test create processor with deprecated fields', + 'ml/inference_processor/Test simulate', 'ml/inference_stats_crud/Test get stats given missing trained model', 'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false', 'ml/jobs_crud/Test cannot create job with existing categorizer state document', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index 12a0f158b34..2f79a540e08 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -424,6 +424,7 @@ public class InferenceIngestIT extends ESRestTestCase { private static final String REGRESSION_CONFIG = "{" + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"description\": \"test model for regression\",\n" + + " \"inference_config\": {\"regression\": {}},\n" + " \"definition\": " + REGRESSION_DEFINITION + "}"; @@ -564,6 +565,7 @@ public class InferenceIngestIT extends ESRestTestCase { " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"description\": \"test model for classification\",\n" + " \"default_field_map\": {\"col_1_alias\": \"col1\"},\n" + + " \"inference_config\": {\"classification\": {}},\n" + " \"definition\": " + CLASSIFICATION_DEFINITION + "}"; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index afa5f1c6dcb..db27ce4aeeb 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -12,6 +12,7 @@ import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.ml.inference.TrainedModelConfig; import org.elasticsearch.client.ml.inference.TrainedModelDefinition; import org.elasticsearch.client.ml.inference.TrainedModelInput; +import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; @@ -193,6 +194,7 @@ public class TrainedModelIT extends ESRestTestCase { .setTrainedModel(buildRegression()); TrainedModelConfig.builder() .setDefinition(definition) + .setInferenceConfig(new RegressionConfig()) .setModelId(modelId) .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3"))) .build().toXContent(builder, ToXContent.EMPTY_PARAMS); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index a32367a6766..40a74d7405a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.ml.inference.loadingservice.Model; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -48,11 +49,12 @@ public class TransportInternalInferModelAction extends HandledTransportAction listener) { Response.Builder responseBuilder = Response.builder(); - ActionListener getModelListener = ActionListener.wrap( + ActionListener> getModelListener = ActionListener.wrap( model -> { TypedChainTaskExecutor typedChainTaskExecutor = new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME), @@ -62,7 +64,9 @@ public class TransportInternalInferModelAction extends HandledTransportAction true); request.getObjectsToInfer().forEach(stringObjectMap -> typedChainTaskExecutor.add(chainedTask -> - model.infer(stringObjectMap, request.getConfig(), chainedTask))); + // The InferenceConfigUpdate here is unchecked, initially. + // It gets checked when it is applied + model.infer(stringObjectMap, request.getUpdate(), chainedTask))); typedChainTaskExecutor.execute(ActionListener.wrap( inferenceResultsInterfaces -> diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index fc6d5b92f36..503393a836a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -93,6 +93,22 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction inferenceConfig; private final Map fieldMap; private final InferenceAuditor auditor; private volatile boolean previouslyLicensed; @@ -82,7 +85,7 @@ public class InferenceProcessor extends AbstractProcessor { String tag, String targetField, String modelId, - InferenceConfig inferenceConfig, + InferenceConfigUpdate inferenceConfig, Map fieldMap) { super(tag); this.client = ExceptionsHelper.requireNonNull(client, "client"); @@ -245,7 +248,8 @@ public class InferenceProcessor extends AbstractProcessor { LoggingDeprecationHandler.INSTANCE.usedDeprecatedName(null, () -> null, FIELD_MAPPINGS, FIELD_MAP); } } - InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG)); + InferenceConfigUpdate inferenceConfig = + inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG)); return new InferenceProcessor(client, auditor, @@ -262,7 +266,7 @@ public class InferenceProcessor extends AbstractProcessor { this.maxIngestProcessors = maxIngestProcessors; } - InferenceConfig inferenceConfigFromMap(Map inferenceConfig) { + InferenceConfigUpdate inferenceConfigFromMap(Map inferenceConfig) { ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); if (inferenceConfig.size() != 1) { throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.", @@ -279,12 +283,12 @@ public class InferenceProcessor extends AbstractProcessor { if (inferenceConfig.containsKey(ClassificationConfig.NAME.getPreferredName())) { checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS); - ClassificationConfig config = ClassificationConfig.fromMap(valueMap); + ClassificationConfigUpdate config = ClassificationConfigUpdate.fromMap(valueMap); checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField()); return config; } else if (inferenceConfig.containsKey(RegressionConfig.NAME.getPreferredName())) { checkSupportedVersion(RegressionConfig.EMPTY_PARAMS); - RegressionConfig config = RegressionConfig.fromMap(valueMap); + RegressionConfigUpdate config = RegressionConfigUpdate.fromMap(valueMap); checkFieldUniqueness(config.getResultsField()); return config; } else { @@ -298,6 +302,9 @@ public class InferenceProcessor extends AbstractProcessor { Set duplicatedFieldNames = new HashSet<>(); Set currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES); for(String fieldName : fieldNames) { + if (fieldName == null) { + continue; + } if (currentFieldNames.contains(fieldName)) { duplicatedFieldNames.add(fieldName); } else { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index f77188853ad..fa436e0c53f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -10,6 +10,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; @@ -24,21 +25,24 @@ import java.util.Set; import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING; -public class LocalModel implements Model { +public class LocalModel implements Model { private final TrainedModelDefinition trainedModelDefinition; private final String modelId; private final Set fieldNames; private final Map defaultFieldMap; + private final T inferenceConfig; public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition, TrainedModelInput input, - Map defaultFieldMap) { + Map defaultFieldMap, + T modelInferenceConfig) { this.trainedModelDefinition = trainedModelDefinition; this.modelId = modelId; this.fieldNames = new HashSet<>(input.getFieldNames()); this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap); + this.inferenceConfig = modelInferenceConfig; } long ramBytesUsed() { @@ -65,7 +69,15 @@ public class LocalModel implements Model { } @Override - public void infer(Map fields, InferenceConfig config, ActionListener listener) { + public void infer(Map fields, InferenceConfigUpdate update, ActionListener listener) { + if (update.isSupported(this.inferenceConfig) == false) { + listener.onFailure(ExceptionsHelper.badRequestException( + "Model [{}] has inference config of type [{}] which is not supported by inference request of type [{}]", + this.modelId, + this.inferenceConfig.getName(), + update.getName())); + return; + } try { Model.mapFieldsIfNecessary(fields, defaultFieldMap); if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) { @@ -73,7 +85,7 @@ public class LocalModel implements Model { return; } - listener.onResponse(trainedModelDefinition.infer(fields, config)); + listener.onResponse(trainedModelDefinition.infer(fields, update.apply(inferenceConfig))); } catch (Exception e) { listener.onFailure(e); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java index f1df6908d63..c11d735e26a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -8,15 +8,16 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.ActionListener; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.MapHelper; import java.util.Map; -public interface Model { +public interface Model { String getResultsType(); - void infer(Map fields, InferenceConfig inferenceConfig, ActionListener listener); + void infer(Map fields, InferenceConfigUpdate inferenceConfig, ActionListener listener); String getModelId(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 71cfe710821..b8ecfa424f6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -27,6 +27,11 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -78,9 +83,9 @@ public class ModelLoadingService implements ClusterStateListener { Setting.Property.NodeScope); private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); - private final Cache localModelCache; + private final Cache> localModelCache; private final Set referencedModels = new HashSet<>(); - private final Map>> loadingListeners = new HashMap<>(); + private final Map>>> loadingListeners = new HashMap<>(); private final TrainedModelProvider provider; private final Set shouldNotAudit; private final ThreadPool threadPool; @@ -100,7 +105,7 @@ public class ModelLoadingService implements ClusterStateListener { this.auditor = auditor; this.shouldNotAudit = new HashSet<>(); this.namedXContentRegistry = namedXContentRegistry; - this.localModelCache = CacheBuilder.builder() + this.localModelCache = CacheBuilder.>builder() .setMaximumWeight(this.maxCacheSize.getBytes()) .weigher((id, localModel) -> localModel.ramBytesUsed()) .removalListener(this::cacheEvictionListener) @@ -126,8 +131,8 @@ public class ModelLoadingService implements ClusterStateListener { * @param modelId the model to get * @param modelActionListener the listener to alert when the model has been retrieved. */ - public void getModel(String modelId, ActionListener modelActionListener) { - LocalModel cachedModel = localModelCache.get(modelId); + public void getModel(String modelId, ActionListener> modelActionListener) { + LocalModel cachedModel = localModelCache.get(modelId); if (cachedModel != null) { modelActionListener.onResponse(cachedModel); logger.trace("[{}] loaded from cache", modelId); @@ -138,12 +143,18 @@ public class ModelLoadingService implements ClusterStateListener { // by a simulated pipeline logger.trace("[{}] not actively loading, eager loading without cache", modelId); provider.getTrainedModel(modelId, true, ActionListener.wrap( - trainedModelConfig -> - modelActionListener.onResponse(new LocalModel( + trainedModelConfig -> { + trainedModelConfig.ensureParsedDefinition(namedXContentRegistry); + InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? + inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) : + trainedModelConfig.getInferenceConfig(); + modelActionListener.onResponse(new LocalModel<>( trainedModelConfig.getModelId(), - trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(), + trainedModelConfig.getModelDefinition(), trainedModelConfig.getInput(), - trainedModelConfig.getDefaultFieldMap())), + trainedModelConfig.getDefaultFieldMap(), + inferenceConfig)); + }, modelActionListener::onFailure )); } else { @@ -156,9 +167,9 @@ public class ModelLoadingService implements ClusterStateListener { * Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded * Returns false if the model is not loaded or actively being loaded */ - private boolean loadModelIfNecessary(String modelId, ActionListener modelActionListener) { + private boolean loadModelIfNecessary(String modelId, ActionListener> modelActionListener) { synchronized (loadingListeners) { - Model cachedModel = localModelCache.get(modelId); + Model cachedModel = localModelCache.get(modelId); if (cachedModel != null) { modelActionListener.onResponse(cachedModel); return true; @@ -197,12 +208,17 @@ public class ModelLoadingService implements ClusterStateListener { } private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelConfig) throws IOException { - Queue> listeners; - LocalModel loadedModel = new LocalModel( + Queue>> listeners; + trainedModelConfig.ensureParsedDefinition(namedXContentRegistry); + InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? + inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) : + trainedModelConfig.getInferenceConfig(); + LocalModel loadedModel = new LocalModel<>( trainedModelConfig.getModelId(), - trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(), + trainedModelConfig.getModelDefinition(), trainedModelConfig.getInput(), - trainedModelConfig.getDefaultFieldMap()); + trainedModelConfig.getDefaultFieldMap(), + inferenceConfig); synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); // If there is no loadingListener that means the loading was canceled and the listener was already notified as such @@ -213,13 +229,13 @@ public class ModelLoadingService implements ClusterStateListener { localModelCache.put(modelId, loadedModel); shouldNotAudit.remove(modelId); } // synchronized (loadingListeners) - for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + for (ActionListener> listener = listeners.poll(); listener != null; listener = listeners.poll()) { listener.onResponse(loadedModel); } } private void handleLoadFailure(String modelId, Exception failure) { - Queue> listeners; + Queue>> listeners; synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); if (listeners == null) { @@ -228,12 +244,12 @@ public class ModelLoadingService implements ClusterStateListener { } // synchronized (loadingListeners) // If we failed to load and there were listeners present, that means that this model is referenced by a processor // Alert the listeners to the failure - for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + for (ActionListener> listener = listeners.poll(); listener != null; listener = listeners.poll()) { listener.onFailure(failure); } } - private void cacheEvictionListener(RemovalNotification notification) { + private void cacheEvictionListener(RemovalNotification> notification) { if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) { String msg = new ParameterizedMessage( "model cache entry evicted." + @@ -263,7 +279,7 @@ public class ModelLoadingService implements ClusterStateListener { return; } // The listeners still waiting for a model and we are canceling the load? - List>>> drainWithFailure = new ArrayList<>(); + List>>>> drainWithFailure = new ArrayList<>(); Set referencedModelsBeforeClusterState = null; Set loadingModelBeforeClusterState = null; Set removedModels = null; @@ -306,11 +322,11 @@ public class ModelLoadingService implements ClusterStateListener { referencedModels); } } - for (Tuple>> modelAndListeners : drainWithFailure) { + for (Tuple>>> modelAndListeners : drainWithFailure) { final String msg = new ParameterizedMessage( "Cancelling load of model [{}] as it is no longer referenced by a pipeline", modelAndListeners.v1()).getFormat(); - for (ActionListener listener : modelAndListeners.v2()) { + for (ActionListener> listener : modelAndListeners.v2()) { listener.onFailure(new ElasticsearchException(msg)); } } @@ -379,4 +395,14 @@ public class ModelLoadingService implements ClusterStateListener { return allReferencedModelKeys; } + private static InferenceConfig inferenceConfigFromTargetType(TargetType targetType) { + switch(targetType) { + case REGRESSION: + return RegressionConfig.EMPTY_PARAMS; + case CLASSIFICATION: + return ClassificationConfig.EMPTY_PARAMS; + default: + throw ExceptionsHelper.badRequestException("unsupported target type [{}]", targetType); + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java index 23e32174361..80620c8eed2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java @@ -45,6 +45,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; @@ -694,7 +695,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase { client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request( modelId, Collections.singletonList(Collections.emptyMap()), - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, false ), inferModelSuccess); InternalInferModelAction.Response response = inferModelSuccess.actionGet(); @@ -711,7 +712,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase { client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request( modelId, Collections.singletonList(Collections.emptyMap()), - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, false )).actionGet(); }); @@ -724,7 +725,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase { client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request( modelId, Collections.singletonList(Collections.emptyMap()), - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, true ), inferModelSuccess); response = inferModelSuccess.actionGet(); @@ -740,7 +741,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase { client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request( modelId, Collections.singletonList(Collections.emptyMap()), - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, false ), listener); assertThat(listener.actionGet().getInferenceResults(), is(not(empty()))); @@ -760,6 +761,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase { .setModelId(modelId) .setDescription("test model for classification") .setInput(new TrainedModelInput(Arrays.asList("feature1"))) + .setInferenceConfig(RegressionConfig.EMPTY_PARAMS) .build(); client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet(); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index ee2b399ef2d..1d9c748f376 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; @@ -168,7 +169,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { extractedFieldList.add(new DocValueField("foo", Collections.emptySet())); extractedFieldList.add(new MultiField("bar", new DocValueField("bar.keyword", Collections.emptySet()))); extractedFieldList.add(new DocValueField("baz", Collections.emptySet())); - TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(); + TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; + TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList); @@ -190,6 +192,11 @@ public class AnalyticsResultProcessorTests extends ESTestCase { assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar.keyword", "baz"))); assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed())); assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations())); + if (targetType.equals(TargetType.CLASSIFICATION)) { + assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification")); + } else { + assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression")); + } Map metadata = storedModel.getMetadata(); assertThat(metadata.size(), equalTo(1)); assertThat(metadata, hasKey("analytics_config")); @@ -213,7 +220,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { return null; }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); - TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(); + TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; + TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 9b6dfd734d0..81b842047ff 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -14,7 +14,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.junit.Before; @@ -53,7 +55,7 @@ public class InferenceProcessorTests extends ESTestCase { "my_processor", targetField, "classification_model", - ClassificationConfig.EMPTY_PARAMS, + ClassificationConfigUpdate.EMPTY_PARAMS, Collections.emptyMap()); Map source = new HashMap<>(); @@ -75,13 +77,14 @@ public class InferenceProcessorTests extends ESTestCase { @SuppressWarnings("unchecked") public void testMutateDocumentClassificationTopNClasses() { - ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null); + ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, null); + ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, null); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, auditor, "my_processor", "ml.my_processor", "classification_model", - classificationConfig, + classificationConfigUpdate, Collections.emptyMap()); Map source = new HashMap<>(); @@ -105,12 +108,13 @@ public class InferenceProcessorTests extends ESTestCase { public void testMutateDocumentClassificationFeatureInfluence() { ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2); + ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, 2); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, auditor, "my_processor", "ml.my_processor", "classification_model", - classificationConfig, + classificationConfigUpdate, Collections.emptyMap()); Map source = new HashMap<>(); @@ -145,12 +149,13 @@ public class InferenceProcessorTests extends ESTestCase { @SuppressWarnings("unchecked") public void testMutateDocumentClassificationTopNClassesWithSpecificField() { ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops"); + ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, "result", "tops", null); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, auditor, "my_processor", "ml.my_processor", "classification_model", - classificationConfig, + classificationConfigUpdate, Collections.emptyMap()); Map source = new HashMap<>(); @@ -174,12 +179,13 @@ public class InferenceProcessorTests extends ESTestCase { public void testMutateDocumentRegression() { RegressionConfig regressionConfig = new RegressionConfig("foo"); + RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", null); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, auditor, "my_processor", "ml.my_processor", "regression_model", - regressionConfig, + regressionConfigUpdate, Collections.emptyMap()); Map source = new HashMap<>(); @@ -196,12 +202,13 @@ public class InferenceProcessorTests extends ESTestCase { public void testMutateDocumentRegressionWithTopFetures() { RegressionConfig regressionConfig = new RegressionConfig("foo", 2); + RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", 2); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, auditor, "my_processor", "ml.my_processor", "regression_model", - regressionConfig, + regressionConfigUpdate, Collections.emptyMap()); Map source = new HashMap<>(); @@ -233,7 +240,7 @@ public class InferenceProcessorTests extends ESTestCase { "my_processor", "my_field", modelId, - new ClassificationConfig(topNClasses, null, null), + new ClassificationConfigUpdate(topNClasses, null, null, null), Collections.emptyMap()); Map source = new HashMap(){{ @@ -262,7 +269,7 @@ public class InferenceProcessorTests extends ESTestCase { "my_processor", "my_field", modelId, - new ClassificationConfig(topNClasses, null, null), + new ClassificationConfigUpdate(topNClasses, null, null, null), fieldMapping); Map source = new HashMap(5){{ @@ -298,7 +305,7 @@ public class InferenceProcessorTests extends ESTestCase { "my_processor", "my_field", modelId, - new ClassificationConfig(topNClasses, null, null), + new ClassificationConfigUpdate(topNClasses, null, null, null), fieldMapping); Map source = new HashMap(5){{ @@ -326,7 +333,7 @@ public class InferenceProcessorTests extends ESTestCase { "my_processor", targetField, "regression_model", - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, Collections.emptyMap()); Map source = new HashMap<>(); @@ -369,7 +376,7 @@ public class InferenceProcessorTests extends ESTestCase { "my_processor", "ml", "regression_model", - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, Collections.emptyMap()); Map source = new HashMap<>(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 1f91f8724a6..a12f2821d97 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -13,8 +13,11 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; @@ -31,7 +34,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.ExecutionException; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.closeTo; @@ -48,22 +50,23 @@ public class LocalModelTests extends ESTestCase { .setTrainedModel(buildClassification(false)) .build(); - Model model = new LocalModel(modelId, + Model model = new LocalModel<>(modelId, definition, new TrainedModelInput(inputFields), - Collections.singletonMap("field.foo", "field.foo.keyword")); + Collections.singletonMap("field.foo", "field.foo.keyword"), + ClassificationConfig.EMPTY_PARAMS); Map fields = new HashMap() {{ put("field.foo", 1.0); put("field.bar", 0.5); put("categorical", "dog"); }}; - SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfig(0)); + SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null)); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), is("0")); ClassificationInferenceResults classificationResult = - (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1)); + (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfigUpdate(1, null, null, null)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0")); @@ -72,22 +75,29 @@ public class LocalModelTests extends ESTestCase { .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildClassification(true)) .build(); - model = new LocalModel(modelId, + model = new LocalModel<>(modelId, definition, new TrainedModelInput(inputFields), - Collections.singletonMap("field.foo", "field.foo.keyword")); - result = getSingleValue(model, fields, new ClassificationConfig(0)); + Collections.singletonMap("field.foo", "field.foo.keyword"), + ClassificationConfig.EMPTY_PARAMS); + result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null)); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), equalTo("not_to_be")); - classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1)); + classificationResult = (ClassificationInferenceResults)getSingleValue(model, + fields, + new ClassificationConfigUpdate(1, null, null, null)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); - classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(2)); + classificationResult = (ClassificationInferenceResults)getSingleValue(model, + fields, + new ClassificationConfigUpdate(2, null, null, null)); assertThat(classificationResult.getTopClasses(), hasSize(2)); - classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(-1)); + classificationResult = (ClassificationInferenceResults)getSingleValue(model, + fields, + new ClassificationConfigUpdate(-1, null, null, null)); assertThat(classificationResult.getTopClasses(), hasSize(2)); } @@ -97,10 +107,11 @@ public class LocalModelTests extends ESTestCase { .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildRegression()) .build(); - Model model = new LocalModel("regression_model", + Model model = new LocalModel<>("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields), - Collections.singletonMap("bar", "bar.keyword")); + Collections.singletonMap("bar", "bar.keyword"), + RegressionConfig.EMPTY_PARAMS); Map fields = new HashMap() {{ put("foo", 1.0); @@ -108,14 +119,8 @@ public class LocalModelTests extends ESTestCase { put("categorical", "dog"); }}; - SingleValueInferenceResults results = getSingleValue(model, fields, RegressionConfig.EMPTY_PARAMS); + SingleValueInferenceResults results = getSingleValue(model, fields, RegressionConfigUpdate.EMPTY_PARAMS); assertThat(results.value(), equalTo(1.3)); - - PlainActionFuture failedFuture = new PlainActionFuture<>(); - model.infer(fields, new ClassificationConfig(2), failedFuture); - ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get); - assertThat(ex.getCause().getMessage(), - equalTo("Cannot infer using configuration for [classification] when model target_type is [regression]")); } public void testAllFieldsMissing() throws Exception { @@ -124,7 +129,12 @@ public class LocalModelTests extends ESTestCase { .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildRegression()) .build(); - Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields), null); + Model model = new LocalModel<>( + "regression_model", + trainedModelDefinition, + new TrainedModelInput(inputFields), + null, + RegressionConfig.EMPTY_PARAMS); Map fields = new HashMap() {{ put("something", 1.0); @@ -132,18 +142,21 @@ public class LocalModelTests extends ESTestCase { put("baz", "dog"); }}; - WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfig.EMPTY_PARAMS); + WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfigUpdate.EMPTY_PARAMS); assertThat(results.getWarning(), equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, "regression_model"))); } - private static SingleValueInferenceResults getSingleValue(Model model, - Map fields, - InferenceConfig config) throws Exception { + private static SingleValueInferenceResults getSingleValue(Model model, + Map fields, + InferenceConfigUpdate config) + throws Exception { return (SingleValueInferenceResults)getInferenceResult(model, fields, config); } - private static InferenceResults getInferenceResult(Model model, Map fields, InferenceConfig config) throws Exception { + private static InferenceResults getInferenceResult(Model model, + Map fields, + InferenceConfigUpdate config) throws Exception { PlainActionFuture future = new PlainActionFuture<>(); model.infer(fields, config, future); return future.get(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index b9e33815fcb..50adfe6d861 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -36,6 +36,8 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -111,7 +113,7 @@ public class ModelLoadingServiceTests extends ESTestCase { String[] modelIds = new String[]{model1, model2, model3}; for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -124,7 +126,7 @@ public class ModelLoadingServiceTests extends ESTestCase { modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -164,7 +166,7 @@ public class ModelLoadingServiceTests extends ESTestCase { for(int i = 0; i < 10; i++) { // Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load) String model = modelIds[i%2]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -176,7 +178,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Load model 3, should invalidate 1 for(int i = 0; i < 10; i++) { - PlainActionFuture future3 = new PlainActionFuture<>(); + PlainActionFuture> future3 = new PlainActionFuture<>(); modelLoadingService.getModel(model3, future3); assertThat(future3.get(), is(not(nullValue()))); } @@ -184,7 +186,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Load model 1, should invalidate 2 for(int i = 0; i < 10; i++) { - PlainActionFuture future1 = new PlainActionFuture<>(); + PlainActionFuture> future1 = new PlainActionFuture<>(); modelLoadingService.getModel(model1, future1); assertThat(future1.get(), is(not(nullValue()))); } @@ -192,7 +194,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Load model 2, should invalidate 3 for(int i = 0; i < 10; i++) { - PlainActionFuture future2 = new PlainActionFuture<>(); + PlainActionFuture> future2 = new PlainActionFuture<>(); modelLoadingService.getModel(model2, future2); assertThat(future2.get(), is(not(nullValue()))); } @@ -204,7 +206,7 @@ public class ModelLoadingServiceTests extends ESTestCase { modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -230,7 +232,7 @@ public class ModelLoadingServiceTests extends ESTestCase { modelLoadingService.clusterChanged(ingestChangedEvent(false, model1)); for(int i = 0; i < 10; i++) { - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model1, future); assertThat(future.get(), is(not(nullValue()))); } @@ -250,7 +252,7 @@ public class ModelLoadingServiceTests extends ESTestCase { Settings.EMPTY); modelLoadingService.clusterChanged(ingestChangedEvent(model)); - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); try { @@ -274,7 +276,7 @@ public class ModelLoadingServiceTests extends ESTestCase { NamedXContentRegistry.EMPTY, Settings.EMPTY); - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); try { future.get(); @@ -296,7 +298,7 @@ public class ModelLoadingServiceTests extends ESTestCase { Settings.EMPTY); for(int i = 0; i < 3; i++) { - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -310,6 +312,7 @@ public class ModelLoadingServiceTests extends ESTestCase { when(definition.ramBytesUsed()).thenReturn(size); TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class); when(trainedModelConfig.getModelDefinition()).thenReturn(definition); + when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS); when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz"))); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 11037925735..1749d00f821 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -20,8 +20,8 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; @@ -146,20 +146,20 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { // Test regression InternalInferModelAction.Request request = new InternalInferModelAction.Request(modelId1, toInfer, - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, true); InternalInferModelAction.Response response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), contains(1.3, 1.25)); - request = new InternalInferModelAction.Request(modelId1, toInfer2, RegressionConfig.EMPTY_PARAMS, true); + request = new InternalInferModelAction.Request(modelId1, toInfer2, RegressionConfigUpdate.EMPTY_PARAMS, true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), contains(1.65, 1.55)); // Test classification - request = new InternalInferModelAction.Request(modelId2, toInfer, ClassificationConfig.EMPTY_PARAMS, true); + request = new InternalInferModelAction.Request(modelId2, toInfer, ClassificationConfigUpdate.EMPTY_PARAMS, true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults() .stream() @@ -168,7 +168,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { contains("not_to_be", "to_be")); // Get top classes - request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2, null, null), true); + request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null), true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); ClassificationInferenceResults classificationInferenceResults = @@ -187,7 +187,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); // Test that top classes restrict the number returned - request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1, null, null), true); + request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfigUpdate(1, null, null, null), true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); @@ -262,7 +262,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { // Test regression InternalInferModelAction.Request request = new InternalInferModelAction.Request(modelId, toInfer, - ClassificationConfig.EMPTY_PARAMS, + ClassificationConfigUpdate.EMPTY_PARAMS, true); InternalInferModelAction.Response response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults() @@ -271,7 +271,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { .collect(Collectors.toList()), contains("option_0", "option_2")); - request = new InternalInferModelAction.Request(modelId, toInfer2, ClassificationConfig.EMPTY_PARAMS, true); + request = new InternalInferModelAction.Request(modelId, toInfer2, ClassificationConfigUpdate.EMPTY_PARAMS, true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults() .stream() @@ -281,7 +281,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { // Get top classes - request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfig(3, null, null), true); + request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfigUpdate(3, null, null, null), true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); ClassificationInferenceResults classificationInferenceResults = @@ -303,7 +303,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { InternalInferModelAction.Request request = new InternalInferModelAction.Request( model, Collections.emptyList(), - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, true); try { client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); @@ -344,7 +344,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { InternalInferModelAction.Request request = new InternalInferModelAction.Request( modelId, toInferMissingField, - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, true); try { InferenceResults result = diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index 8c1f1e581de..dd49427f2d8 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -11,6 +11,7 @@ setup: "description": "empty model for tests", "tags": ["regression", "tag1"], "input": {"field_names": ["field1", "field2"]}, + "inference_config": {"regression": {}}, "definition": { "preprocessors": [], "trained_model": { @@ -35,6 +36,7 @@ setup: "description": "empty model for tests", "input": {"field_names": ["field1", "field2"]}, "tags": ["regression", "tag2"], + "inference_config": {"regression": {}}, "definition": { "preprocessors": [], "trained_model": { @@ -58,6 +60,7 @@ setup: "description": "empty model for tests", "input": {"field_names": ["field1", "field2"]}, "tags": ["classification", "tag2"], + "inference_config": {"classification": {}}, "definition": { "preprocessors": [], "trained_model": { @@ -83,6 +86,7 @@ setup: "description": "empty model for tests", "input": {"field_names": ["field1", "field2"]}, "tags": ["classification", "tag3"], + "inference_config": {"classification": {}}, "definition": { "preprocessors": [], "trained_model": { @@ -108,6 +112,7 @@ setup: "description": "empty model for tests", "input": {"field_names": ["field1", "field2"]}, "tags": ["classification", "tag3"], + "inference_config": {"classification": {}}, "definition": { "preprocessors": [], "trained_model": { @@ -343,6 +348,7 @@ setup: "input": { "field_names": "fieldy_mc_fieldname" }, + "inference_config": {"regression": {}}, "definition": { "trained_model": { "ensemble": { @@ -377,6 +383,7 @@ setup: "input": { "field_names": "fieldy_mc_fieldname" }, + "inference_config": {"regression": {}}, "definition": { "trained_model": { "ensemble": { @@ -397,6 +404,7 @@ setup: "input": { "field_names": "fieldy_mc_fieldname" }, + "inference_config": {"regression": {}}, "definition": { "trained_model": { "ensemble": { @@ -434,6 +442,7 @@ setup: "input": { "field_names": [] }, + "inference_config": {"regression": {}}, "definition": { "trained_model": { "ensemble": { @@ -469,6 +478,7 @@ setup: { "description": "model for tests", "input": {"field_names": ["field1", "field2"]}, + "inference_config": {"regression": {}}, "definition": { "preprocessors": [], "trained_model": { @@ -510,3 +520,47 @@ setup: - is_true: create_time - is_true: version - is_true: estimated_heap_memory_usage_bytes +--- +"Test PUT model where target type and inference config mismatch": + - do: + catch: /Model \[my-regression-model\] inference config type \[classification\] does not support definition target type \[regression\]/ + ml.put_trained_model: + model_id: my-regression-model + body: > + { + "description": "model for tests", + "input": {"field_names": ["field1", "field2"]}, + "inference_config": {"classification": {}}, + "definition": { + "preprocessors": [], + "trained_model": { + "ensemble": { + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + }, + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + } + ] + } + } + } + } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_processor.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_processor.yml index 9c043eee3cd..08b6603b4f0 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_processor.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_processor.yml @@ -13,6 +13,7 @@ setup: "description": "empty model for tests", "tags": ["regression", "tag1"], "input": {"field_names": ["field1", "field2"]}, + "inference_config": { "regression": {"results_field": "my_regression"}}, "definition": { "preprocessors": [], "trained_model": { @@ -112,3 +113,42 @@ setup: - 'Deprecated field [field_mappings] used, expected [field_map] instead' ingest.delete_pipeline: id: "regression-model-pipeline" +--- +"Test simulate": + - do: + ingest.simulate: + body: > + { + "pipeline": { + "processors": [ + { + "inference" : { + "model_id" : "a-perfect-regression-model", + "inference_config": {"regression": {}}, + "target_field": "regression_field", + "field_map": {} + } + } + ]}, + "docs": [{"_source": {"field1": 1, "field2": 2}}] + } + - match: { docs.0.doc._source.regression_field.my_regression: 42.0 } + + - do: + ingest.simulate: + body: > + { + "pipeline": { + "processors": [ + { + "inference" : { + "model_id" : "a-perfect-regression-model", + "inference_config": {"regression": {"results_field": "value"}}, + "target_field": "regression_field", + "field_map": {} + } + } + ]}, + "docs": [{"_source": {"field1": 1, "field2": 2}}] + } + - match: { docs.0.doc._source.regression_field.value: 42.0 } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml index 40d0f6ba01a..805a50d04a5 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml @@ -6,12 +6,13 @@ setup: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser index: id: trained_model_config-a-unused-regression-model1-0 - index: .ml-inference-000001 + index: .ml-inference-000002 body: > { "model_id": "a-unused-regression-model1", "created_by": "ml_tests", "version": "8.0.0", + "inference_config": {"regression": {}}, "description": "empty model for tests", "create_time": 0, "doc_type": "trained_model_config" @@ -22,12 +23,13 @@ setup: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser index: id: trained_model_config-a-unused-regression-model-0 - index: .ml-inference-000001 + index: .ml-inference-000002 body: > { "model_id": "a-unused-regression-model", "created_by": "ml_tests", "version": "8.0.0", + "inference_config": {"regression": {}}, "description": "empty model for tests", "create_time": 0, "doc_type": "trained_model_config" @@ -37,12 +39,13 @@ setup: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser index: id: trained_model_config-a-used-regression-model-0 - index: .ml-inference-000001 + index: .ml-inference-000002 body: > { "model_id": "a-used-regression-model", "created_by": "ml_tests", "version": "8.0.0", + "inference_config": {"regression": {}}, "description": "empty model for tests", "create_time": 0, "doc_type": "trained_model_config" diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml index 888eaad4403..0d7ecee9571 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml @@ -13,6 +13,7 @@ setup: { "description": "empty model for tests", "tags": ["regression", "tag1"], + "inference_config": {"regression":{}}, "input": {"field_names": ["field1", "field2"]}, "definition": { "preprocessors": [], @@ -36,6 +37,7 @@ setup: body: > { "description": "empty model for tests", + "inference_config": {"regression":{}}, "input": {"field_names": ["field1", "field2"]}, "definition": { "preprocessors": [],