From 4a1610265f05ac6a1a52f724d14b3c2141081176 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 2 Apr 2020 12:25:10 -0400 Subject: [PATCH] [7.x] [ML] add new inference_config field to trained model config (#54421) (#54647) * [ML] add new inference_config field to trained model config (#54421) A new field called `inference_config` is now added to the trained model config object. This new field allows for default inference settings from analytics or some external model builder. The inference processor can still override whatever is set as the default in the trained model config. * fixing for backport --- .../MlInferenceNamedXContentProvider.java | 11 + .../inference/NamedXContentObjectHelper.java | 10 + .../ml/inference/TrainedModelConfig.java | 30 ++- .../trainedmodel/ClassificationConfig.java | 129 ++++++++++ .../trainedmodel/InferenceConfig.java | 26 ++ .../trainedmodel/RegressionConfig.java | 104 ++++++++ .../client/MachineLearningIT.java | 4 + .../client/RestHighLevelClientTests.java | 9 +- .../MlClientDocumentationIT.java | 4 + .../ml/inference/TrainedModelConfigTests.java | 11 +- .../ClassificationConfigTests.java | 51 ++++ .../trainedmodel/RegressionConfigTests.java | 50 ++++ .../high-level/ml/put-trained-model.asciidoc | 2 + .../ingest/processors/inference.asciidoc | 44 ++-- .../df-analytics/apis/put-inference.asciidoc | 79 +++++- docs/reference/ml/ml-shared.asciidoc | 33 +++ .../ml/action/InternalInferModelAction.java | 47 +++- .../dataframe/analyses/BoostedTreeParams.java | 24 ++ .../MlInferenceNamedXContentProvider.java | 31 ++- .../core/ml/inference/TrainedModelConfig.java | 46 +++- .../persistence/InferenceIndexConstants.java | 14 +- .../trainedmodel/ClassificationConfig.java | 101 +++++--- .../ClassificationConfigUpdate.java | 235 ++++++++++++++++++ .../trainedmodel/InferenceConfigUpdate.java | 19 ++ .../LenientlyParsedInferenceConfig.java | 9 + .../trainedmodel/RegressionConfig.java | 77 +++--- .../trainedmodel/RegressionConfigUpdate.java | 178 +++++++++++++ .../StrictlyParsedInferenceConfig.java | 9 + .../ml/utils/NamedXContentObjectHelper.java | 10 + .../core/ml/inference_index_template.json | 5 +- .../ml/AbstractBWCSerializationTestCase.java | 12 +- .../AbstractBWCWireSerializationTestCase.java | 73 ++++++ .../InternalInferModelActionRequestTests.java | 45 +++- .../ml/inference/TrainedModelConfigTests.java | 22 +- .../TrainedModelDefinitionTests.java | 9 +- .../ClassificationConfigTests.java | 49 ++-- .../ClassificationConfigUpdateTests.java | 69 +++++ .../trainedmodel/RegressionConfigTests.java | 45 ++-- .../RegressionConfigUpdateTests.java | 62 +++++ .../trainedmodel/ensemble/EnsembleTests.java | 5 +- .../trainedmodel/tree/TreeTests.java | 15 +- .../ml/qa/ml-with-security/build.gradle | 2 + .../ml/integration/InferenceIngestIT.java | 2 + .../xpack/ml/integration/TrainedModelIT.java | 2 + .../TransportInternalInferModelAction.java | 8 +- .../TransportPutTrainedModelAction.java | 16 ++ .../process/AnalyticsResultProcessor.java | 29 +++ .../inference/ingest/InferenceProcessor.java | 19 +- .../inference/loadingservice/LocalModel.java | 20 +- .../ml/inference/loadingservice/Model.java | 5 +- .../loadingservice/ModelLoadingService.java | 70 ++++-- .../MachineLearningLicensingTests.java | 10 +- .../AnalyticsResultProcessorTests.java | 12 +- .../ingest/InferenceProcessorTests.java | 31 ++- .../loadingservice/LocalModelTests.java | 65 +++-- .../ModelLoadingServiceTests.java | 25 +- .../integration/ModelInferenceActionIT.java | 24 +- .../rest-api-spec/test/ml/inference_crud.yml | 54 ++++ .../test/ml/inference_processor.yml | 40 +++ .../test/ml/inference_stats_crud.yml | 9 +- .../test/ml/trained_model_cat_apis.yml | 2 + 61 files changed, 1950 insertions(+), 303 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfig.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/InferenceConfig.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfig.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfigTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfigTests.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigUpdate.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LenientlyParsedInferenceConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdate.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/StrictlyParsedInferenceConfig.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCWireSerializationTestCase.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdateTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdateTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java index 0be0e8f6c58..d8a00f1d8f6 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java @@ -19,6 +19,9 @@ package org.elasticsearch.client.ml.inference; import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding; +import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression; @@ -61,6 +64,14 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider { new ParseField(LangIdentNeuralNetwork.NAME), LangIdentNeuralNetwork::fromXContent)); + // Inference Config + namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class, + ClassificationConfig.NAME, + ClassificationConfig::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class, + RegressionConfig.NAME, + RegressionConfig::fromXContent)); + // Aggregating output namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class, new ParseField(WeightedMode.NAME), diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java index 1795f5da495..10a739280be 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java @@ -54,4 +54,14 @@ public final class NamedXContentObjectHelper { } return builder; } + + public static XContentBuilder writeNamedObject(XContentBuilder builder, + ToXContent.Params params, + String namedObjectName, + NamedXContentObject namedObject) throws IOException { + builder.startObject(namedObjectName); + builder.field(namedObject.getName(), namedObject, params); + builder.endObject(); + return builder; + } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java index d8749c0a6cd..9d821752240 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -20,6 +20,7 @@ package org.elasticsearch.client.ml.inference; import org.elasticsearch.Version; import org.elasticsearch.client.common.TimeUtil; +import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.unit.ByteSizeValue; @@ -36,6 +37,8 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import static org.elasticsearch.client.ml.inference.NamedXContentObjectHelper.writeNamedObject; + public class TrainedModelConfig implements ToXContentObject { public static final String NAME = "trained_model_config"; @@ -54,6 +57,7 @@ public class TrainedModelConfig implements ToXContentObject { public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations"); public static final ParseField LICENSE_LEVEL = new ParseField("license_level"); public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map"); + public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config"); public static final ObjectParser PARSER = new ObjectParser<>(NAME, true, @@ -78,6 +82,9 @@ public class TrainedModelConfig implements ToXContentObject { PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS); PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL); PARSER.declareObject(TrainedModelConfig.Builder::setDefaultFieldMap, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAP); + PARSER.declareNamedObject(TrainedModelConfig.Builder::setInferenceConfig, + (p, c, n) -> p.namedObject(InferenceConfig.class, n, null), + INFERENCE_CONFIG); } public static TrainedModelConfig fromXContent(XContentParser parser) throws IOException { @@ -98,6 +105,7 @@ public class TrainedModelConfig implements ToXContentObject { private final Long estimatedOperations; private final String licenseLevel; private final Map defaultFieldMap; + private final InferenceConfig inferenceConfig; TrainedModelConfig(String modelId, String createdBy, @@ -112,7 +120,8 @@ public class TrainedModelConfig implements ToXContentObject { Long estimatedHeapMemory, Long estimatedOperations, String licenseLevel, - Map defaultFieldMap) { + Map defaultFieldMap, + InferenceConfig inferenceConfig) { this.modelId = modelId; this.createdBy = createdBy; this.version = version; @@ -127,6 +136,7 @@ public class TrainedModelConfig implements ToXContentObject { this.estimatedOperations = estimatedOperations; this.licenseLevel = licenseLevel; this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap); + this.inferenceConfig = inferenceConfig; } public String getModelId() { @@ -189,6 +199,10 @@ public class TrainedModelConfig implements ToXContentObject { return defaultFieldMap; } + public InferenceConfig getInferenceConfig() { + return inferenceConfig; + } + public static Builder builder() { return new Builder(); } @@ -238,6 +252,9 @@ public class TrainedModelConfig implements ToXContentObject { if (defaultFieldMap != null) { builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap); } + if (inferenceConfig != null) { + writeNamedObject(builder, params, INFERENCE_CONFIG.getPreferredName(), inferenceConfig); + } builder.endObject(); return builder; } @@ -265,6 +282,7 @@ public class TrainedModelConfig implements ToXContentObject { Objects.equals(estimatedOperations, that.estimatedOperations) && Objects.equals(licenseLevel, that.licenseLevel) && Objects.equals(defaultFieldMap, that.defaultFieldMap) && + Objects.equals(inferenceConfig, that.inferenceConfig) && Objects.equals(metadata, that.metadata); } @@ -283,6 +301,7 @@ public class TrainedModelConfig implements ToXContentObject { metadata, licenseLevel, input, + inferenceConfig, defaultFieldMap); } @@ -303,6 +322,7 @@ public class TrainedModelConfig implements ToXContentObject { private Long estimatedOperations; private String licenseLevel; private Map defaultFieldMap; + private InferenceConfig inferenceConfig; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -387,6 +407,11 @@ public class TrainedModelConfig implements ToXContentObject { return this; } + public Builder setInferenceConfig(InferenceConfig inferenceConfig) { + this.inferenceConfig = inferenceConfig; + return this; + } + public TrainedModelConfig build() { return new TrainedModelConfig( modelId, @@ -402,7 +427,8 @@ public class TrainedModelConfig implements ToXContentObject { estimatedHeapMemory, estimatedOperations, licenseLevel, - defaultFieldMap); + defaultFieldMap, + inferenceConfig); } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfig.java new file mode 100644 index 00000000000..87b37434075 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfig.java @@ -0,0 +1,129 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class ClassificationConfig implements InferenceConfig { + + public static final ParseField NAME = new ParseField("classification"); + + public static final ParseField RESULTS_FIELD = new ParseField("results_field"); + public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field"); + public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); + + + private final Integer numTopClasses; + private final String topClassesResultsField; + private final String resultsField; + private final Integer numTopFeatureImportanceValues; + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new ClassificationConfig( + (Integer) args[0], (String) args[1], (String) args[2], (Integer) args[3])); + + static { + PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); + PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD); + PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD); + PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); + } + + public static ClassificationConfig fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public ClassificationConfig() { + this(null, null, null, null); + } + + public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField, Integer featureImportance) { + this.numTopClasses = numTopClasses; + this.topClassesResultsField = topClassesResultsField; + this.resultsField = resultsField; + this.numTopFeatureImportanceValues = featureImportance; + } + + public Integer getNumTopClasses() { + return numTopClasses; + } + + public String getTopClassesResultsField() { + return topClassesResultsField; + } + + public String getResultsField() { + return resultsField; + } + + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassificationConfig that = (ClassificationConfig) o; + return Objects.equals(numTopClasses, that.numTopClasses) + && Objects.equals(topClassesResultsField, that.topClassesResultsField) + && Objects.equals(resultsField, that.resultsField) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); + } + + @Override + public int hashCode() { + return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (numTopClasses != null) { + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); + } + if (topClassesResultsField != null) { + builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField); + } + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } + builder.endObject(); + return builder; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/InferenceConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/InferenceConfig.java new file mode 100644 index 00000000000..0e3e911cd2c --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/InferenceConfig.java @@ -0,0 +1,26 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel; + +import org.elasticsearch.client.ml.inference.NamedXContentObject; + + +public interface InferenceConfig extends NamedXContentObject { + +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfig.java new file mode 100644 index 00000000000..dddd9c5ab48 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfig.java @@ -0,0 +1,104 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class RegressionConfig implements InferenceConfig { + + public static final ParseField NAME = new ParseField("regression"); + public static final ParseField RESULTS_FIELD = new ParseField("results_field"); + public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME.getPreferredName(), + true, + args -> new RegressionConfig((String) args[0], (Integer)args[1])); + + static { + PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD); + PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); + } + + public static RegressionConfig fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final String resultsField; + private final Integer numTopFeatureImportanceValues; + + public RegressionConfig() { + this(null, null); + } + + public RegressionConfig(String resultsField, Integer numTopFeatureImportanceValues) { + this.resultsField = resultsField; + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + } + + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + + public String getResultsField() { + return resultsField; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RegressionConfig that = (RegressionConfig)o; + return Objects.equals(this.resultsField, that.resultsField) + && Objects.equals(this.numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); + } + + @Override + public int hashCode() { + return Objects.hash(resultsField, numTopFeatureImportanceValues); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index d1b8188b100..08ce2d5f5cc 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -158,6 +158,7 @@ import org.elasticsearch.client.ml.inference.TrainedModelDefinition; import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.client.ml.inference.TrainedModelInput; import org.elasticsearch.client.ml.inference.TrainedModelStats; +import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; import org.elasticsearch.client.ml.job.config.AnalysisConfig; @@ -2272,6 +2273,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder() .setDefinition(definition) .setModelId(modelId) + .setInferenceConfig(new RegressionConfig()) .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4"))) .setDescription("test model") .build(); @@ -2285,6 +2287,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { trainedModelConfig = TrainedModelConfig.builder() .setCompressedDefinition(InferenceToXContentCompressor.deflate(definition)) .setModelId(modelIdCompressed) + .setInferenceConfig(new RegressionConfig()) .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4"))) .setDescription("test model") .build(); @@ -2591,6 +2594,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder() .setDefinition(definition) .setModelId(modelId) + .setInferenceConfig(new RegressionConfig()) .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4"))) .setDescription("test model") .build(); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index a93da23328c..e7575b737b8 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -75,6 +75,8 @@ import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding; import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding; +import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode; @@ -699,7 +701,7 @@ public class RestHighLevelClientTests extends ESTestCase { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(62, namedXContents.size()); + assertEquals(64, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -709,7 +711,7 @@ public class RestHighLevelClientTests extends ESTestCase { categories.put(namedXContent.categoryClass, counter + 1); } } - assertEquals("Had: " + categories, 13, categories.size()); + assertEquals("Had: " + categories, 14, categories.size()); assertEquals(Integer.valueOf(3), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); @@ -783,6 +785,9 @@ public class RestHighLevelClientTests extends ESTestCase { assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class)); assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME)); + assertEquals(Integer.valueOf(2), + categories.get(org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig.class)); + assertThat(names, hasItems(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName())); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index ac43ac07a3e..b11804a6884 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -174,6 +174,7 @@ import org.elasticsearch.client.ml.inference.TrainedModelDefinition; import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.client.ml.inference.TrainedModelInput; import org.elasticsearch.client.ml.inference.TrainedModelStats; +import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.AnalysisLimits; @@ -3646,11 +3647,13 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .setDescription("test model") // <5> .setMetadata(new HashMap<>()) // <6> .setTags("my_regression_models") // <7> + .setInferenceConfig(new RegressionConfig("value", 0)) // <8> .build(); // end::put-trained-model-config trainedModelConfig = TrainedModelConfig.builder() .setDefinition(definition) + .setInferenceConfig(new RegressionConfig(null, null)) .setModelId("my-new-trained-model") .setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) .setDescription("test model") @@ -4234,6 +4237,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder() .setDefinition(definition) .setModelId(modelId) + .setInferenceConfig(new RegressionConfig("value", 0)) .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4"))) .setDescription("test model") .build(); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index be4b0443e59..429ab53dcc9 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -19,6 +19,9 @@ package org.elasticsearch.client.ml.inference; import org.elasticsearch.Version; +import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfigTests; +import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfigTests; +import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; @@ -39,13 +42,14 @@ import java.util.stream.Stream; public class TrainedModelConfigTests extends AbstractXContentTestCase { public static TrainedModelConfig createTestTrainedModelConfig() { + TargetType targetType = randomFrom(TargetType.values()); return new TrainedModelConfig( randomAlphaOfLength(10), randomAlphaOfLength(10), Version.CURRENT, randomBoolean() ? null : randomAlphaOfLength(100), Instant.ofEpochMilli(randomNonNegativeLong()), - randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(), + randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder(targetType).build(), randomBoolean() ? null : randomAlphaOfLength(100), randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()), @@ -57,7 +61,10 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase randomAlphaOfLength(10)) .limit(randomIntBetween(1, 10)) - .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10)))); + .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))), + targetType.equals(TargetType.CLASSIFICATION) ? + ClassificationConfigTests.randomClassificationConfig() : + RegressionConfigTests.randomRegressionConfig()); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfigTests.java new file mode 100644 index 00000000000..3dab960fcc3 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ClassificationConfigTests.java @@ -0,0 +1,51 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.inference.trainedmodel; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class ClassificationConfigTests extends AbstractXContentTestCase { + + public static ClassificationConfig randomClassificationConfig() { + return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10), + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomIntBetween(0, 10) + ); + } + + @Override + protected ClassificationConfig createTestInstance() { + return randomClassificationConfig(); + } + + @Override + protected ClassificationConfig doParseInstance(XContentParser parser) throws IOException { + return ClassificationConfig.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfigTests.java new file mode 100644 index 00000000000..4057b961b96 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/RegressionConfigTests.java @@ -0,0 +1,50 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.inference.trainedmodel; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class RegressionConfigTests extends AbstractXContentTestCase { + + public static RegressionConfig randomRegressionConfig() { + return new RegressionConfig( + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomIntBetween(0, 10)); + } + + + @Override + protected RegressionConfig createTestInstance() { + return randomRegressionConfig(); + } + + @Override + protected RegressionConfig doParseInstance(XContentParser parser) throws IOException { + return RegressionConfig.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/docs/java-rest/high-level/ml/put-trained-model.asciidoc b/docs/java-rest/high-level/ml/put-trained-model.asciidoc index 6a0f96a78b9..7637b739940 100644 --- a/docs/java-rest/high-level/ml/put-trained-model.asciidoc +++ b/docs/java-rest/high-level/ml/put-trained-model.asciidoc @@ -39,6 +39,8 @@ include-tagged::{doc-tests-file}[{api}-config] <5> Optionally, a human-readable description <6> Optionally, an object map contain metadata about the model <7> Optionally, an array of tags to organize the model +<8> The default inference config to use with the model. Must match the underlying + definition target_type. include::../execution.asciidoc[] diff --git a/docs/reference/ingest/processors/inference.asciidoc b/docs/reference/ingest/processors/inference.asciidoc index 6066a78feaa..4f5a1d76c7f 100644 --- a/docs/reference/ingest/processors/inference.asciidoc +++ b/docs/reference/ingest/processors/inference.asciidoc @@ -38,44 +38,38 @@ include::common-options.asciidoc[] [[inference-processor-regression-opt]] ==== {regression-cap} configuration options +Regression configuration for inference. + `results_field`:: (Optional, string) -Specifies the field to which the inference prediction is written. Defaults to -`predicted_value`. +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-regression-results-field] -`num_top_feature_importance_values`:::: +`num_top_feature_importance_values`:: (Optional, integer) -Specifies the maximum number of -{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature -importance] values per document. By default, it is zero and no feature importance -calculation occurs. +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-regression-num-top-feature-importance-values] [discrete] [[inference-processor-classification-opt]] ==== {classification-cap} configuration options -`results_field`:: -(Optional, string) -The field that is added to incoming documents to contain the inference prediction. Defaults to -`predicted_value`. +Classification configuration for inference. `num_top_classes`:: -(Optional, integer) -Specifies the number of top class predictions to return. Defaults to 0. +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-classes] + +`num_top_feature_importance_values`:: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values] + +`results_field`:: +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-results-field] `top_classes_results_field`:: -(Optional, string) -Specifies the field to which the top classes are written. Defaults to -`top_classes`. - -`num_top_feature_importance_values`:::: -(Optional, integer) -Specifies the maximum number of -{ml-docs}/dfa-classification.html#dfa-classification-feature-importance[feature -importance] values per document. By default, it is zero and no feature -importance calculation occurs. - +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-top-classes-results-field] [discrete] [[inference-processor-config-example]] @@ -178,4 +172,4 @@ You can also specify a target field as follows: // NOTCONSOLE In this case, {feat-imp} is exposed in the -`my_field.foo.feature_importance` field. \ No newline at end of file +`my_field.foo.feature_importance` field. diff --git a/docs/reference/ml/df-analytics/apis/put-inference.asciidoc b/docs/reference/ml/df-analytics/apis/put-inference.asciidoc index 44ccfc445b8..f49b288c306 100644 --- a/docs/reference/ml/df-analytics/apis/put-inference.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-inference.asciidoc @@ -43,7 +43,7 @@ is not created by {dfanalytics}. (Required, string) include::{docdir}/ml/ml-shared.asciidoc[tag=model-id] - +[role="child_attributes"] [[ml-put-inference-request-body]] ==== {api-request-body-title} @@ -52,32 +52,96 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=model-id] The compressed (GZipped and Base64 encoded) {infer} definition of the model. If `compressed_definition` is specified, then `definition` cannot be specified. +//Begin definition `definition`:: (Required, object) The {infer} definition for the model. If `definition` is specified, then `compressed_definition` cannot be specified. - -`definition`.`preprocessors`::: ++ +.Properties of `definition` +[%collapsible%open] +==== +`preprocessors`::: (Optional, object) Collection of preprocessors. See <> for the full list of available preprocessors. -`definition`.`trained_model`::: +`trained_model`::: (Required, object) The definition of the trained model. See <> for details. +==== +//End definition `description`:: (Optional, string) A human-readable description of the {infer} trained model. +//Begin inference_config +`inference_config`:: +(Required, object) +The default configuration for inference. This can be either a `regression` +or `classification` configuration. It must match the underlying +`definition.trained_model`'s `target_type`. ++ +.Properties of `inference_config` +[%collapsible%open] +==== +`regression`::: +(Optional, object) +Regression configuration for inference. ++ +.Properties of regression inference +[%collapsible%open] +===== +`num_top_feature_importance_values`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-regression-num-top-feature-importance-values] + +`results_field`:::: +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-regression-results-field] +===== + +`classification`::: +(Optional, object) +Classification configuration for inference. ++ +.Properties of classification inference +[%collapsible%open] +===== +`num_top_classes`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-classes] + +`num_top_feature_importance_values`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values] + +`results_field`:::: +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-results-field] + +`top_classes_results_field`:::: +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-top-classes-results-field] +===== +==== +//End of inference_config + +//Begin input `input`:: (Required, object) The input field names for the model definition. - -`input`.`field_names`::: ++ +.Properties of `input` +[%collapsible%open] +==== +`field_names`::: (Required, string) An array of input field names for the model. +==== +//End input `metadata`:: (Optional, object) @@ -87,7 +151,6 @@ An object map that contains metadata about the model. (Optional, string) An array of tags to organize the model. - [[ml-put-inference-preprocessors]] ===== {infer-cap} preprocessor definitions @@ -491,4 +554,4 @@ Example of a `weighted_mode` object: ===== {infer-cap} JSON schema For the full JSON schema of model {infer}, -https://github.com/elastic/ml-json-schemas[click here]. \ No newline at end of file +https://github.com/elastic/ml-json-schemas[click here]. diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index c144df477c4..da1a4d60a8e 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -1213,6 +1213,39 @@ For more information about these options, see <>. -- end::indices-options[] +tag::inference-config-classification-num-top-classes[] +Specifies the number of top class predictions to return. Defaults to 0. +end::inference-config-classification-num-top-classes[] + +tag::inference-config-classification-num-top-feature-importance-values[] +Specifies the maximum number of +{ml-docs}/dfa-classification.html#dfa-classification-feature-importance[feature +importance] values per document. By default, it is zero and no feature +importance calculation occurs. +end::inference-config-classification-num-top-feature-importance-values[] + +tag::inference-config-classification-results-field[] +The field that is added to incoming documents to contain the inference +prediction. Defaults to `predicted_value`. +end::inference-config-classification-results-field[] + +tag::inference-config-classification-top-classes-results-field[] +Specifies the field to which the top classes are written. Defaults to +`top_classes`. +end::inference-config-classification-top-classes-results-field[] + +tag::inference-config-regression-num-top-feature-importance-values[] +Specifies the maximum number of +{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature +importance] values per document. By default, it is zero and no feature importance +calculation occurs. +end::inference-config-regression-num-top-feature-importance-values[] + +tag::inference-config-regression-results-field[] +Specifies the field to which the inference prediction is written. Defaults to +`predicted_value`. +end::inference-config-regression-results-field[] + tag::influencers[] A comma separated list of influencer field names. Typically these can be the by, over, or partition fields that are used in the detector configuration. You might diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java index 556fb2b4f02..82c4bc5f435 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; @@ -13,8 +14,12 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -37,27 +42,30 @@ public class InternalInferModelAction extends ActionType> objectsToInfer; - private final InferenceConfig config; + private final InferenceConfigUpdate update; private final boolean previouslyLicensed; public Request(String modelId, boolean previouslyLicensed) { - this(modelId, Collections.emptyList(), RegressionConfig.EMPTY_PARAMS, previouslyLicensed); + this(modelId, Collections.emptyList(), RegressionConfigUpdate.EMPTY_PARAMS, previouslyLicensed); } public Request(String modelId, List> objectsToInfer, - InferenceConfig inferenceConfig, + InferenceConfigUpdate inferenceConfig, boolean previouslyLicensed) { this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer")); - this.config = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config"); + this.update = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config"); this.previouslyLicensed = previouslyLicensed; } - public Request(String modelId, Map objectToInfer, InferenceConfig config, boolean previouslyLicensed) { + public Request(String modelId, + Map objectToInfer, + InferenceConfigUpdate update, + boolean previouslyLicensed) { this(modelId, Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")), - config, + update, previouslyLicensed); } @@ -65,7 +73,18 @@ public class InternalInferModelAction extends ActionType)in.readNamedWriteable(InferenceConfigUpdate.class); + } else { + InferenceConfig oldConfig = in.readNamedWriteable(InferenceConfig.class); + if (oldConfig instanceof RegressionConfig) { + this.update = RegressionConfigUpdate.fromConfig((RegressionConfig)oldConfig); + } else if (oldConfig instanceof ClassificationConfig) { + this.update = ClassificationConfigUpdate.fromConfig((ClassificationConfig) oldConfig); + } else { + throw new IOException("Unexpected configuration type [" + oldConfig.getName() + "]"); + } + } this.previouslyLicensed = in.readBoolean(); } @@ -77,8 +96,8 @@ public class InternalInferModelAction extends ActionType LENIENT_PARSER = createParser(true); @@ -93,6 +98,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { parser.declareString(TrainedModelConfig.Builder::setLazyDefinition, COMPRESSED_DEFINITION); parser.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL); parser.declareObject(TrainedModelConfig.Builder::setDefaultFieldMap, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAP); + parser.declareNamedObject(TrainedModelConfig.Builder::setInferenceConfig, (p, c, n) -> ignoreUnknownFields ? + p.namedObject(LenientlyParsedInferenceConfig.class, n, null) : + p.namedObject(StrictlyParsedInferenceConfig.class, n, null), + INFERENCE_CONFIG); return parser; } @@ -112,6 +121,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { private final long estimatedOperations; private final License.OperationMode licenseLevel; private final Map defaultFieldMap; + private final InferenceConfig inferenceConfig; private final LazyModelDefinition definition; @@ -127,7 +137,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { Long estimatedHeapMemory, Long estimatedOperations, String licenseLevel, - Map defaultFieldMap) { + Map defaultFieldMap, + InferenceConfig inferenceConfig) { this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY); this.version = ExceptionsHelper.requireNonNull(version, VERSION); @@ -148,6 +159,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { this.estimatedOperations = estimatedOperations; this.licenseLevel = License.OperationMode.parse(ExceptionsHelper.requireNonNull(licenseLevel, LICENSE_LEVEL)); this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap); + this.inferenceConfig = inferenceConfig; } public TrainedModelConfig(StreamInput in) throws IOException { @@ -170,6 +182,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { } else { this.defaultFieldMap = null; } + if (in.getVersion().onOrAfter(Version.V_7_8_0)) { + this.inferenceConfig = in.readOptionalNamedWriteable(InferenceConfig.class); + } else { + this.inferenceConfig = null; + } } public String getModelId() { @@ -204,6 +221,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return defaultFieldMap; } + @Nullable + public InferenceConfig getInferenceConfig() { + return inferenceConfig; + } + @Nullable public String getCompressedDefinition() throws IOException { if (definition == null) { @@ -274,6 +296,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { out.writeBoolean(false); } } + if (out.getVersion().onOrAfter(Version.V_7_8_0)) { + out.writeOptionalNamedWriteable(inferenceConfig); + } } @Override @@ -311,6 +336,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { if (defaultFieldMap != null && defaultFieldMap.isEmpty() == false) { builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap); } + if (inferenceConfig != null) { + writeNamedObject(builder, params, INFERENCE_CONFIG.getPreferredName(), inferenceConfig); + } builder.endObject(); return builder; } @@ -337,6 +365,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { Objects.equals(estimatedOperations, that.estimatedOperations) && Objects.equals(licenseLevel, that.licenseLevel) && Objects.equals(defaultFieldMap, that.defaultFieldMap) && + Objects.equals(inferenceConfig, that.inferenceConfig) && Objects.equals(metadata, that.metadata); } @@ -354,6 +383,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { estimatedOperations, input, licenseLevel, + inferenceConfig, defaultFieldMap); } @@ -372,6 +402,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { private LazyModelDefinition definition; private String licenseLevel; private Map defaultFieldMap; + private InferenceConfig inferenceConfig; public Builder() {} @@ -389,6 +420,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { this.estimatedHeapMemory = config.estimatedHeapMemory; this.licenseLevel = config.licenseLevel.description(); this.defaultFieldMap = config.defaultFieldMap == null ? null : new HashMap<>(config.defaultFieldMap); + this.inferenceConfig = config.inferenceConfig; } public Builder setModelId(String modelId) { @@ -512,6 +544,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return this; } + public Builder setInferenceConfig(InferenceConfig inferenceConfig) { + this.inferenceConfig = inferenceConfig; + return this; + } + public Builder validate() { return validate(false); } @@ -530,6 +567,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { if (modelId == null) { validationException = addValidationError("[" + MODEL_ID.getPreferredName() + "] must not be null.", validationException); } + if (inferenceConfig == null && forCreation) { + validationException = addValidationError("[" + INFERENCE_CONFIG.getPreferredName() + "] must not be null.", + validationException); + } if (modelId != null && MlStrings.isValidId(modelId) == false) { validationException = addValidationError(Messages.getMessage(Messages.INVALID_ID, @@ -605,7 +646,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { estimatedHeapMemory == null ? 0 : estimatedHeapMemory, estimatedOperations == null ? 0 : estimatedOperations, licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel, - defaultFieldMap); + defaultFieldMap, + inferenceConfig); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java index 8621786cedd..7c7ef38d9cc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java @@ -5,19 +5,31 @@ */ package org.elasticsearch.xpack.core.ml.inference.persistence; +import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; +import org.elasticsearch.xpack.core.template.TemplateUtils; /** * Class containing the index constants so that the index version, name, and prefix are available to a wider audience. */ public final class InferenceIndexConstants { - public static final String INDEX_VERSION = "000001"; + /** + * version: 7.8.0: + * - adds inference_config definition to trained model config + * + */ + public static final String INDEX_VERSION = "000002"; public static final String INDEX_NAME_PREFIX = ".ml-inference-"; public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*"; public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION; public static final ParseField DOC_TYPE = new ParseField("doc_type"); private InferenceIndexConstants() {} + private static final String MAPPINGS_VERSION_VARIABLE = "xpack.ml.version"; + public static String mapping() { + return TemplateUtils.loadTemplate("/org/elasticsearch/xpack/core/ml/inference_index_mappings.json", + Version.CURRENT.toString(), MAPPINGS_VERSION_VARIABLE); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java index 1aa8c816ccb..82749c27ce9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java @@ -9,24 +9,19 @@ import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; import java.util.Objects; -import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; - -public class ClassificationConfig implements InferenceConfig { +public class ClassificationConfig implements LenientlyParsedInferenceConfig, StrictlyParsedInferenceConfig { public static final ParseField NAME = new ParseField("classification"); public static final String DEFAULT_TOP_CLASSES_RESULTS_FIELD = "top_classes"; - private static final String DEFAULT_RESULTS_FIELD = "predicted_value"; + public static final String DEFAULT_RESULTS_FIELD = "predicted_value"; public static final ParseField RESULTS_FIELD = new ParseField("results_field"); public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); @@ -42,32 +37,27 @@ public class ClassificationConfig implements InferenceConfig { private final String resultsField; private final int numTopFeatureImportanceValues; - public static ClassificationConfig fromMap(Map map) { - Map options = new HashMap<>(map); - Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName()); - String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName()); - String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName()); - Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); + private static final ObjectParser LENIENT_PARSER = createParser(true); + private static final ObjectParser STRICT_PARSER = createParser(false); - if (options.isEmpty() == false) { - throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet()); - } - return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, featureImportance); + private static ObjectParser createParser(boolean lenient) { + ObjectParser parser = new ObjectParser<>( + NAME.getPreferredName(), + lenient, + ClassificationConfig.Builder::new); + parser.declareInt(ClassificationConfig.Builder::setNumTopClasses, NUM_TOP_CLASSES); + parser.declareString(ClassificationConfig.Builder::setResultsField, RESULTS_FIELD); + parser.declareString(ClassificationConfig.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD); + parser.declareInt(ClassificationConfig.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES); + return parser; } - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ClassificationConfig( - (Integer) args[0], (String) args[1], (String) args[2], (Integer) args[3])); - - static { - PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); - PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD); - PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD); - PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); + public static ClassificationConfig fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null).build(); } - public static ClassificationConfig fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); + public static ClassificationConfig fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null).build(); } public ClassificationConfig(Integer numTopClasses) { @@ -150,14 +140,10 @@ public class ClassificationConfig implements InferenceConfig { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - if (numTopClasses != 0) { - builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); - } + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField); builder.field(RESULTS_FIELD.getPreferredName(), resultsField); - if (numTopFeatureImportanceValues > 0) { - builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); - } + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); builder.endObject(); return builder; } @@ -179,7 +165,50 @@ public class ClassificationConfig implements InferenceConfig { @Override public Version getMinimalSupportedVersion() { - return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION; + return requestingImportance() ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION; } + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private Integer numTopClasses; + private String topClassesResultsField; + private String resultsField; + private Integer numTopFeatureImportanceValues; + + Builder() {} + + Builder(ClassificationConfig config) { + this.numTopClasses = config.numTopClasses; + this.topClassesResultsField = config.topClassesResultsField; + this.resultsField = config.resultsField; + this.numTopFeatureImportanceValues = config.numTopFeatureImportanceValues; + } + + public Builder setNumTopClasses(Integer numTopClasses) { + this.numTopClasses = numTopClasses; + return this; + } + + public Builder setTopClassesResultsField(String topClassesResultsField) { + this.topClassesResultsField = topClassesResultsField; + return this; + } + + public Builder setResultsField(String resultsField) { + this.resultsField = resultsField; + return this; + } + + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + + public ClassificationConfig build() { + return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, numTopFeatureImportanceValues); + } + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java new file mode 100644 index 00000000000..2df5678eb5e --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java @@ -0,0 +1,235 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.NUM_TOP_CLASSES; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.RESULTS_FIELD; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.TOP_CLASSES_RESULTS_FIELD; + +public class ClassificationConfigUpdate implements InferenceConfigUpdate { + + public static final ParseField NAME = new ParseField("classification"); + + public static ClassificationConfigUpdate EMPTY_PARAMS = + new ClassificationConfigUpdate(null, null, null, null); + + private final Integer numTopClasses; + private final String topClassesResultsField; + private final String resultsField; + private final Integer numTopFeatureImportanceValues; + + public static ClassificationConfigUpdate fromMap(Map map) { + Map options = new HashMap<>(map); + Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName()); + String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName()); + String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName()); + Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); + + if (options.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet()); + } + return new ClassificationConfigUpdate(numTopClasses, resultsField, topClassesResultsField, featureImportance); + } + + public static ClassificationConfigUpdate fromConfig(ClassificationConfig config) { + return new ClassificationConfigUpdate(config.getNumTopClasses(), + config.getResultsField(), + config.getTopClassesResultsField(), + config.getNumTopFeatureImportanceValues()); + } + + private static final ObjectParser STRICT_PARSER = createParser(false); + + private static ObjectParser createParser(boolean lenient) { + ObjectParser parser = new ObjectParser<>( + NAME.getPreferredName(), + lenient, + ClassificationConfigUpdate.Builder::new); + parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopClasses, NUM_TOP_CLASSES); + parser.declareString(ClassificationConfigUpdate.Builder::setResultsField, RESULTS_FIELD); + parser.declareString(ClassificationConfigUpdate.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD); + parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES); + return parser; + } + + public static ClassificationConfigUpdate fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null).build(); + } + + public ClassificationConfigUpdate(Integer numTopClasses, + String resultsField, + String topClassesResultsField, + Integer featureImportance) { + this.numTopClasses = numTopClasses; + this.topClassesResultsField = topClassesResultsField; + this.resultsField = resultsField; + if (featureImportance != null && featureImportance < 0) { + throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() + + "] must be greater than or equal to 0"); + } + this.numTopFeatureImportanceValues = featureImportance; + } + + public ClassificationConfigUpdate(StreamInput in) throws IOException { + this.numTopClasses = in.readOptionalInt(); + this.topClassesResultsField = in.readOptionalString(); + this.resultsField = in.readOptionalString(); + this.numTopFeatureImportanceValues = in.readOptionalVInt(); + } + + public Integer getNumTopClasses() { + return numTopClasses; + } + + public String getTopClassesResultsField() { + return topClassesResultsField; + } + + public String getResultsField() { + return resultsField; + } + + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalInt(numTopClasses); + out.writeOptionalString(topClassesResultsField); + out.writeOptionalString(resultsField); + out.writeOptionalVInt(numTopFeatureImportanceValues); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassificationConfigUpdate that = (ClassificationConfigUpdate) o; + return Objects.equals(numTopClasses, that.numTopClasses) + && Objects.equals(topClassesResultsField, that.topClassesResultsField) + && Objects.equals(resultsField, that.resultsField) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); + } + + @Override + public int hashCode() { + return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (numTopClasses != null) { + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); + } + if (topClassesResultsField != null) { + builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField); + } + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public ClassificationConfig apply(ClassificationConfig originalConfig) { + if (isNoop(originalConfig)) { + return originalConfig; + } + ClassificationConfig.Builder builder = new ClassificationConfig.Builder(originalConfig); + if (resultsField != null) { + builder.setResultsField(resultsField); + } + if (numTopFeatureImportanceValues != null) { + builder.setNumTopFeatureImportanceValues(numTopFeatureImportanceValues); + } + if (topClassesResultsField != null) { + builder.setTopClassesResultsField(topClassesResultsField); + } + if (numTopClasses != null) { + builder.setNumTopClasses(numTopClasses); + } + return builder.build(); + } + + @Override + public InferenceConfig toConfig() { + return apply(ClassificationConfig.EMPTY_PARAMS); + } + + @Override + public boolean isSupported(InferenceConfig inferenceConfig) { + return inferenceConfig instanceof ClassificationConfig; + } + + boolean isNoop(ClassificationConfig originalConfig) { + return (resultsField == null || resultsField.equals(originalConfig.getResultsField())) + && (numTopFeatureImportanceValues == null + || originalConfig.getNumTopFeatureImportanceValues() == numTopFeatureImportanceValues) + && (topClassesResultsField == null || topClassesResultsField.equals(originalConfig.getTopClassesResultsField())) + && (numTopClasses == null || originalConfig.getNumTopClasses() == numTopClasses); + } + + public static class Builder { + private Integer numTopClasses; + private String topClassesResultsField; + private String resultsField; + private Integer numTopFeatureImportanceValues; + + public Builder setNumTopClasses(int numTopClasses) { + this.numTopClasses = numTopClasses; + return this; + } + + public Builder setTopClassesResultsField(String topClassesResultsField) { + this.topClassesResultsField = topClassesResultsField; + return this; + } + + public Builder setResultsField(String resultsField) { + this.resultsField = resultsField; + return this; + } + + public Builder setNumTopFeatureImportanceValues(int numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + + public ClassificationConfigUpdate build() { + return new ClassificationConfigUpdate(numTopClasses, resultsField, topClassesResultsField, numTopFeatureImportanceValues); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigUpdate.java new file mode 100644 index 00000000000..72c8cff3c24 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigUpdate.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + + +public interface InferenceConfigUpdate extends NamedXContentObject, NamedWriteable { + + T apply(T originalConfig); + + InferenceConfig toConfig(); + + boolean isSupported(InferenceConfig config); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LenientlyParsedInferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LenientlyParsedInferenceConfig.java new file mode 100644 index 00000000000..501a9dff1ad --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LenientlyParsedInferenceConfig.java @@ -0,0 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +public interface LenientlyParsedInferenceConfig extends InferenceConfig { +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java index 4c8244c734c..ad5cfa4244e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java @@ -9,48 +9,42 @@ import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; import java.util.Objects; -import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; - -public class RegressionConfig implements InferenceConfig { +public class RegressionConfig implements LenientlyParsedInferenceConfig, StrictlyParsedInferenceConfig { public static final ParseField NAME = new ParseField("regression"); private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0; public static final ParseField RESULTS_FIELD = new ParseField("results_field"); public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); - private static final String DEFAULT_RESULTS_FIELD = "predicted_value"; + public static final String DEFAULT_RESULTS_FIELD = "predicted_value"; public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD, null); - public static RegressionConfig fromMap(Map map) { - Map options = new HashMap<>(map); - String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName()); - Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); - if (options.isEmpty() == false) { - throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet()); - } - return new RegressionConfig(resultsField, featureImportance); + private static final ObjectParser LENIENT_PARSER = createParser(true); + private static final ObjectParser STRICT_PARSER = createParser(false); + + private static ObjectParser createParser(boolean lenient) { + ObjectParser parser = new ObjectParser<>( + NAME.getPreferredName(), + lenient, + RegressionConfig.Builder::new); + parser.declareString(RegressionConfig.Builder::setResultsField, RESULTS_FIELD); + parser.declareInt(RegressionConfig.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES); + return parser; } - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0], (Integer)args[1])); - - static { - PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD); - PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); + public static RegressionConfig fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null).build(); } - public static RegressionConfig fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); + public static RegressionConfig fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null).build(); } private final String resultsField; @@ -113,9 +107,7 @@ public class RegressionConfig implements InferenceConfig { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(RESULTS_FIELD.getPreferredName(), resultsField); - if (numTopFeatureImportanceValues > 0) { - builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); - } + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); builder.endObject(); return builder; } @@ -141,7 +133,36 @@ public class RegressionConfig implements InferenceConfig { @Override public Version getMinimalSupportedVersion() { - return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION; + return requestingImportance() ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION; } + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String resultsField; + private Integer numTopFeatureImportanceValues; + + Builder() {} + + Builder(RegressionConfig config) { + this.resultsField = config.resultsField; + this.numTopFeatureImportanceValues = config.numTopFeatureImportanceValues; + } + + public Builder setResultsField(String resultsField) { + this.resultsField = resultsField; + return this; + } + + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + + public RegressionConfig build() { + return new RegressionConfig(resultsField, numTopFeatureImportanceValues); + } + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdate.java new file mode 100644 index 00000000000..b9def23de3b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdate.java @@ -0,0 +1,178 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.DEFAULT_RESULTS_FIELD; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.RESULTS_FIELD; + +public class RegressionConfigUpdate implements InferenceConfigUpdate { + + public static final ParseField NAME = new ParseField("regression"); + + public static RegressionConfigUpdate EMPTY_PARAMS = new RegressionConfigUpdate(null, null); + + public static RegressionConfigUpdate fromMap(Map map) { + Map options = new HashMap<>(map); + String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName()); + Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); + if (options.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet()); + } + return new RegressionConfigUpdate(resultsField, featureImportance); + } + + public static RegressionConfigUpdate fromConfig(RegressionConfig config) { + return new RegressionConfigUpdate(config.getResultsField(), config.getNumTopFeatureImportanceValues()); + } + + private static final ObjectParser STRICT_PARSER = createParser(false); + + private static ObjectParser createParser(boolean lenient) { + ObjectParser parser = new ObjectParser<>( + NAME.getPreferredName(), + lenient, + RegressionConfigUpdate.Builder::new); + parser.declareString(RegressionConfigUpdate.Builder::setResultsField, RESULTS_FIELD); + parser.declareInt(RegressionConfigUpdate.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES); + return parser; + } + + public static RegressionConfigUpdate fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null).build(); + } + + private final String resultsField; + private final Integer numTopFeatureImportanceValues; + + public RegressionConfigUpdate(String resultsField, Integer numTopFeatureImportanceValues) { + this.resultsField = resultsField; + if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) { + throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() + + "] must be greater than or equal to 0"); + } + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + } + + public RegressionConfigUpdate(StreamInput in) throws IOException { + this.resultsField = in.readOptionalString(); + this.numTopFeatureImportanceValues = in.readOptionalVInt(); + } + + public int getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues == null ? 0 : numTopFeatureImportanceValues; + } + + public String getResultsField() { + return resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField; + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(resultsField); + out.writeOptionalVInt(numTopFeatureImportanceValues); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RegressionConfigUpdate that = (RegressionConfigUpdate)o; + return Objects.equals(this.resultsField, that.resultsField) + && Objects.equals(this.numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); + } + + @Override + public int hashCode() { + return Objects.hash(resultsField, numTopFeatureImportanceValues); + } + + @Override + public RegressionConfig apply(RegressionConfig originalConfig) { + if (isNoop(originalConfig)) { + return originalConfig; + } + RegressionConfig.Builder builder = new RegressionConfig.Builder(originalConfig); + if (resultsField != null) { + builder.setResultsField(resultsField); + } + if (numTopFeatureImportanceValues != null) { + builder.setNumTopFeatureImportanceValues(numTopFeatureImportanceValues); + } + return builder.build(); + } + + @Override + public InferenceConfig toConfig() { + return apply(RegressionConfig.EMPTY_PARAMS); + } + + @Override + public boolean isSupported(InferenceConfig inferenceConfig) { + return inferenceConfig instanceof RegressionConfig; + } + + boolean isNoop(RegressionConfig originalConfig) { + return (resultsField == null || originalConfig.getResultsField().equals(resultsField)) + && (numTopFeatureImportanceValues == null + || originalConfig.getNumTopFeatureImportanceValues() == numTopFeatureImportanceValues); + } + + public static class Builder { + private String resultsField; + private Integer numTopFeatureImportanceValues; + + public Builder setResultsField(String resultsField) { + this.resultsField = resultsField; + return this; + } + + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + + public RegressionConfigUpdate build() { + return new RegressionConfigUpdate(resultsField, numTopFeatureImportanceValues); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/StrictlyParsedInferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/StrictlyParsedInferenceConfig.java new file mode 100644 index 00000000000..e6ea5129b31 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/StrictlyParsedInferenceConfig.java @@ -0,0 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +public interface StrictlyParsedInferenceConfig extends InferenceConfig { +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java index a7a6d22ae3e..d4906173674 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java @@ -41,4 +41,14 @@ public final class NamedXContentObjectHelper { } return builder; } + + public static XContentBuilder writeNamedObject(XContentBuilder builder, + ToXContent.Params params, + String namedObjectName, + NamedXContentObject namedObject) throws IOException { + builder.startObject(namedObjectName); + builder.field(namedObject.getName(), namedObject, params); + builder.endObject(); + return builder; + } } diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json index 2a1131ca557..a77a0119e95 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json @@ -2,7 +2,7 @@ "order" : 0, "version" : ${xpack.ml.version.id}, "index_patterns" : [ - ".ml-inference-000001" + ".ml-inference-000002" ], "settings" : { "index" : { @@ -67,6 +67,9 @@ }, "default_field_map": { "enabled": false + }, + "inference_config": { + "enabled": false } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCSerializationTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCSerializationTestCase.java index f34f8fc008f..01351f055f1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCSerializationTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCSerializationTestCase.java @@ -11,22 +11,12 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; -import java.util.Collections; import java.util.List; -import java.util.stream.Collectors; -import static org.elasticsearch.Version.getDeclaredVersions; +import static org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase.DEFAULT_BWC_VERSIONS; public abstract class AbstractBWCSerializationTestCase extends AbstractSerializingTestCase { - private static final List ALL_VERSIONS = Collections.unmodifiableList(getDeclaredVersions(Version.class)); - - public static List getAllBWCVersions(Version version) { - return ALL_VERSIONS.stream().filter(v -> v.before(version) && version.isCompatible(v)).collect(Collectors.toList()); - } - - private static final List DEFAULT_BWC_VERSIONS = getAllBWCVersions(Version.CURRENT); - /** * Returns the expected instance if serialized from the given version. */ diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCWireSerializationTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCWireSerializationTestCase.java new file mode 100644 index 00000000000..300370c3fc4 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/AbstractBWCWireSerializationTestCase.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.elasticsearch.Version.getDeclaredVersions; + +public abstract class AbstractBWCWireSerializationTestCase extends AbstractWireSerializingTestCase { + + static final List ALL_VERSIONS = Collections.unmodifiableList(getDeclaredVersions(Version.class)); + + public static List getAllBWCVersions(Version version) { + return ALL_VERSIONS.stream().filter(v -> v.before(version) && version.isCompatible(v)).collect(Collectors.toList()); + } + + static final List DEFAULT_BWC_VERSIONS = getAllBWCVersions(Version.CURRENT); + + /** + * Returns the expected instance if serialized from the given version. + */ + protected abstract T mutateInstanceForVersion(T instance, Version version); + + /** + * The bwc versions to test serialization against + */ + protected List bwcVersions() { + return DEFAULT_BWC_VERSIONS; + } + + /** + * Test serialization and deserialization of the test instance across versions + */ + public final void testBwcSerialization() throws IOException { + for (int runs = 0; runs < NUMBER_OF_TEST_RUNS; runs++) { + T testInstance = createTestInstance(); + for (Version bwcVersion : bwcVersions()) { + assertBwcSerialization(testInstance, bwcVersion); + } + } + } + + /** + * Assert that instances copied at a particular version are equal. The version is useful + * for sanity checking the backwards compatibility of the wire. It isn't a substitute for + * real backwards compatibility tests but it is *so* much faster. + */ + protected final void assertBwcSerialization(T testInstance, Version version) throws IOException { + T deserializedInstance = copyWriteable(testInstance, getNamedWriteableRegistry(), instanceReader(), version); + assertOnBWCObject(deserializedInstance, mutateInstanceForVersion(testInstance, version), version); + } + + /** + * @param bwcSerializedObject The object deserialized from the previous version + * @param testInstance The original test instance + * @param version The version which serialized + */ + protected void assertOnBWCObject(T bwcSerializedObject, T testInstance, Version version) { + assertNotSame(version.toString(), bwcSerializedObject, testInstance); + assertEquals(version.toString(), bwcSerializedObject, testInstance); + assertEquals(version.toString(), bwcSerializedObject.hashCode(), testInstance.hashCode()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java index d743a579f85..80693c94be9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java @@ -5,14 +5,21 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdateTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdateTests; import java.util.ArrayList; import java.util.List; @@ -22,25 +29,27 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class InternalInferModelActionRequestTests extends AbstractWireSerializingTestCase { +public class InternalInferModelActionRequestTests extends AbstractBWCWireSerializationTestCase { @Override + @SuppressWarnings("unchecked") protected Request createTestInstance() { return randomBoolean() ? new Request( randomAlphaOfLength(10), Stream.generate(InternalInferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()), - randomInferenceConfig(), + randomInferenceConfigUpdate(), randomBoolean()) : new Request( randomAlphaOfLength(10), randomMap(), - randomInferenceConfig(), + randomInferenceConfigUpdate(), randomBoolean()); } - private static InferenceConfig randomInferenceConfig() { - return randomFrom(RegressionConfigTests.randomRegressionConfig(), ClassificationConfigTests.randomClassificationConfig()); + private static InferenceConfigUpdate randomInferenceConfigUpdate() { + return randomFrom(RegressionConfigUpdateTests.randomRegressionConfig(), + ClassificationConfigUpdateTests.randomClassificationConfig()); } private static Map randomMap() { @@ -60,4 +69,26 @@ public class InternalInferModelActionRequestTests extends AbstractWireSerializin entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); return new NamedWriteableRegistry(entries); } + + @Override + @SuppressWarnings("unchecked") + protected Request mutateInstanceForVersion(Request instance, Version version) { + if (version.before(Version.V_7_8_0)) { + InferenceConfigUpdate update = null; + if (instance.getUpdate() instanceof ClassificationConfigUpdate) { + update = ClassificationConfigUpdate.fromConfig( + ClassificationConfigTests.mutateForVersion((ClassificationConfig) instance.getUpdate().toConfig(), version)); + } + else if (instance.getUpdate() instanceof RegressionConfigUpdate) { + update = RegressionConfigUpdate.fromConfig( + RegressionConfigTests.mutateForVersion((RegressionConfig) instance.getUpdate().toConfig(), version)); + } + else { + fail("unknown update type " + instance.getUpdate().getName()); + } + return new Request(instance.getModelId(), instance.getObjectsToInfer(), update, instance.isPreviouslyLicensed()); + } + return instance; + } + } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index ace9cad261f..06e004850f7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -21,7 +21,9 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.license.License; import org.elasticsearch.search.SearchModule; -import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.MlStrings; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; @@ -46,7 +48,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.not; -public class TrainedModelConfigTests extends AbstractSerializingTestCase { +public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase { private boolean lenient; @@ -66,6 +68,8 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase randomAlphaOfLength(10)) .limit(randomIntBetween(1, 10)) - .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10)))); + .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))), + randomFrom(ClassificationConfigTests.randomClassificationConfig(), RegressionConfigTests.randomRegressionConfig())); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); assertThat(reference.utf8ToString(), containsString("\"compressed_definition\"")); @@ -182,7 +187,8 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase randomAlphaOfLength(10)) .limit(randomIntBetween(1, 10)) - .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10)))); + .collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))), + randomFrom(ClassificationConfigTests.randomClassificationConfig(), RegressionConfigTests.randomRegressionConfig())); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); Map objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2(); @@ -311,4 +317,12 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase { +public class ClassificationConfigTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; public static ClassificationConfig randomClassificationConfig() { return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10), @@ -26,23 +25,17 @@ public class ClassificationConfigTests extends AbstractSerializingTestCase configMap = new HashMap<>(); - configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3); - configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo"); - configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar"); - configMap.put(ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 2); - assertThat(ClassificationConfig.fromMap(configMap), equalTo(expected)); + public static ClassificationConfig mutateForVersion(ClassificationConfig instance, Version version) { + ClassificationConfig.Builder builder = new ClassificationConfig.Builder(instance); + if (version.before(Version.V_7_7_0)) { + builder.setNumTopFeatureImportanceValues(0); + } + return builder.build(); } - public void testFromMapWithUnknownField() { - ElasticsearchException ex = expectThrows(ElasticsearchException.class, - () -> ClassificationConfig.fromMap(Collections.singletonMap("some_key", 1))); - assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); } @Override @@ -57,6 +50,16 @@ public class ClassificationConfigTests extends AbstractSerializingTestCase { + + public static ClassificationConfigUpdate randomClassificationConfig() { + return new ClassificationConfigUpdate(randomBoolean() ? null : randomIntBetween(-1, 10), + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomIntBetween(0, 10) + ); + } + + public void testFromMap() { + ClassificationConfigUpdate expected = new ClassificationConfigUpdate(null, null, null, null); + assertThat(ClassificationConfigUpdate.fromMap(Collections.emptyMap()), equalTo(expected)); + + expected = new ClassificationConfigUpdate(3, "foo", "bar", 2); + Map configMap = new HashMap<>(); + configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3); + configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo"); + configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar"); + configMap.put(ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 2); + assertThat(ClassificationConfigUpdate.fromMap(configMap), equalTo(expected)); + } + + public void testFromMapWithUnknownField() { + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> ClassificationConfigUpdate.fromMap(Collections.singletonMap("some_key", 1))); + assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + } + + @Override + protected ClassificationConfigUpdate createTestInstance() { + return randomClassificationConfig(); + } + + @Override + protected Writeable.Reader instanceReader() { + return ClassificationConfigUpdate::new; + } + + @Override + protected ClassificationConfigUpdate doParseInstance(XContentParser parser) throws IOException { + return ClassificationConfigUpdate.fromXContentStrict(parser); + } + + @Override + protected ClassificationConfigUpdate mutateInstanceForVersion(ClassificationConfigUpdate instance, Version version) { + return instance; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java index ba8b8b62cd5..57dfa18520a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java @@ -5,37 +5,32 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; -import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.junit.Before; import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import static org.hamcrest.Matchers.equalTo; - -public class RegressionConfigTests extends AbstractSerializingTestCase { +public class RegressionConfigTests extends AbstractBWCSerializationTestCase { + private boolean lenient; public static RegressionConfig randomRegressionConfig() { return new RegressionConfig(randomBoolean() ? null : randomAlphaOfLength(10)); } - public void testFromMap() { - RegressionConfig expected = new RegressionConfig("foo", 3); - Map config = new HashMap(){{ - put(RegressionConfig.RESULTS_FIELD.getPreferredName(), "foo"); - put(RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 3); - }}; - assertThat(RegressionConfig.fromMap(config), equalTo(expected)); + public static RegressionConfig mutateForVersion(RegressionConfig instance, Version version) { + RegressionConfig.Builder builder = new RegressionConfig.Builder(instance); + if (version.before(Version.V_7_7_0)) { + builder.setNumTopFeatureImportanceValues(0); + } + return builder.build(); } - public void testFromMapWithUnknownField() { - ElasticsearchException ex = expectThrows(ElasticsearchException.class, - () -> RegressionConfig.fromMap(Collections.singletonMap("some_key", 1))); - assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); } @Override @@ -50,6 +45,16 @@ public class RegressionConfigTests extends AbstractSerializingTestCase { + + public static RegressionConfigUpdate randomRegressionConfig() { + return new RegressionConfigUpdate(randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : randomIntBetween(0, 10)); + } + + public void testFromMap() { + RegressionConfigUpdate expected = new RegressionConfigUpdate("foo", 3); + Map config = new HashMap(){{ + put(RegressionConfig.RESULTS_FIELD.getPreferredName(), "foo"); + put(RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 3); + }}; + assertThat(RegressionConfigUpdate.fromMap(config), equalTo(expected)); + } + + public void testFromMapWithUnknownField() { + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> RegressionConfigUpdate.fromMap(Collections.singletonMap("some_key", 1))); + assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + } + + @Override + protected RegressionConfigUpdate createTestInstance() { + return randomRegressionConfig(); + } + + @Override + protected Writeable.Reader instanceReader() { + return RegressionConfigUpdate::new; + } + + @Override + protected RegressionConfigUpdate doParseInstance(XContentParser parser) throws IOException { + return RegressionConfigUpdate.fromXContentStrict(parser); + } + + @Override + protected RegressionConfigUpdate mutateInstanceForVersion(RegressionConfigUpdate instance, Version version) { + return instance; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index e6c368fefa4..4e2a5398d79 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -65,6 +65,10 @@ public class EnsembleTests extends AbstractSerializingTestCase { } public static Ensemble createRandom() { + return createRandom(randomFrom(TargetType.values())); + } + + public static Ensemble createRandom(TargetType targetType) { int numberOfFeatures = randomIntBetween(1, 10); List featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList()); int numberOfModels = randomIntBetween(1, 10); @@ -74,7 +78,6 @@ public class EnsembleTests extends AbstractSerializingTestCase { double[] weights = randomBoolean() ? null : Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).mapToDouble(Double::valueOf).toArray(); - TargetType targetType = randomFrom(TargetType.values()); List categoryLabels = null; if (randomBoolean() && targetType == TargetType.CLASSIFICATION) { categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index d801dec470e..dd418c57cfa 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -64,16 +64,20 @@ public class TreeTests extends AbstractSerializingTestCase { return createRandom(); } - public static Tree createRandom() { + public static Tree createRandom(TargetType targetType) { int numberOfFeatures = randomIntBetween(1, 10); List featureNames = new ArrayList<>(); for (int i = 0; i < numberOfFeatures; i++) { featureNames.add(randomAlphaOfLength(10)); } - return buildRandomTree(featureNames, 6); + return buildRandomTree(targetType, featureNames, 6); } - public static Tree buildRandomTree(List featureNames, int depth) { + public static Tree createRandom() { + return createRandom(randomFrom(TargetType.values())); + } + + public static Tree buildRandomTree(TargetType targetType, List featureNames, int depth) { Tree.Builder builder = Tree.builder(); int maxFeatureIndex = featureNames.size() - 1; builder.setFeatureNames(featureNames); @@ -96,7 +100,6 @@ public class TreeTests extends AbstractSerializingTestCase { } childNodes = nextNodes; } - TargetType targetType = randomFrom(TargetType.values()); List categoryLabels = null; if (randomBoolean() && targetType == TargetType.CLASSIFICATION) { categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); @@ -105,6 +108,10 @@ public class TreeTests extends AbstractSerializingTestCase { return builder.setTargetType(targetType).setClassificationLabels(categoryLabels).build(); } + public static Tree buildRandomTree(List featureNames, int depth) { + return buildRandomTree(randomFrom(TargetType.values()), featureNames, depth); + } + @Override protected Writeable.Reader instanceReader() { return Tree::new; diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 1863bb096ee..5eebec9559d 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -141,9 +141,11 @@ integTest.runner { 'ml/inference_crud/Test put ensemble with empty models', 'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index', 'ml/inference_crud/Test put model with empty input.field_names', + 'ml/inference_crud/Test PUT model where target type and inference config mismatch', 'ml/inference_processor/Test create processor with missing mandatory fields', 'ml/inference_processor/Test create and delete pipeline with inference processor', 'ml/inference_processor/Test create processor with deprecated fields', + 'ml/inference_processor/Test simulate', 'ml/inference_stats_crud/Test get stats given missing trained model', 'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false', 'ml/jobs_crud/Test cannot create job with existing categorizer state document', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index 12a0f158b34..2f79a540e08 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -424,6 +424,7 @@ public class InferenceIngestIT extends ESRestTestCase { private static final String REGRESSION_CONFIG = "{" + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"description\": \"test model for regression\",\n" + + " \"inference_config\": {\"regression\": {}},\n" + " \"definition\": " + REGRESSION_DEFINITION + "}"; @@ -564,6 +565,7 @@ public class InferenceIngestIT extends ESRestTestCase { " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"description\": \"test model for classification\",\n" + " \"default_field_map\": {\"col_1_alias\": \"col1\"},\n" + + " \"inference_config\": {\"classification\": {}},\n" + " \"definition\": " + CLASSIFICATION_DEFINITION + "}"; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index afa5f1c6dcb..db27ce4aeeb 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -12,6 +12,7 @@ import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.ml.inference.TrainedModelConfig; import org.elasticsearch.client.ml.inference.TrainedModelDefinition; import org.elasticsearch.client.ml.inference.TrainedModelInput; +import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; @@ -193,6 +194,7 @@ public class TrainedModelIT extends ESRestTestCase { .setTrainedModel(buildRegression()); TrainedModelConfig.builder() .setDefinition(definition) + .setInferenceConfig(new RegressionConfig()) .setModelId(modelId) .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3"))) .build().toXContent(builder, ToXContent.EMPTY_PARAMS); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index a32367a6766..40a74d7405a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.ml.inference.loadingservice.Model; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -48,11 +49,12 @@ public class TransportInternalInferModelAction extends HandledTransportAction listener) { Response.Builder responseBuilder = Response.builder(); - ActionListener getModelListener = ActionListener.wrap( + ActionListener> getModelListener = ActionListener.wrap( model -> { TypedChainTaskExecutor typedChainTaskExecutor = new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME), @@ -62,7 +64,9 @@ public class TransportInternalInferModelAction extends HandledTransportAction true); request.getObjectsToInfer().forEach(stringObjectMap -> typedChainTaskExecutor.add(chainedTask -> - model.infer(stringObjectMap, request.getConfig(), chainedTask))); + // The InferenceConfigUpdate here is unchecked, initially. + // It gets checked when it is applied + model.infer(stringObjectMap, request.getUpdate(), chainedTask))); typedChainTaskExecutor.execute(ActionListener.wrap( inferenceResultsInterfaces -> diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index fc6d5b92f36..503393a836a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -93,6 +93,22 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction inferenceConfig; private final Map fieldMap; private final InferenceAuditor auditor; private volatile boolean previouslyLicensed; @@ -82,7 +85,7 @@ public class InferenceProcessor extends AbstractProcessor { String tag, String targetField, String modelId, - InferenceConfig inferenceConfig, + InferenceConfigUpdate inferenceConfig, Map fieldMap) { super(tag); this.client = ExceptionsHelper.requireNonNull(client, "client"); @@ -245,7 +248,8 @@ public class InferenceProcessor extends AbstractProcessor { LoggingDeprecationHandler.INSTANCE.usedDeprecatedName(null, () -> null, FIELD_MAPPINGS, FIELD_MAP); } } - InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG)); + InferenceConfigUpdate inferenceConfig = + inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG)); return new InferenceProcessor(client, auditor, @@ -262,7 +266,7 @@ public class InferenceProcessor extends AbstractProcessor { this.maxIngestProcessors = maxIngestProcessors; } - InferenceConfig inferenceConfigFromMap(Map inferenceConfig) { + InferenceConfigUpdate inferenceConfigFromMap(Map inferenceConfig) { ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); if (inferenceConfig.size() != 1) { throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.", @@ -279,12 +283,12 @@ public class InferenceProcessor extends AbstractProcessor { if (inferenceConfig.containsKey(ClassificationConfig.NAME.getPreferredName())) { checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS); - ClassificationConfig config = ClassificationConfig.fromMap(valueMap); + ClassificationConfigUpdate config = ClassificationConfigUpdate.fromMap(valueMap); checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField()); return config; } else if (inferenceConfig.containsKey(RegressionConfig.NAME.getPreferredName())) { checkSupportedVersion(RegressionConfig.EMPTY_PARAMS); - RegressionConfig config = RegressionConfig.fromMap(valueMap); + RegressionConfigUpdate config = RegressionConfigUpdate.fromMap(valueMap); checkFieldUniqueness(config.getResultsField()); return config; } else { @@ -298,6 +302,9 @@ public class InferenceProcessor extends AbstractProcessor { Set duplicatedFieldNames = new HashSet<>(); Set currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES); for(String fieldName : fieldNames) { + if (fieldName == null) { + continue; + } if (currentFieldNames.contains(fieldName)) { duplicatedFieldNames.add(fieldName); } else { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index f77188853ad..fa436e0c53f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -10,6 +10,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; @@ -24,21 +25,24 @@ import java.util.Set; import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING; -public class LocalModel implements Model { +public class LocalModel implements Model { private final TrainedModelDefinition trainedModelDefinition; private final String modelId; private final Set fieldNames; private final Map defaultFieldMap; + private final T inferenceConfig; public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition, TrainedModelInput input, - Map defaultFieldMap) { + Map defaultFieldMap, + T modelInferenceConfig) { this.trainedModelDefinition = trainedModelDefinition; this.modelId = modelId; this.fieldNames = new HashSet<>(input.getFieldNames()); this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap); + this.inferenceConfig = modelInferenceConfig; } long ramBytesUsed() { @@ -65,7 +69,15 @@ public class LocalModel implements Model { } @Override - public void infer(Map fields, InferenceConfig config, ActionListener listener) { + public void infer(Map fields, InferenceConfigUpdate update, ActionListener listener) { + if (update.isSupported(this.inferenceConfig) == false) { + listener.onFailure(ExceptionsHelper.badRequestException( + "Model [{}] has inference config of type [{}] which is not supported by inference request of type [{}]", + this.modelId, + this.inferenceConfig.getName(), + update.getName())); + return; + } try { Model.mapFieldsIfNecessary(fields, defaultFieldMap); if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) { @@ -73,7 +85,7 @@ public class LocalModel implements Model { return; } - listener.onResponse(trainedModelDefinition.infer(fields, config)); + listener.onResponse(trainedModelDefinition.infer(fields, update.apply(inferenceConfig))); } catch (Exception e) { listener.onFailure(e); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java index f1df6908d63..c11d735e26a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -8,15 +8,16 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.ActionListener; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.MapHelper; import java.util.Map; -public interface Model { +public interface Model { String getResultsType(); - void infer(Map fields, InferenceConfig inferenceConfig, ActionListener listener); + void infer(Map fields, InferenceConfigUpdate inferenceConfig, ActionListener listener); String getModelId(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 71cfe710821..b8ecfa424f6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -27,6 +27,11 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -78,9 +83,9 @@ public class ModelLoadingService implements ClusterStateListener { Setting.Property.NodeScope); private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); - private final Cache localModelCache; + private final Cache> localModelCache; private final Set referencedModels = new HashSet<>(); - private final Map>> loadingListeners = new HashMap<>(); + private final Map>>> loadingListeners = new HashMap<>(); private final TrainedModelProvider provider; private final Set shouldNotAudit; private final ThreadPool threadPool; @@ -100,7 +105,7 @@ public class ModelLoadingService implements ClusterStateListener { this.auditor = auditor; this.shouldNotAudit = new HashSet<>(); this.namedXContentRegistry = namedXContentRegistry; - this.localModelCache = CacheBuilder.builder() + this.localModelCache = CacheBuilder.>builder() .setMaximumWeight(this.maxCacheSize.getBytes()) .weigher((id, localModel) -> localModel.ramBytesUsed()) .removalListener(this::cacheEvictionListener) @@ -126,8 +131,8 @@ public class ModelLoadingService implements ClusterStateListener { * @param modelId the model to get * @param modelActionListener the listener to alert when the model has been retrieved. */ - public void getModel(String modelId, ActionListener modelActionListener) { - LocalModel cachedModel = localModelCache.get(modelId); + public void getModel(String modelId, ActionListener> modelActionListener) { + LocalModel cachedModel = localModelCache.get(modelId); if (cachedModel != null) { modelActionListener.onResponse(cachedModel); logger.trace("[{}] loaded from cache", modelId); @@ -138,12 +143,18 @@ public class ModelLoadingService implements ClusterStateListener { // by a simulated pipeline logger.trace("[{}] not actively loading, eager loading without cache", modelId); provider.getTrainedModel(modelId, true, ActionListener.wrap( - trainedModelConfig -> - modelActionListener.onResponse(new LocalModel( + trainedModelConfig -> { + trainedModelConfig.ensureParsedDefinition(namedXContentRegistry); + InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? + inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) : + trainedModelConfig.getInferenceConfig(); + modelActionListener.onResponse(new LocalModel<>( trainedModelConfig.getModelId(), - trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(), + trainedModelConfig.getModelDefinition(), trainedModelConfig.getInput(), - trainedModelConfig.getDefaultFieldMap())), + trainedModelConfig.getDefaultFieldMap(), + inferenceConfig)); + }, modelActionListener::onFailure )); } else { @@ -156,9 +167,9 @@ public class ModelLoadingService implements ClusterStateListener { * Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded * Returns false if the model is not loaded or actively being loaded */ - private boolean loadModelIfNecessary(String modelId, ActionListener modelActionListener) { + private boolean loadModelIfNecessary(String modelId, ActionListener> modelActionListener) { synchronized (loadingListeners) { - Model cachedModel = localModelCache.get(modelId); + Model cachedModel = localModelCache.get(modelId); if (cachedModel != null) { modelActionListener.onResponse(cachedModel); return true; @@ -197,12 +208,17 @@ public class ModelLoadingService implements ClusterStateListener { } private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelConfig) throws IOException { - Queue> listeners; - LocalModel loadedModel = new LocalModel( + Queue>> listeners; + trainedModelConfig.ensureParsedDefinition(namedXContentRegistry); + InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? + inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) : + trainedModelConfig.getInferenceConfig(); + LocalModel loadedModel = new LocalModel<>( trainedModelConfig.getModelId(), - trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(), + trainedModelConfig.getModelDefinition(), trainedModelConfig.getInput(), - trainedModelConfig.getDefaultFieldMap()); + trainedModelConfig.getDefaultFieldMap(), + inferenceConfig); synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); // If there is no loadingListener that means the loading was canceled and the listener was already notified as such @@ -213,13 +229,13 @@ public class ModelLoadingService implements ClusterStateListener { localModelCache.put(modelId, loadedModel); shouldNotAudit.remove(modelId); } // synchronized (loadingListeners) - for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + for (ActionListener> listener = listeners.poll(); listener != null; listener = listeners.poll()) { listener.onResponse(loadedModel); } } private void handleLoadFailure(String modelId, Exception failure) { - Queue> listeners; + Queue>> listeners; synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); if (listeners == null) { @@ -228,12 +244,12 @@ public class ModelLoadingService implements ClusterStateListener { } // synchronized (loadingListeners) // If we failed to load and there were listeners present, that means that this model is referenced by a processor // Alert the listeners to the failure - for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + for (ActionListener> listener = listeners.poll(); listener != null; listener = listeners.poll()) { listener.onFailure(failure); } } - private void cacheEvictionListener(RemovalNotification notification) { + private void cacheEvictionListener(RemovalNotification> notification) { if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) { String msg = new ParameterizedMessage( "model cache entry evicted." + @@ -263,7 +279,7 @@ public class ModelLoadingService implements ClusterStateListener { return; } // The listeners still waiting for a model and we are canceling the load? - List>>> drainWithFailure = new ArrayList<>(); + List>>>> drainWithFailure = new ArrayList<>(); Set referencedModelsBeforeClusterState = null; Set loadingModelBeforeClusterState = null; Set removedModels = null; @@ -306,11 +322,11 @@ public class ModelLoadingService implements ClusterStateListener { referencedModels); } } - for (Tuple>> modelAndListeners : drainWithFailure) { + for (Tuple>>> modelAndListeners : drainWithFailure) { final String msg = new ParameterizedMessage( "Cancelling load of model [{}] as it is no longer referenced by a pipeline", modelAndListeners.v1()).getFormat(); - for (ActionListener listener : modelAndListeners.v2()) { + for (ActionListener> listener : modelAndListeners.v2()) { listener.onFailure(new ElasticsearchException(msg)); } } @@ -379,4 +395,14 @@ public class ModelLoadingService implements ClusterStateListener { return allReferencedModelKeys; } + private static InferenceConfig inferenceConfigFromTargetType(TargetType targetType) { + switch(targetType) { + case REGRESSION: + return RegressionConfig.EMPTY_PARAMS; + case CLASSIFICATION: + return ClassificationConfig.EMPTY_PARAMS; + default: + throw ExceptionsHelper.badRequestException("unsupported target type [{}]", targetType); + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java index 23e32174361..80620c8eed2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java @@ -45,6 +45,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; @@ -694,7 +695,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase { client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request( modelId, Collections.singletonList(Collections.emptyMap()), - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, false ), inferModelSuccess); InternalInferModelAction.Response response = inferModelSuccess.actionGet(); @@ -711,7 +712,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase { client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request( modelId, Collections.singletonList(Collections.emptyMap()), - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, false )).actionGet(); }); @@ -724,7 +725,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase { client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request( modelId, Collections.singletonList(Collections.emptyMap()), - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, true ), inferModelSuccess); response = inferModelSuccess.actionGet(); @@ -740,7 +741,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase { client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request( modelId, Collections.singletonList(Collections.emptyMap()), - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, false ), listener); assertThat(listener.actionGet().getInferenceResults(), is(not(empty()))); @@ -760,6 +761,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase { .setModelId(modelId) .setDescription("test model for classification") .setInput(new TrainedModelInput(Arrays.asList("feature1"))) + .setInferenceConfig(RegressionConfig.EMPTY_PARAMS) .build(); client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet(); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index ee2b399ef2d..1d9c748f376 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; @@ -168,7 +169,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { extractedFieldList.add(new DocValueField("foo", Collections.emptySet())); extractedFieldList.add(new MultiField("bar", new DocValueField("bar.keyword", Collections.emptySet()))); extractedFieldList.add(new DocValueField("baz", Collections.emptySet())); - TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(); + TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; + TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList); @@ -190,6 +192,11 @@ public class AnalyticsResultProcessorTests extends ESTestCase { assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar.keyword", "baz"))); assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed())); assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations())); + if (targetType.equals(TargetType.CLASSIFICATION)) { + assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification")); + } else { + assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression")); + } Map metadata = storedModel.getMetadata(); assertThat(metadata.size(), equalTo(1)); assertThat(metadata, hasKey("analytics_config")); @@ -213,7 +220,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase { return null; }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); - TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(); + TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; + TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 9b6dfd734d0..81b842047ff 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -14,7 +14,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.junit.Before; @@ -53,7 +55,7 @@ public class InferenceProcessorTests extends ESTestCase { "my_processor", targetField, "classification_model", - ClassificationConfig.EMPTY_PARAMS, + ClassificationConfigUpdate.EMPTY_PARAMS, Collections.emptyMap()); Map source = new HashMap<>(); @@ -75,13 +77,14 @@ public class InferenceProcessorTests extends ESTestCase { @SuppressWarnings("unchecked") public void testMutateDocumentClassificationTopNClasses() { - ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null); + ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, null); + ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, null); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, auditor, "my_processor", "ml.my_processor", "classification_model", - classificationConfig, + classificationConfigUpdate, Collections.emptyMap()); Map source = new HashMap<>(); @@ -105,12 +108,13 @@ public class InferenceProcessorTests extends ESTestCase { public void testMutateDocumentClassificationFeatureInfluence() { ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2); + ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, 2); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, auditor, "my_processor", "ml.my_processor", "classification_model", - classificationConfig, + classificationConfigUpdate, Collections.emptyMap()); Map source = new HashMap<>(); @@ -145,12 +149,13 @@ public class InferenceProcessorTests extends ESTestCase { @SuppressWarnings("unchecked") public void testMutateDocumentClassificationTopNClassesWithSpecificField() { ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops"); + ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, "result", "tops", null); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, auditor, "my_processor", "ml.my_processor", "classification_model", - classificationConfig, + classificationConfigUpdate, Collections.emptyMap()); Map source = new HashMap<>(); @@ -174,12 +179,13 @@ public class InferenceProcessorTests extends ESTestCase { public void testMutateDocumentRegression() { RegressionConfig regressionConfig = new RegressionConfig("foo"); + RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", null); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, auditor, "my_processor", "ml.my_processor", "regression_model", - regressionConfig, + regressionConfigUpdate, Collections.emptyMap()); Map source = new HashMap<>(); @@ -196,12 +202,13 @@ public class InferenceProcessorTests extends ESTestCase { public void testMutateDocumentRegressionWithTopFetures() { RegressionConfig regressionConfig = new RegressionConfig("foo", 2); + RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", 2); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, auditor, "my_processor", "ml.my_processor", "regression_model", - regressionConfig, + regressionConfigUpdate, Collections.emptyMap()); Map source = new HashMap<>(); @@ -233,7 +240,7 @@ public class InferenceProcessorTests extends ESTestCase { "my_processor", "my_field", modelId, - new ClassificationConfig(topNClasses, null, null), + new ClassificationConfigUpdate(topNClasses, null, null, null), Collections.emptyMap()); Map source = new HashMap(){{ @@ -262,7 +269,7 @@ public class InferenceProcessorTests extends ESTestCase { "my_processor", "my_field", modelId, - new ClassificationConfig(topNClasses, null, null), + new ClassificationConfigUpdate(topNClasses, null, null, null), fieldMapping); Map source = new HashMap(5){{ @@ -298,7 +305,7 @@ public class InferenceProcessorTests extends ESTestCase { "my_processor", "my_field", modelId, - new ClassificationConfig(topNClasses, null, null), + new ClassificationConfigUpdate(topNClasses, null, null, null), fieldMapping); Map source = new HashMap(5){{ @@ -326,7 +333,7 @@ public class InferenceProcessorTests extends ESTestCase { "my_processor", targetField, "regression_model", - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, Collections.emptyMap()); Map source = new HashMap<>(); @@ -369,7 +376,7 @@ public class InferenceProcessorTests extends ESTestCase { "my_processor", "ml", "regression_model", - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, Collections.emptyMap()); Map source = new HashMap<>(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 1f91f8724a6..a12f2821d97 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -13,8 +13,11 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; @@ -31,7 +34,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.ExecutionException; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.closeTo; @@ -48,22 +50,23 @@ public class LocalModelTests extends ESTestCase { .setTrainedModel(buildClassification(false)) .build(); - Model model = new LocalModel(modelId, + Model model = new LocalModel<>(modelId, definition, new TrainedModelInput(inputFields), - Collections.singletonMap("field.foo", "field.foo.keyword")); + Collections.singletonMap("field.foo", "field.foo.keyword"), + ClassificationConfig.EMPTY_PARAMS); Map fields = new HashMap() {{ put("field.foo", 1.0); put("field.bar", 0.5); put("categorical", "dog"); }}; - SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfig(0)); + SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null)); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), is("0")); ClassificationInferenceResults classificationResult = - (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1)); + (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfigUpdate(1, null, null, null)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0")); @@ -72,22 +75,29 @@ public class LocalModelTests extends ESTestCase { .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildClassification(true)) .build(); - model = new LocalModel(modelId, + model = new LocalModel<>(modelId, definition, new TrainedModelInput(inputFields), - Collections.singletonMap("field.foo", "field.foo.keyword")); - result = getSingleValue(model, fields, new ClassificationConfig(0)); + Collections.singletonMap("field.foo", "field.foo.keyword"), + ClassificationConfig.EMPTY_PARAMS); + result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null)); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), equalTo("not_to_be")); - classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1)); + classificationResult = (ClassificationInferenceResults)getSingleValue(model, + fields, + new ClassificationConfigUpdate(1, null, null, null)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); - classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(2)); + classificationResult = (ClassificationInferenceResults)getSingleValue(model, + fields, + new ClassificationConfigUpdate(2, null, null, null)); assertThat(classificationResult.getTopClasses(), hasSize(2)); - classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(-1)); + classificationResult = (ClassificationInferenceResults)getSingleValue(model, + fields, + new ClassificationConfigUpdate(-1, null, null, null)); assertThat(classificationResult.getTopClasses(), hasSize(2)); } @@ -97,10 +107,11 @@ public class LocalModelTests extends ESTestCase { .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildRegression()) .build(); - Model model = new LocalModel("regression_model", + Model model = new LocalModel<>("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields), - Collections.singletonMap("bar", "bar.keyword")); + Collections.singletonMap("bar", "bar.keyword"), + RegressionConfig.EMPTY_PARAMS); Map fields = new HashMap() {{ put("foo", 1.0); @@ -108,14 +119,8 @@ public class LocalModelTests extends ESTestCase { put("categorical", "dog"); }}; - SingleValueInferenceResults results = getSingleValue(model, fields, RegressionConfig.EMPTY_PARAMS); + SingleValueInferenceResults results = getSingleValue(model, fields, RegressionConfigUpdate.EMPTY_PARAMS); assertThat(results.value(), equalTo(1.3)); - - PlainActionFuture failedFuture = new PlainActionFuture<>(); - model.infer(fields, new ClassificationConfig(2), failedFuture); - ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get); - assertThat(ex.getCause().getMessage(), - equalTo("Cannot infer using configuration for [classification] when model target_type is [regression]")); } public void testAllFieldsMissing() throws Exception { @@ -124,7 +129,12 @@ public class LocalModelTests extends ESTestCase { .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildRegression()) .build(); - Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields), null); + Model model = new LocalModel<>( + "regression_model", + trainedModelDefinition, + new TrainedModelInput(inputFields), + null, + RegressionConfig.EMPTY_PARAMS); Map fields = new HashMap() {{ put("something", 1.0); @@ -132,18 +142,21 @@ public class LocalModelTests extends ESTestCase { put("baz", "dog"); }}; - WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfig.EMPTY_PARAMS); + WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfigUpdate.EMPTY_PARAMS); assertThat(results.getWarning(), equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, "regression_model"))); } - private static SingleValueInferenceResults getSingleValue(Model model, - Map fields, - InferenceConfig config) throws Exception { + private static SingleValueInferenceResults getSingleValue(Model model, + Map fields, + InferenceConfigUpdate config) + throws Exception { return (SingleValueInferenceResults)getInferenceResult(model, fields, config); } - private static InferenceResults getInferenceResult(Model model, Map fields, InferenceConfig config) throws Exception { + private static InferenceResults getInferenceResult(Model model, + Map fields, + InferenceConfigUpdate config) throws Exception { PlainActionFuture future = new PlainActionFuture<>(); model.infer(fields, config, future); return future.get(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index b9e33815fcb..50adfe6d861 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -36,6 +36,8 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -111,7 +113,7 @@ public class ModelLoadingServiceTests extends ESTestCase { String[] modelIds = new String[]{model1, model2, model3}; for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -124,7 +126,7 @@ public class ModelLoadingServiceTests extends ESTestCase { modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -164,7 +166,7 @@ public class ModelLoadingServiceTests extends ESTestCase { for(int i = 0; i < 10; i++) { // Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load) String model = modelIds[i%2]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -176,7 +178,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Load model 3, should invalidate 1 for(int i = 0; i < 10; i++) { - PlainActionFuture future3 = new PlainActionFuture<>(); + PlainActionFuture> future3 = new PlainActionFuture<>(); modelLoadingService.getModel(model3, future3); assertThat(future3.get(), is(not(nullValue()))); } @@ -184,7 +186,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Load model 1, should invalidate 2 for(int i = 0; i < 10; i++) { - PlainActionFuture future1 = new PlainActionFuture<>(); + PlainActionFuture> future1 = new PlainActionFuture<>(); modelLoadingService.getModel(model1, future1); assertThat(future1.get(), is(not(nullValue()))); } @@ -192,7 +194,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Load model 2, should invalidate 3 for(int i = 0; i < 10; i++) { - PlainActionFuture future2 = new PlainActionFuture<>(); + PlainActionFuture> future2 = new PlainActionFuture<>(); modelLoadingService.getModel(model2, future2); assertThat(future2.get(), is(not(nullValue()))); } @@ -204,7 +206,7 @@ public class ModelLoadingServiceTests extends ESTestCase { modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -230,7 +232,7 @@ public class ModelLoadingServiceTests extends ESTestCase { modelLoadingService.clusterChanged(ingestChangedEvent(false, model1)); for(int i = 0; i < 10; i++) { - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model1, future); assertThat(future.get(), is(not(nullValue()))); } @@ -250,7 +252,7 @@ public class ModelLoadingServiceTests extends ESTestCase { Settings.EMPTY); modelLoadingService.clusterChanged(ingestChangedEvent(model)); - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); try { @@ -274,7 +276,7 @@ public class ModelLoadingServiceTests extends ESTestCase { NamedXContentRegistry.EMPTY, Settings.EMPTY); - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); try { future.get(); @@ -296,7 +298,7 @@ public class ModelLoadingServiceTests extends ESTestCase { Settings.EMPTY); for(int i = 0; i < 3; i++) { - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -310,6 +312,7 @@ public class ModelLoadingServiceTests extends ESTestCase { when(definition.ramBytesUsed()).thenReturn(size); TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class); when(trainedModelConfig.getModelDefinition()).thenReturn(definition); + when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS); when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz"))); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 11037925735..1749d00f821 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -20,8 +20,8 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; @@ -146,20 +146,20 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { // Test regression InternalInferModelAction.Request request = new InternalInferModelAction.Request(modelId1, toInfer, - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, true); InternalInferModelAction.Response response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), contains(1.3, 1.25)); - request = new InternalInferModelAction.Request(modelId1, toInfer2, RegressionConfig.EMPTY_PARAMS, true); + request = new InternalInferModelAction.Request(modelId1, toInfer2, RegressionConfigUpdate.EMPTY_PARAMS, true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), contains(1.65, 1.55)); // Test classification - request = new InternalInferModelAction.Request(modelId2, toInfer, ClassificationConfig.EMPTY_PARAMS, true); + request = new InternalInferModelAction.Request(modelId2, toInfer, ClassificationConfigUpdate.EMPTY_PARAMS, true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults() .stream() @@ -168,7 +168,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { contains("not_to_be", "to_be")); // Get top classes - request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2, null, null), true); + request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null), true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); ClassificationInferenceResults classificationInferenceResults = @@ -187,7 +187,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); // Test that top classes restrict the number returned - request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1, null, null), true); + request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfigUpdate(1, null, null, null), true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); @@ -262,7 +262,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { // Test regression InternalInferModelAction.Request request = new InternalInferModelAction.Request(modelId, toInfer, - ClassificationConfig.EMPTY_PARAMS, + ClassificationConfigUpdate.EMPTY_PARAMS, true); InternalInferModelAction.Response response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults() @@ -271,7 +271,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { .collect(Collectors.toList()), contains("option_0", "option_2")); - request = new InternalInferModelAction.Request(modelId, toInfer2, ClassificationConfig.EMPTY_PARAMS, true); + request = new InternalInferModelAction.Request(modelId, toInfer2, ClassificationConfigUpdate.EMPTY_PARAMS, true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults() .stream() @@ -281,7 +281,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { // Get top classes - request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfig(3, null, null), true); + request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfigUpdate(3, null, null, null), true); response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); ClassificationInferenceResults classificationInferenceResults = @@ -303,7 +303,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { InternalInferModelAction.Request request = new InternalInferModelAction.Request( model, Collections.emptyList(), - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, true); try { client().execute(InternalInferModelAction.INSTANCE, request).actionGet(); @@ -344,7 +344,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { InternalInferModelAction.Request request = new InternalInferModelAction.Request( modelId, toInferMissingField, - RegressionConfig.EMPTY_PARAMS, + RegressionConfigUpdate.EMPTY_PARAMS, true); try { InferenceResults result = diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index 8c1f1e581de..dd49427f2d8 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -11,6 +11,7 @@ setup: "description": "empty model for tests", "tags": ["regression", "tag1"], "input": {"field_names": ["field1", "field2"]}, + "inference_config": {"regression": {}}, "definition": { "preprocessors": [], "trained_model": { @@ -35,6 +36,7 @@ setup: "description": "empty model for tests", "input": {"field_names": ["field1", "field2"]}, "tags": ["regression", "tag2"], + "inference_config": {"regression": {}}, "definition": { "preprocessors": [], "trained_model": { @@ -58,6 +60,7 @@ setup: "description": "empty model for tests", "input": {"field_names": ["field1", "field2"]}, "tags": ["classification", "tag2"], + "inference_config": {"classification": {}}, "definition": { "preprocessors": [], "trained_model": { @@ -83,6 +86,7 @@ setup: "description": "empty model for tests", "input": {"field_names": ["field1", "field2"]}, "tags": ["classification", "tag3"], + "inference_config": {"classification": {}}, "definition": { "preprocessors": [], "trained_model": { @@ -108,6 +112,7 @@ setup: "description": "empty model for tests", "input": {"field_names": ["field1", "field2"]}, "tags": ["classification", "tag3"], + "inference_config": {"classification": {}}, "definition": { "preprocessors": [], "trained_model": { @@ -343,6 +348,7 @@ setup: "input": { "field_names": "fieldy_mc_fieldname" }, + "inference_config": {"regression": {}}, "definition": { "trained_model": { "ensemble": { @@ -377,6 +383,7 @@ setup: "input": { "field_names": "fieldy_mc_fieldname" }, + "inference_config": {"regression": {}}, "definition": { "trained_model": { "ensemble": { @@ -397,6 +404,7 @@ setup: "input": { "field_names": "fieldy_mc_fieldname" }, + "inference_config": {"regression": {}}, "definition": { "trained_model": { "ensemble": { @@ -434,6 +442,7 @@ setup: "input": { "field_names": [] }, + "inference_config": {"regression": {}}, "definition": { "trained_model": { "ensemble": { @@ -469,6 +478,7 @@ setup: { "description": "model for tests", "input": {"field_names": ["field1", "field2"]}, + "inference_config": {"regression": {}}, "definition": { "preprocessors": [], "trained_model": { @@ -510,3 +520,47 @@ setup: - is_true: create_time - is_true: version - is_true: estimated_heap_memory_usage_bytes +--- +"Test PUT model where target type and inference config mismatch": + - do: + catch: /Model \[my-regression-model\] inference config type \[classification\] does not support definition target type \[regression\]/ + ml.put_trained_model: + model_id: my-regression-model + body: > + { + "description": "model for tests", + "input": {"field_names": ["field1", "field2"]}, + "inference_config": {"classification": {}}, + "definition": { + "preprocessors": [], + "trained_model": { + "ensemble": { + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + }, + { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2}, + {"node_index": 1, "leaf_value": 0}, + {"node_index": 2, "leaf_value": 1} + ], + "target_type": "regression" + } + } + ] + } + } + } + } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_processor.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_processor.yml index 9c043eee3cd..08b6603b4f0 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_processor.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_processor.yml @@ -13,6 +13,7 @@ setup: "description": "empty model for tests", "tags": ["regression", "tag1"], "input": {"field_names": ["field1", "field2"]}, + "inference_config": { "regression": {"results_field": "my_regression"}}, "definition": { "preprocessors": [], "trained_model": { @@ -112,3 +113,42 @@ setup: - 'Deprecated field [field_mappings] used, expected [field_map] instead' ingest.delete_pipeline: id: "regression-model-pipeline" +--- +"Test simulate": + - do: + ingest.simulate: + body: > + { + "pipeline": { + "processors": [ + { + "inference" : { + "model_id" : "a-perfect-regression-model", + "inference_config": {"regression": {}}, + "target_field": "regression_field", + "field_map": {} + } + } + ]}, + "docs": [{"_source": {"field1": 1, "field2": 2}}] + } + - match: { docs.0.doc._source.regression_field.my_regression: 42.0 } + + - do: + ingest.simulate: + body: > + { + "pipeline": { + "processors": [ + { + "inference" : { + "model_id" : "a-perfect-regression-model", + "inference_config": {"regression": {"results_field": "value"}}, + "target_field": "regression_field", + "field_map": {} + } + } + ]}, + "docs": [{"_source": {"field1": 1, "field2": 2}}] + } + - match: { docs.0.doc._source.regression_field.value: 42.0 } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml index 40d0f6ba01a..805a50d04a5 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml @@ -6,12 +6,13 @@ setup: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser index: id: trained_model_config-a-unused-regression-model1-0 - index: .ml-inference-000001 + index: .ml-inference-000002 body: > { "model_id": "a-unused-regression-model1", "created_by": "ml_tests", "version": "8.0.0", + "inference_config": {"regression": {}}, "description": "empty model for tests", "create_time": 0, "doc_type": "trained_model_config" @@ -22,12 +23,13 @@ setup: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser index: id: trained_model_config-a-unused-regression-model-0 - index: .ml-inference-000001 + index: .ml-inference-000002 body: > { "model_id": "a-unused-regression-model", "created_by": "ml_tests", "version": "8.0.0", + "inference_config": {"regression": {}}, "description": "empty model for tests", "create_time": 0, "doc_type": "trained_model_config" @@ -37,12 +39,13 @@ setup: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser index: id: trained_model_config-a-used-regression-model-0 - index: .ml-inference-000001 + index: .ml-inference-000002 body: > { "model_id": "a-used-regression-model", "created_by": "ml_tests", "version": "8.0.0", + "inference_config": {"regression": {}}, "description": "empty model for tests", "create_time": 0, "doc_type": "trained_model_config" diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml index 888eaad4403..0d7ecee9571 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/trained_model_cat_apis.yml @@ -13,6 +13,7 @@ setup: { "description": "empty model for tests", "tags": ["regression", "tag1"], + "inference_config": {"regression":{}}, "input": {"field_names": ["field1", "field2"]}, "definition": { "preprocessors": [], @@ -36,6 +37,7 @@ setup: body: > { "description": "empty model for tests", + "inference_config": {"regression":{}}, "input": {"field_names": ["field1", "field2"]}, "definition": { "preprocessors": [],