[7.x] [ML] add new inference_config field to trained model config (#54421) (#54647)

* [ML] add new inference_config field to trained model config (#54421)

A new field called `inference_config` is now added to the trained model config object. This new field allows for default inference settings from analytics or some external model builder.

The inference processor can still override whatever is set as the default in the trained model config.

* fixing for backport
This commit is contained in:
Benjamin Trent 2020-04-02 12:25:10 -04:00 committed by GitHub
parent 2113c1ffb6
commit 4a1610265f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
61 changed files with 1950 additions and 303 deletions

View File

@ -19,6 +19,9 @@
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
@ -61,6 +64,14 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
new ParseField(LangIdentNeuralNetwork.NAME),
LangIdentNeuralNetwork::fromXContent));
// Inference Config
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class,
ClassificationConfig.NAME,
ClassificationConfig::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class,
RegressionConfig.NAME,
RegressionConfig::fromXContent));
// Aggregating output
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
new ParseField(WeightedMode.NAME),

View File

@ -54,4 +54,14 @@ public final class NamedXContentObjectHelper {
}
return builder;
}
public static XContentBuilder writeNamedObject(XContentBuilder builder,
ToXContent.Params params,
String namedObjectName,
NamedXContentObject namedObject) throws IOException {
builder.startObject(namedObjectName);
builder.field(namedObject.getName(), namedObject, params);
builder.endObject();
return builder;
}
}

View File

@ -20,6 +20,7 @@ package org.elasticsearch.client.ml.inference;
import org.elasticsearch.Version;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.ByteSizeValue;
@ -36,6 +37,8 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.client.ml.inference.NamedXContentObjectHelper.writeNamedObject;
public class TrainedModelConfig implements ToXContentObject {
public static final String NAME = "trained_model_config";
@ -54,6 +57,7 @@ public class TrainedModelConfig implements ToXContentObject {
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map");
public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
@ -78,6 +82,9 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
PARSER.declareObject(TrainedModelConfig.Builder::setDefaultFieldMap, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAP);
PARSER.declareNamedObject(TrainedModelConfig.Builder::setInferenceConfig,
(p, c, n) -> p.namedObject(InferenceConfig.class, n, null),
INFERENCE_CONFIG);
}
public static TrainedModelConfig fromXContent(XContentParser parser) throws IOException {
@ -98,6 +105,7 @@ public class TrainedModelConfig implements ToXContentObject {
private final Long estimatedOperations;
private final String licenseLevel;
private final Map<String, String> defaultFieldMap;
private final InferenceConfig inferenceConfig;
TrainedModelConfig(String modelId,
String createdBy,
@ -112,7 +120,8 @@ public class TrainedModelConfig implements ToXContentObject {
Long estimatedHeapMemory,
Long estimatedOperations,
String licenseLevel,
Map<String, String> defaultFieldMap) {
Map<String, String> defaultFieldMap,
InferenceConfig inferenceConfig) {
this.modelId = modelId;
this.createdBy = createdBy;
this.version = version;
@ -127,6 +136,7 @@ public class TrainedModelConfig implements ToXContentObject {
this.estimatedOperations = estimatedOperations;
this.licenseLevel = licenseLevel;
this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap);
this.inferenceConfig = inferenceConfig;
}
public String getModelId() {
@ -189,6 +199,10 @@ public class TrainedModelConfig implements ToXContentObject {
return defaultFieldMap;
}
public InferenceConfig getInferenceConfig() {
return inferenceConfig;
}
public static Builder builder() {
return new Builder();
}
@ -238,6 +252,9 @@ public class TrainedModelConfig implements ToXContentObject {
if (defaultFieldMap != null) {
builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap);
}
if (inferenceConfig != null) {
writeNamedObject(builder, params, INFERENCE_CONFIG.getPreferredName(), inferenceConfig);
}
builder.endObject();
return builder;
}
@ -265,6 +282,7 @@ public class TrainedModelConfig implements ToXContentObject {
Objects.equals(estimatedOperations, that.estimatedOperations) &&
Objects.equals(licenseLevel, that.licenseLevel) &&
Objects.equals(defaultFieldMap, that.defaultFieldMap) &&
Objects.equals(inferenceConfig, that.inferenceConfig) &&
Objects.equals(metadata, that.metadata);
}
@ -283,6 +301,7 @@ public class TrainedModelConfig implements ToXContentObject {
metadata,
licenseLevel,
input,
inferenceConfig,
defaultFieldMap);
}
@ -303,6 +322,7 @@ public class TrainedModelConfig implements ToXContentObject {
private Long estimatedOperations;
private String licenseLevel;
private Map<String, String> defaultFieldMap;
private InferenceConfig inferenceConfig;
public Builder setModelId(String modelId) {
this.modelId = modelId;
@ -387,6 +407,11 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
public Builder setInferenceConfig(InferenceConfig inferenceConfig) {
this.inferenceConfig = inferenceConfig;
return this;
}
public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
@ -402,7 +427,8 @@ public class TrainedModelConfig implements ToXContentObject {
estimatedHeapMemory,
estimatedOperations,
licenseLevel,
defaultFieldMap);
defaultFieldMap,
inferenceConfig);
}
}

View File

@ -0,0 +1,129 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class ClassificationConfig implements InferenceConfig {
public static final ParseField NAME = new ParseField("classification");
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field");
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
private final Integer numTopClasses;
private final String topClassesResultsField;
private final String resultsField;
private final Integer numTopFeatureImportanceValues;
private static final ConstructingObjectParser<ClassificationConfig, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new ClassificationConfig(
(Integer) args[0], (String) args[1], (String) args[2], (Integer) args[3]));
static {
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD);
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
}
public static ClassificationConfig fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public ClassificationConfig() {
this(null, null, null, null);
}
public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField, Integer featureImportance) {
this.numTopClasses = numTopClasses;
this.topClassesResultsField = topClassesResultsField;
this.resultsField = resultsField;
this.numTopFeatureImportanceValues = featureImportance;
}
public Integer getNumTopClasses() {
return numTopClasses;
}
public String getTopClassesResultsField() {
return topClassesResultsField;
}
public String getResultsField() {
return resultsField;
}
public Integer getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassificationConfig that = (ClassificationConfig) o;
return Objects.equals(numTopClasses, that.numTopClasses)
&& Objects.equals(topClassesResultsField, that.topClassesResultsField)
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
}
@Override
public int hashCode() {
return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
if (numTopClasses != null) {
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
}
if (topClassesResultsField != null) {
builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField);
}
if (resultsField != null) {
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
}
if (numTopFeatureImportanceValues != null) {
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
}
builder.endObject();
return builder;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
}

View File

@ -0,0 +1,26 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel;
import org.elasticsearch.client.ml.inference.NamedXContentObject;
public interface InferenceConfig extends NamedXContentObject {
}

View File

@ -0,0 +1,104 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class RegressionConfig implements InferenceConfig {
public static final ParseField NAME = new ParseField("regression");
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
private static final ConstructingObjectParser<RegressionConfig, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(),
true,
args -> new RegressionConfig((String) args[0], (Integer)args[1]));
static {
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
}
public static RegressionConfig fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final String resultsField;
private final Integer numTopFeatureImportanceValues;
public RegressionConfig() {
this(null, null);
}
public RegressionConfig(String resultsField, Integer numTopFeatureImportanceValues) {
this.resultsField = resultsField;
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
}
public Integer getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues;
}
public String getResultsField() {
return resultsField;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
if (resultsField != null) {
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
}
if (numTopFeatureImportanceValues != null) {
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RegressionConfig that = (RegressionConfig)o;
return Objects.equals(this.resultsField, that.resultsField)
&& Objects.equals(this.numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
}
@Override
public int hashCode() {
return Objects.hash(resultsField, numTopFeatureImportanceValues);
}
}

View File

@ -158,6 +158,7 @@ import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.client.ml.inference.TrainedModelInput;
import org.elasticsearch.client.ml.inference.TrainedModelStats;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
@ -2272,6 +2273,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
.setDefinition(definition)
.setModelId(modelId)
.setInferenceConfig(new RegressionConfig())
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
.setDescription("test model")
.build();
@ -2285,6 +2287,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
trainedModelConfig = TrainedModelConfig.builder()
.setCompressedDefinition(InferenceToXContentCompressor.deflate(definition))
.setModelId(modelIdCompressed)
.setInferenceConfig(new RegressionConfig())
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
.setDescription("test model")
.build();
@ -2591,6 +2594,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
.setDefinition(definition)
.setModelId(modelId)
.setInferenceConfig(new RegressionConfig())
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
.setDescription("test model")
.build();

View File

@ -75,6 +75,8 @@ import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
@ -699,7 +701,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(62, namedXContents.size());
assertEquals(64, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -709,7 +711,7 @@ public class RestHighLevelClientTests extends ESTestCase {
categories.put(namedXContent.categoryClass, counter + 1);
}
}
assertEquals("Had: " + categories, 13, categories.size());
assertEquals("Had: " + categories, 14, categories.size());
assertEquals(Integer.valueOf(3), categories.get(Aggregation.class));
assertTrue(names.contains(ChildrenAggregationBuilder.NAME));
assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME));
@ -783,6 +785,9 @@ public class RestHighLevelClientTests extends ESTestCase {
assertEquals(Integer.valueOf(3),
categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class));
assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME));
assertEquals(Integer.valueOf(2),
categories.get(org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig.class));
assertThat(names, hasItems(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName()));
}
public void testApiNamingConventions() throws Exception {

View File

@ -174,6 +174,7 @@ import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.client.ml.inference.TrainedModelInput;
import org.elasticsearch.client.ml.inference.TrainedModelStats;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
import org.elasticsearch.client.ml.job.config.AnalysisLimits;
@ -3646,11 +3647,13 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.setDescription("test model") // <5>
.setMetadata(new HashMap<>()) // <6>
.setTags("my_regression_models") // <7>
.setInferenceConfig(new RegressionConfig("value", 0)) // <8>
.build();
// end::put-trained-model-config
trainedModelConfig = TrainedModelConfig.builder()
.setDefinition(definition)
.setInferenceConfig(new RegressionConfig(null, null))
.setModelId("my-new-trained-model")
.setInput(new TrainedModelInput("col1", "col2", "col3", "col4"))
.setDescription("test model")
@ -4234,6 +4237,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
.setDefinition(definition)
.setModelId(modelId)
.setInferenceConfig(new RegressionConfig("value", 0))
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
.setDescription("test model")
.build();

View File

@ -19,6 +19,9 @@
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.Version;
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfigTests;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfigTests;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
@ -39,13 +42,14 @@ import java.util.stream.Stream;
public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedModelConfig> {
public static TrainedModelConfig createTestTrainedModelConfig() {
TargetType targetType = randomFrom(TargetType.values());
return new TrainedModelConfig(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
Version.CURRENT,
randomBoolean() ? null : randomAlphaOfLength(100),
Instant.ofEpochMilli(randomNonNegativeLong()),
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder(targetType).build(),
randomBoolean() ? null : randomAlphaOfLength(100),
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
@ -57,7 +61,10 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))));
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))),
targetType.equals(TargetType.CLASSIFICATION) ?
ClassificationConfigTests.randomClassificationConfig() :
RegressionConfigTests.randomRegressionConfig());
}
@Override

View File

@ -0,0 +1,51 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class ClassificationConfigTests extends AbstractXContentTestCase<ClassificationConfig> {
public static ClassificationConfig randomClassificationConfig() {
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10),
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomIntBetween(0, 10)
);
}
@Override
protected ClassificationConfig createTestInstance() {
return randomClassificationConfig();
}
@Override
protected ClassificationConfig doParseInstance(XContentParser parser) throws IOException {
return ClassificationConfig.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -0,0 +1,50 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class RegressionConfigTests extends AbstractXContentTestCase<RegressionConfig> {
public static RegressionConfig randomRegressionConfig() {
return new RegressionConfig(
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomIntBetween(0, 10));
}
@Override
protected RegressionConfig createTestInstance() {
return randomRegressionConfig();
}
@Override
protected RegressionConfig doParseInstance(XContentParser parser) throws IOException {
return RegressionConfig.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -39,6 +39,8 @@ include-tagged::{doc-tests-file}[{api}-config]
<5> Optionally, a human-readable description
<6> Optionally, an object map contain metadata about the model
<7> Optionally, an array of tags to organize the model
<8> The default inference config to use with the model. Must match the underlying
definition target_type.
include::../execution.asciidoc[]

View File

@ -38,44 +38,38 @@ include::common-options.asciidoc[]
[[inference-processor-regression-opt]]
==== {regression-cap} configuration options
Regression configuration for inference.
`results_field`::
(Optional, string)
Specifies the field to which the inference prediction is written. Defaults to
`predicted_value`.
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-regression-results-field]
`num_top_feature_importance_values`::::
`num_top_feature_importance_values`::
(Optional, integer)
Specifies the maximum number of
{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature
importance] values per document. By default, it is zero and no feature importance
calculation occurs.
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-regression-num-top-feature-importance-values]
[discrete]
[[inference-processor-classification-opt]]
==== {classification-cap} configuration options
`results_field`::
(Optional, string)
The field that is added to incoming documents to contain the inference prediction. Defaults to
`predicted_value`.
Classification configuration for inference.
`num_top_classes`::
(Optional, integer)
Specifies the number of top class predictions to return. Defaults to 0.
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-classes]
`num_top_feature_importance_values`::
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values]
`results_field`::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-results-field]
`top_classes_results_field`::
(Optional, string)
Specifies the field to which the top classes are written. Defaults to
`top_classes`.
`num_top_feature_importance_values`::::
(Optional, integer)
Specifies the maximum number of
{ml-docs}/dfa-classification.html#dfa-classification-feature-importance[feature
importance] values per document. By default, it is zero and no feature
importance calculation occurs.
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-top-classes-results-field]
[discrete]
[[inference-processor-config-example]]
@ -178,4 +172,4 @@ You can also specify a target field as follows:
// NOTCONSOLE
In this case, {feat-imp} is exposed in the
`my_field.foo.feature_importance` field.
`my_field.foo.feature_importance` field.

View File

@ -43,7 +43,7 @@ is not created by {dfanalytics}.
(Required, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=model-id]
[role="child_attributes"]
[[ml-put-inference-request-body]]
==== {api-request-body-title}
@ -52,32 +52,96 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=model-id]
The compressed (GZipped and Base64 encoded) {infer} definition of the model.
If `compressed_definition` is specified, then `definition` cannot be specified.
//Begin definition
`definition`::
(Required, object)
The {infer} definition for the model. If `definition` is specified, then
`compressed_definition` cannot be specified.
`definition`.`preprocessors`:::
+
.Properties of `definition`
[%collapsible%open]
====
`preprocessors`:::
(Optional, object)
Collection of preprocessors. See <<ml-put-inference-preprocessors>> for the full
list of available preprocessors.
`definition`.`trained_model`:::
`trained_model`:::
(Required, object)
The definition of the trained model. See <<ml-put-inference-trained-model>> for
details.
====
//End definition
`description`::
(Optional, string)
A human-readable description of the {infer} trained model.
//Begin inference_config
`inference_config`::
(Required, object)
The default configuration for inference. This can be either a `regression`
or `classification` configuration. It must match the underlying
`definition.trained_model`'s `target_type`.
+
.Properties of `inference_config`
[%collapsible%open]
====
`regression`:::
(Optional, object)
Regression configuration for inference.
+
.Properties of regression inference
[%collapsible%open]
=====
`num_top_feature_importance_values`::::
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-regression-num-top-feature-importance-values]
`results_field`::::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-regression-results-field]
=====
`classification`:::
(Optional, object)
Classification configuration for inference.
+
.Properties of classification inference
[%collapsible%open]
=====
`num_top_classes`::::
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-classes]
`num_top_feature_importance_values`::::
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values]
`results_field`::::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-results-field]
`top_classes_results_field`::::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=inference-config-classification-top-classes-results-field]
=====
====
//End of inference_config
//Begin input
`input`::
(Required, object)
The input field names for the model definition.
`input`.`field_names`:::
+
.Properties of `input`
[%collapsible%open]
====
`field_names`:::
(Required, string)
An array of input field names for the model.
====
//End input
`metadata`::
(Optional, object)
@ -87,7 +151,6 @@ An object map that contains metadata about the model.
(Optional, string)
An array of tags to organize the model.
[[ml-put-inference-preprocessors]]
===== {infer-cap} preprocessor definitions
@ -491,4 +554,4 @@ Example of a `weighted_mode` object:
===== {infer-cap} JSON schema
For the full JSON schema of model {infer},
https://github.com/elastic/ml-json-schemas[click here].
https://github.com/elastic/ml-json-schemas[click here].

View File

@ -1213,6 +1213,39 @@ For more information about these options, see <<multi-index>>.
--
end::indices-options[]
tag::inference-config-classification-num-top-classes[]
Specifies the number of top class predictions to return. Defaults to 0.
end::inference-config-classification-num-top-classes[]
tag::inference-config-classification-num-top-feature-importance-values[]
Specifies the maximum number of
{ml-docs}/dfa-classification.html#dfa-classification-feature-importance[feature
importance] values per document. By default, it is zero and no feature
importance calculation occurs.
end::inference-config-classification-num-top-feature-importance-values[]
tag::inference-config-classification-results-field[]
The field that is added to incoming documents to contain the inference
prediction. Defaults to `predicted_value`.
end::inference-config-classification-results-field[]
tag::inference-config-classification-top-classes-results-field[]
Specifies the field to which the top classes are written. Defaults to
`top_classes`.
end::inference-config-classification-top-classes-results-field[]
tag::inference-config-regression-num-top-feature-importance-values[]
Specifies the maximum number of
{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature
importance] values per document. By default, it is zero and no feature importance
calculation occurs.
end::inference-config-regression-num-top-feature-importance-values[]
tag::inference-config-regression-results-field[]
Specifies the field to which the inference prediction is written. Defaults to
`predicted_value`.
end::inference-config-regression-results-field[]
tag::influencers[]
A comma separated list of influencer field names. Typically these can be the by,
over, or partition fields that are used in the detector configuration. You might

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
@ -13,8 +14,12 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -37,27 +42,30 @@ public class InternalInferModelAction extends ActionType<InternalInferModelActio
private final String modelId;
private final List<Map<String, Object>> objectsToInfer;
private final InferenceConfig config;
private final InferenceConfigUpdate<? extends InferenceConfig> update;
private final boolean previouslyLicensed;
public Request(String modelId, boolean previouslyLicensed) {
this(modelId, Collections.emptyList(), RegressionConfig.EMPTY_PARAMS, previouslyLicensed);
this(modelId, Collections.emptyList(), RegressionConfigUpdate.EMPTY_PARAMS, previouslyLicensed);
}
public Request(String modelId,
List<Map<String, Object>> objectsToInfer,
InferenceConfig inferenceConfig,
InferenceConfigUpdate<? extends InferenceConfig> inferenceConfig,
boolean previouslyLicensed) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer"));
this.config = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config");
this.update = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config");
this.previouslyLicensed = previouslyLicensed;
}
public Request(String modelId, Map<String, Object> objectToInfer, InferenceConfig config, boolean previouslyLicensed) {
public Request(String modelId,
Map<String, Object> objectToInfer,
InferenceConfigUpdate<? extends InferenceConfig> update,
boolean previouslyLicensed) {
this(modelId,
Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")),
config,
update,
previouslyLicensed);
}
@ -65,7 +73,18 @@ public class InternalInferModelAction extends ActionType<InternalInferModelActio
super(in);
this.modelId = in.readString();
this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap));
this.config = in.readNamedWriteable(InferenceConfig.class);
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.update = (InferenceConfigUpdate<? extends InferenceConfig>)in.readNamedWriteable(InferenceConfigUpdate.class);
} else {
InferenceConfig oldConfig = in.readNamedWriteable(InferenceConfig.class);
if (oldConfig instanceof RegressionConfig) {
this.update = RegressionConfigUpdate.fromConfig((RegressionConfig)oldConfig);
} else if (oldConfig instanceof ClassificationConfig) {
this.update = ClassificationConfigUpdate.fromConfig((ClassificationConfig) oldConfig);
} else {
throw new IOException("Unexpected configuration type [" + oldConfig.getName() + "]");
}
}
this.previouslyLicensed = in.readBoolean();
}
@ -77,8 +96,8 @@ public class InternalInferModelAction extends ActionType<InternalInferModelActio
return objectsToInfer;
}
public InferenceConfig getConfig() {
return config;
public InferenceConfigUpdate getUpdate() {
return update;
}
public boolean isPreviouslyLicensed() {
@ -95,7 +114,11 @@ public class InternalInferModelAction extends ActionType<InternalInferModelActio
super.writeTo(out);
out.writeString(modelId);
out.writeCollection(objectsToInfer, StreamOutput::writeMap);
out.writeNamedWriteable(config);
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeNamedWriteable(update);
} else {
out.writeNamedWriteable(update.toConfig());
}
out.writeBoolean(previouslyLicensed);
}
@ -105,14 +128,14 @@ public class InternalInferModelAction extends ActionType<InternalInferModelActio
if (o == null || getClass() != o.getClass()) return false;
InternalInferModelAction.Request that = (InternalInferModelAction.Request) o;
return Objects.equals(modelId, that.modelId)
&& Objects.equals(config, that.config)
&& Objects.equals(update, that.update)
&& Objects.equals(previouslyLicensed, that.previouslyLicensed)
&& Objects.equals(objectsToInfer, that.objectsToInfer);
}
@Override
public int hashCode() {
return Objects.hash(modelId, objectsToInfer, config, previouslyLicensed);
return Objects.hash(modelId, objectsToInfer, update, previouslyLicensed);
}
}

View File

@ -99,6 +99,30 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
}
}
public Double getLambda() {
return lambda;
}
public Double getGamma() {
return gamma;
}
public Double getEta() {
return eta;
}
public Integer getMaxTrees() {
return maxTrees;
}
public Double getFeatureBagFraction() {
return featureBagFraction;
}
public Integer getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalDouble(lambda);

View File

@ -13,9 +13,14 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInference
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
@ -100,10 +105,19 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
LogisticRegression::fromXContentStrict));
// Inference Configs
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME,
ClassificationConfig::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME,
RegressionConfig::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class, ClassificationConfig.NAME,
ClassificationConfig::fromXContentLenient));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, ClassificationConfig.NAME,
ClassificationConfig::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class, RegressionConfig.NAME,
RegressionConfig::fromXContentLenient));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, RegressionConfig.NAME,
RegressionConfig::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, ClassificationConfigUpdate.NAME,
ClassificationConfigUpdate::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, RegressionConfigUpdate.NAME,
RegressionConfigUpdate::fromXContentStrict));
return namedXContent;
}
@ -149,9 +163,14 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
// Inference Configs
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new));
ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
RegressionConfig.NAME.getPreferredName(), RegressionConfig::new));
RegressionConfig.NAME.getPreferredName(), RegressionConfig::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
ClassificationConfigUpdate.NAME.getPreferredName(), ClassificationConfigUpdate::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
RegressionConfigUpdate.NAME.getPreferredName(), RegressionConfigUpdate::new));
return namedWriteables;
}

View File

@ -23,6 +23,9 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.common.time.TimeUtils;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
@ -38,6 +41,7 @@ import java.util.Objects;
import java.util.stream.Collectors;
import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper.writeNamedObject;
public class TrainedModelConfig implements ToXContentObject, Writeable {
@ -62,6 +66,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map");
public static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
@ -93,6 +98,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
parser.declareString(TrainedModelConfig.Builder::setLazyDefinition, COMPRESSED_DEFINITION);
parser.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
parser.declareObject(TrainedModelConfig.Builder::setDefaultFieldMap, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAP);
parser.declareNamedObject(TrainedModelConfig.Builder::setInferenceConfig, (p, c, n) -> ignoreUnknownFields ?
p.namedObject(LenientlyParsedInferenceConfig.class, n, null) :
p.namedObject(StrictlyParsedInferenceConfig.class, n, null),
INFERENCE_CONFIG);
return parser;
}
@ -112,6 +121,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private final long estimatedOperations;
private final License.OperationMode licenseLevel;
private final Map<String, String> defaultFieldMap;
private final InferenceConfig inferenceConfig;
private final LazyModelDefinition definition;
@ -127,7 +137,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
Long estimatedHeapMemory,
Long estimatedOperations,
String licenseLevel,
Map<String, String> defaultFieldMap) {
Map<String, String> defaultFieldMap,
InferenceConfig inferenceConfig) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
@ -148,6 +159,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
this.estimatedOperations = estimatedOperations;
this.licenseLevel = License.OperationMode.parse(ExceptionsHelper.requireNonNull(licenseLevel, LICENSE_LEVEL));
this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap);
this.inferenceConfig = inferenceConfig;
}
public TrainedModelConfig(StreamInput in) throws IOException {
@ -170,6 +182,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
} else {
this.defaultFieldMap = null;
}
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
this.inferenceConfig = in.readOptionalNamedWriteable(InferenceConfig.class);
} else {
this.inferenceConfig = null;
}
}
public String getModelId() {
@ -204,6 +221,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return defaultFieldMap;
}
@Nullable
public InferenceConfig getInferenceConfig() {
return inferenceConfig;
}
@Nullable
public String getCompressedDefinition() throws IOException {
if (definition == null) {
@ -274,6 +296,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
out.writeBoolean(false);
}
}
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeOptionalNamedWriteable(inferenceConfig);
}
}
@Override
@ -311,6 +336,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
if (defaultFieldMap != null && defaultFieldMap.isEmpty() == false) {
builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap);
}
if (inferenceConfig != null) {
writeNamedObject(builder, params, INFERENCE_CONFIG.getPreferredName(), inferenceConfig);
}
builder.endObject();
return builder;
}
@ -337,6 +365,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
Objects.equals(estimatedOperations, that.estimatedOperations) &&
Objects.equals(licenseLevel, that.licenseLevel) &&
Objects.equals(defaultFieldMap, that.defaultFieldMap) &&
Objects.equals(inferenceConfig, that.inferenceConfig) &&
Objects.equals(metadata, that.metadata);
}
@ -354,6 +383,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
estimatedOperations,
input,
licenseLevel,
inferenceConfig,
defaultFieldMap);
}
@ -372,6 +402,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private LazyModelDefinition definition;
private String licenseLevel;
private Map<String, String> defaultFieldMap;
private InferenceConfig inferenceConfig;
public Builder() {}
@ -389,6 +420,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
this.estimatedHeapMemory = config.estimatedHeapMemory;
this.licenseLevel = config.licenseLevel.description();
this.defaultFieldMap = config.defaultFieldMap == null ? null : new HashMap<>(config.defaultFieldMap);
this.inferenceConfig = config.inferenceConfig;
}
public Builder setModelId(String modelId) {
@ -512,6 +544,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this;
}
public Builder setInferenceConfig(InferenceConfig inferenceConfig) {
this.inferenceConfig = inferenceConfig;
return this;
}
public Builder validate() {
return validate(false);
}
@ -530,6 +567,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
if (modelId == null) {
validationException = addValidationError("[" + MODEL_ID.getPreferredName() + "] must not be null.", validationException);
}
if (inferenceConfig == null && forCreation) {
validationException = addValidationError("[" + INFERENCE_CONFIG.getPreferredName() + "] must not be null.",
validationException);
}
if (modelId != null && MlStrings.isValidId(modelId) == false) {
validationException = addValidationError(Messages.getMessage(Messages.INVALID_ID,
@ -605,7 +646,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
estimatedHeapMemory == null ? 0 : estimatedHeapMemory,
estimatedOperations == null ? 0 : estimatedOperations,
licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel,
defaultFieldMap);
defaultFieldMap,
inferenceConfig);
}
}

View File

@ -5,19 +5,31 @@
*/
package org.elasticsearch.xpack.core.ml.inference.persistence;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.xpack.core.template.TemplateUtils;
/**
* Class containing the index constants so that the index version, name, and prefix are available to a wider audience.
*/
public final class InferenceIndexConstants {
public static final String INDEX_VERSION = "000001";
/**
* version: 7.8.0:
* - adds inference_config definition to trained model config
*
*/
public static final String INDEX_VERSION = "000002";
public static final String INDEX_NAME_PREFIX = ".ml-inference-";
public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*";
public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION;
public static final ParseField DOC_TYPE = new ParseField("doc_type");
private InferenceIndexConstants() {}
private static final String MAPPINGS_VERSION_VARIABLE = "xpack.ml.version";
public static String mapping() {
return TemplateUtils.loadTemplate("/org/elasticsearch/xpack/core/ml/inference_index_mappings.json",
Version.CURRENT.toString(), MAPPINGS_VERSION_VARIABLE);
}
}

View File

@ -9,24 +9,19 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class ClassificationConfig implements InferenceConfig {
public class ClassificationConfig implements LenientlyParsedInferenceConfig, StrictlyParsedInferenceConfig {
public static final ParseField NAME = new ParseField("classification");
public static final String DEFAULT_TOP_CLASSES_RESULTS_FIELD = "top_classes";
private static final String DEFAULT_RESULTS_FIELD = "predicted_value";
public static final String DEFAULT_RESULTS_FIELD = "predicted_value";
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
@ -42,32 +37,27 @@ public class ClassificationConfig implements InferenceConfig {
private final String resultsField;
private final int numTopFeatureImportanceValues;
public static ClassificationConfig fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName());
String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
private static final ObjectParser<ClassificationConfig.Builder, Void> LENIENT_PARSER = createParser(true);
private static final ObjectParser<ClassificationConfig.Builder, Void> STRICT_PARSER = createParser(false);
if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
}
return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, featureImportance);
private static ObjectParser<ClassificationConfig.Builder, Void> createParser(boolean lenient) {
ObjectParser<ClassificationConfig.Builder, Void> parser = new ObjectParser<>(
NAME.getPreferredName(),
lenient,
ClassificationConfig.Builder::new);
parser.declareInt(ClassificationConfig.Builder::setNumTopClasses, NUM_TOP_CLASSES);
parser.declareString(ClassificationConfig.Builder::setResultsField, RESULTS_FIELD);
parser.declareString(ClassificationConfig.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD);
parser.declareInt(ClassificationConfig.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
return parser;
}
private static final ConstructingObjectParser<ClassificationConfig, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ClassificationConfig(
(Integer) args[0], (String) args[1], (String) args[2], (Integer) args[3]));
static {
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD);
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
public static ClassificationConfig fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null).build();
}
public static ClassificationConfig fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
public static ClassificationConfig fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null).build();
}
public ClassificationConfig(Integer numTopClasses) {
@ -150,14 +140,10 @@ public class ClassificationConfig implements InferenceConfig {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (numTopClasses != 0) {
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
}
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField);
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
if (numTopFeatureImportanceValues > 0) {
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
}
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
builder.endObject();
return builder;
}
@ -179,7 +165,50 @@ public class ClassificationConfig implements InferenceConfig {
@Override
public Version getMinimalSupportedVersion() {
return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION;
return requestingImportance() ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private Integer numTopClasses;
private String topClassesResultsField;
private String resultsField;
private Integer numTopFeatureImportanceValues;
Builder() {}
Builder(ClassificationConfig config) {
this.numTopClasses = config.numTopClasses;
this.topClassesResultsField = config.topClassesResultsField;
this.resultsField = config.resultsField;
this.numTopFeatureImportanceValues = config.numTopFeatureImportanceValues;
}
public Builder setNumTopClasses(Integer numTopClasses) {
this.numTopClasses = numTopClasses;
return this;
}
public Builder setTopClassesResultsField(String topClassesResultsField) {
this.topClassesResultsField = topClassesResultsField;
return this;
}
public Builder setResultsField(String resultsField) {
this.resultsField = resultsField;
return this;
}
public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
return this;
}
public ClassificationConfig build() {
return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, numTopFeatureImportanceValues);
}
}
}

View File

@ -0,0 +1,235 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.NUM_TOP_CLASSES;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.RESULTS_FIELD;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig.TOP_CLASSES_RESULTS_FIELD;
public class ClassificationConfigUpdate implements InferenceConfigUpdate<ClassificationConfig> {
public static final ParseField NAME = new ParseField("classification");
public static ClassificationConfigUpdate EMPTY_PARAMS =
new ClassificationConfigUpdate(null, null, null, null);
private final Integer numTopClasses;
private final String topClassesResultsField;
private final String resultsField;
private final Integer numTopFeatureImportanceValues;
public static ClassificationConfigUpdate fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName());
String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
}
return new ClassificationConfigUpdate(numTopClasses, resultsField, topClassesResultsField, featureImportance);
}
public static ClassificationConfigUpdate fromConfig(ClassificationConfig config) {
return new ClassificationConfigUpdate(config.getNumTopClasses(),
config.getResultsField(),
config.getTopClassesResultsField(),
config.getNumTopFeatureImportanceValues());
}
private static final ObjectParser<ClassificationConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
private static ObjectParser<ClassificationConfigUpdate.Builder, Void> createParser(boolean lenient) {
ObjectParser<ClassificationConfigUpdate.Builder, Void> parser = new ObjectParser<>(
NAME.getPreferredName(),
lenient,
ClassificationConfigUpdate.Builder::new);
parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopClasses, NUM_TOP_CLASSES);
parser.declareString(ClassificationConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
parser.declareString(ClassificationConfigUpdate.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD);
parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
return parser;
}
public static ClassificationConfigUpdate fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null).build();
}
public ClassificationConfigUpdate(Integer numTopClasses,
String resultsField,
String topClassesResultsField,
Integer featureImportance) {
this.numTopClasses = numTopClasses;
this.topClassesResultsField = topClassesResultsField;
this.resultsField = resultsField;
if (featureImportance != null && featureImportance < 0) {
throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() +
"] must be greater than or equal to 0");
}
this.numTopFeatureImportanceValues = featureImportance;
}
public ClassificationConfigUpdate(StreamInput in) throws IOException {
this.numTopClasses = in.readOptionalInt();
this.topClassesResultsField = in.readOptionalString();
this.resultsField = in.readOptionalString();
this.numTopFeatureImportanceValues = in.readOptionalVInt();
}
public Integer getNumTopClasses() {
return numTopClasses;
}
public String getTopClassesResultsField() {
return topClassesResultsField;
}
public String getResultsField() {
return resultsField;
}
public Integer getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(numTopClasses);
out.writeOptionalString(topClassesResultsField);
out.writeOptionalString(resultsField);
out.writeOptionalVInt(numTopFeatureImportanceValues);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassificationConfigUpdate that = (ClassificationConfigUpdate) o;
return Objects.equals(numTopClasses, that.numTopClasses)
&& Objects.equals(topClassesResultsField, that.topClassesResultsField)
&& Objects.equals(resultsField, that.resultsField)
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
}
@Override
public int hashCode() {
return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (numTopClasses != null) {
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
}
if (topClassesResultsField != null) {
builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField);
}
if (resultsField != null) {
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
}
if (numTopFeatureImportanceValues != null) {
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
}
builder.endObject();
return builder;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public ClassificationConfig apply(ClassificationConfig originalConfig) {
if (isNoop(originalConfig)) {
return originalConfig;
}
ClassificationConfig.Builder builder = new ClassificationConfig.Builder(originalConfig);
if (resultsField != null) {
builder.setResultsField(resultsField);
}
if (numTopFeatureImportanceValues != null) {
builder.setNumTopFeatureImportanceValues(numTopFeatureImportanceValues);
}
if (topClassesResultsField != null) {
builder.setTopClassesResultsField(topClassesResultsField);
}
if (numTopClasses != null) {
builder.setNumTopClasses(numTopClasses);
}
return builder.build();
}
@Override
public InferenceConfig toConfig() {
return apply(ClassificationConfig.EMPTY_PARAMS);
}
@Override
public boolean isSupported(InferenceConfig inferenceConfig) {
return inferenceConfig instanceof ClassificationConfig;
}
boolean isNoop(ClassificationConfig originalConfig) {
return (resultsField == null || resultsField.equals(originalConfig.getResultsField()))
&& (numTopFeatureImportanceValues == null
|| originalConfig.getNumTopFeatureImportanceValues() == numTopFeatureImportanceValues)
&& (topClassesResultsField == null || topClassesResultsField.equals(originalConfig.getTopClassesResultsField()))
&& (numTopClasses == null || originalConfig.getNumTopClasses() == numTopClasses);
}
public static class Builder {
private Integer numTopClasses;
private String topClassesResultsField;
private String resultsField;
private Integer numTopFeatureImportanceValues;
public Builder setNumTopClasses(int numTopClasses) {
this.numTopClasses = numTopClasses;
return this;
}
public Builder setTopClassesResultsField(String topClassesResultsField) {
this.topClassesResultsField = topClassesResultsField;
return this;
}
public Builder setResultsField(String resultsField) {
this.resultsField = resultsField;
return this;
}
public Builder setNumTopFeatureImportanceValues(int numTopFeatureImportanceValues) {
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
return this;
}
public ClassificationConfigUpdate build() {
return new ClassificationConfigUpdate(numTopClasses, resultsField, topClassesResultsField, numTopFeatureImportanceValues);
}
}
}

View File

@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
public interface InferenceConfigUpdate<T extends InferenceConfig> extends NamedXContentObject, NamedWriteable {
T apply(T originalConfig);
InferenceConfig toConfig();
boolean isSupported(InferenceConfig config);
}

View File

@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
public interface LenientlyParsedInferenceConfig extends InferenceConfig {
}

View File

@ -9,48 +9,42 @@ import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class RegressionConfig implements InferenceConfig {
public class RegressionConfig implements LenientlyParsedInferenceConfig, StrictlyParsedInferenceConfig {
public static final ParseField NAME = new ParseField("regression");
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
public static final ParseField RESULTS_FIELD = new ParseField("results_field");
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
private static final String DEFAULT_RESULTS_FIELD = "predicted_value";
public static final String DEFAULT_RESULTS_FIELD = "predicted_value";
public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD, null);
public static RegressionConfig fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
}
return new RegressionConfig(resultsField, featureImportance);
private static final ObjectParser<RegressionConfig.Builder, Void> LENIENT_PARSER = createParser(true);
private static final ObjectParser<RegressionConfig.Builder, Void> STRICT_PARSER = createParser(false);
private static ObjectParser<RegressionConfig.Builder, Void> createParser(boolean lenient) {
ObjectParser<RegressionConfig.Builder, Void> parser = new ObjectParser<>(
NAME.getPreferredName(),
lenient,
RegressionConfig.Builder::new);
parser.declareString(RegressionConfig.Builder::setResultsField, RESULTS_FIELD);
parser.declareInt(RegressionConfig.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
return parser;
}
private static final ConstructingObjectParser<RegressionConfig, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0], (Integer)args[1]));
static {
PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
public static RegressionConfig fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null).build();
}
public static RegressionConfig fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
public static RegressionConfig fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null).build();
}
private final String resultsField;
@ -113,9 +107,7 @@ public class RegressionConfig implements InferenceConfig {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
if (numTopFeatureImportanceValues > 0) {
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
}
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
builder.endObject();
return builder;
}
@ -141,7 +133,36 @@ public class RegressionConfig implements InferenceConfig {
@Override
public Version getMinimalSupportedVersion() {
return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION;
return requestingImportance() ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String resultsField;
private Integer numTopFeatureImportanceValues;
Builder() {}
Builder(RegressionConfig config) {
this.resultsField = config.resultsField;
this.numTopFeatureImportanceValues = config.numTopFeatureImportanceValues;
}
public Builder setResultsField(String resultsField) {
this.resultsField = resultsField;
return this;
}
public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
return this;
}
public RegressionConfig build() {
return new RegressionConfig(resultsField, numTopFeatureImportanceValues);
}
}
}

View File

@ -0,0 +1,178 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.DEFAULT_RESULTS_FIELD;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.RESULTS_FIELD;
public class RegressionConfigUpdate implements InferenceConfigUpdate<RegressionConfig> {
public static final ParseField NAME = new ParseField("regression");
public static RegressionConfigUpdate EMPTY_PARAMS = new RegressionConfigUpdate(null, null);
public static RegressionConfigUpdate fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
}
return new RegressionConfigUpdate(resultsField, featureImportance);
}
public static RegressionConfigUpdate fromConfig(RegressionConfig config) {
return new RegressionConfigUpdate(config.getResultsField(), config.getNumTopFeatureImportanceValues());
}
private static final ObjectParser<RegressionConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
private static ObjectParser<RegressionConfigUpdate.Builder, Void> createParser(boolean lenient) {
ObjectParser<RegressionConfigUpdate.Builder, Void> parser = new ObjectParser<>(
NAME.getPreferredName(),
lenient,
RegressionConfigUpdate.Builder::new);
parser.declareString(RegressionConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
parser.declareInt(RegressionConfigUpdate.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
return parser;
}
public static RegressionConfigUpdate fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null).build();
}
private final String resultsField;
private final Integer numTopFeatureImportanceValues;
public RegressionConfigUpdate(String resultsField, Integer numTopFeatureImportanceValues) {
this.resultsField = resultsField;
if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) {
throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() +
"] must be greater than or equal to 0");
}
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
}
public RegressionConfigUpdate(StreamInput in) throws IOException {
this.resultsField = in.readOptionalString();
this.numTopFeatureImportanceValues = in.readOptionalVInt();
}
public int getNumTopFeatureImportanceValues() {
return numTopFeatureImportanceValues == null ? 0 : numTopFeatureImportanceValues;
}
public String getResultsField() {
return resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(resultsField);
out.writeOptionalVInt(numTopFeatureImportanceValues);
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (resultsField != null) {
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
}
if (numTopFeatureImportanceValues != null) {
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RegressionConfigUpdate that = (RegressionConfigUpdate)o;
return Objects.equals(this.resultsField, that.resultsField)
&& Objects.equals(this.numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
}
@Override
public int hashCode() {
return Objects.hash(resultsField, numTopFeatureImportanceValues);
}
@Override
public RegressionConfig apply(RegressionConfig originalConfig) {
if (isNoop(originalConfig)) {
return originalConfig;
}
RegressionConfig.Builder builder = new RegressionConfig.Builder(originalConfig);
if (resultsField != null) {
builder.setResultsField(resultsField);
}
if (numTopFeatureImportanceValues != null) {
builder.setNumTopFeatureImportanceValues(numTopFeatureImportanceValues);
}
return builder.build();
}
@Override
public InferenceConfig toConfig() {
return apply(RegressionConfig.EMPTY_PARAMS);
}
@Override
public boolean isSupported(InferenceConfig inferenceConfig) {
return inferenceConfig instanceof RegressionConfig;
}
boolean isNoop(RegressionConfig originalConfig) {
return (resultsField == null || originalConfig.getResultsField().equals(resultsField))
&& (numTopFeatureImportanceValues == null
|| originalConfig.getNumTopFeatureImportanceValues() == numTopFeatureImportanceValues);
}
public static class Builder {
private String resultsField;
private Integer numTopFeatureImportanceValues;
public Builder setResultsField(String resultsField) {
this.resultsField = resultsField;
return this;
}
public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
return this;
}
public RegressionConfigUpdate build() {
return new RegressionConfigUpdate(resultsField, numTopFeatureImportanceValues);
}
}
}

View File

@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
public interface StrictlyParsedInferenceConfig extends InferenceConfig {
}

View File

@ -41,4 +41,14 @@ public final class NamedXContentObjectHelper {
}
return builder;
}
public static XContentBuilder writeNamedObject(XContentBuilder builder,
ToXContent.Params params,
String namedObjectName,
NamedXContentObject namedObject) throws IOException {
builder.startObject(namedObjectName);
builder.field(namedObject.getName(), namedObject, params);
builder.endObject();
return builder;
}
}

View File

@ -2,7 +2,7 @@
"order" : 0,
"version" : ${xpack.ml.version.id},
"index_patterns" : [
".ml-inference-000001"
".ml-inference-000002"
],
"settings" : {
"index" : {
@ -67,6 +67,9 @@
},
"default_field_map": {
"enabled": false
},
"inference_config": {
"enabled": false
}
}
}

View File

@ -11,22 +11,12 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static org.elasticsearch.Version.getDeclaredVersions;
import static org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase.DEFAULT_BWC_VERSIONS;
public abstract class AbstractBWCSerializationTestCase<T extends Writeable & ToXContent> extends AbstractSerializingTestCase<T> {
private static final List<Version> ALL_VERSIONS = Collections.unmodifiableList(getDeclaredVersions(Version.class));
public static List<Version> getAllBWCVersions(Version version) {
return ALL_VERSIONS.stream().filter(v -> v.before(version) && version.isCompatible(v)).collect(Collectors.toList());
}
private static final List<Version> DEFAULT_BWC_VERSIONS = getAllBWCVersions(Version.CURRENT);
/**
* Returns the expected instance if serialized from the given version.
*/

View File

@ -0,0 +1,73 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static org.elasticsearch.Version.getDeclaredVersions;
public abstract class AbstractBWCWireSerializationTestCase<T extends Writeable> extends AbstractWireSerializingTestCase<T> {
static final List<Version> ALL_VERSIONS = Collections.unmodifiableList(getDeclaredVersions(Version.class));
public static List<Version> getAllBWCVersions(Version version) {
return ALL_VERSIONS.stream().filter(v -> v.before(version) && version.isCompatible(v)).collect(Collectors.toList());
}
static final List<Version> DEFAULT_BWC_VERSIONS = getAllBWCVersions(Version.CURRENT);
/**
* Returns the expected instance if serialized from the given version.
*/
protected abstract T mutateInstanceForVersion(T instance, Version version);
/**
* The bwc versions to test serialization against
*/
protected List<Version> bwcVersions() {
return DEFAULT_BWC_VERSIONS;
}
/**
* Test serialization and deserialization of the test instance across versions
*/
public final void testBwcSerialization() throws IOException {
for (int runs = 0; runs < NUMBER_OF_TEST_RUNS; runs++) {
T testInstance = createTestInstance();
for (Version bwcVersion : bwcVersions()) {
assertBwcSerialization(testInstance, bwcVersion);
}
}
}
/**
* Assert that instances copied at a particular version are equal. The version is useful
* for sanity checking the backwards compatibility of the wire. It isn't a substitute for
* real backwards compatibility tests but it is *so* much faster.
*/
protected final void assertBwcSerialization(T testInstance, Version version) throws IOException {
T deserializedInstance = copyWriteable(testInstance, getNamedWriteableRegistry(), instanceReader(), version);
assertOnBWCObject(deserializedInstance, mutateInstanceForVersion(testInstance, version), version);
}
/**
* @param bwcSerializedObject The object deserialized from the previous version
* @param testInstance The original test instance
* @param version The version which serialized
*/
protected void assertOnBWCObject(T bwcSerializedObject, T testInstance, Version version) {
assertNotSame(version.toString(), bwcSerializedObject, testInstance);
assertEquals(version.toString(), bwcSerializedObject, testInstance);
assertEquals(version.toString(), bwcSerializedObject.hashCode(), testInstance.hashCode());
}
}

View File

@ -5,14 +5,21 @@
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdateTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdateTests;
import java.util.ArrayList;
import java.util.List;
@ -22,25 +29,27 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
public class InternalInferModelActionRequestTests extends AbstractWireSerializingTestCase<Request> {
public class InternalInferModelActionRequestTests extends AbstractBWCWireSerializationTestCase<Request> {
@Override
@SuppressWarnings("unchecked")
protected Request createTestInstance() {
return randomBoolean() ?
new Request(
randomAlphaOfLength(10),
Stream.generate(InternalInferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()),
randomInferenceConfig(),
randomInferenceConfigUpdate(),
randomBoolean()) :
new Request(
randomAlphaOfLength(10),
randomMap(),
randomInferenceConfig(),
randomInferenceConfigUpdate(),
randomBoolean());
}
private static InferenceConfig randomInferenceConfig() {
return randomFrom(RegressionConfigTests.randomRegressionConfig(), ClassificationConfigTests.randomClassificationConfig());
private static InferenceConfigUpdate randomInferenceConfigUpdate() {
return randomFrom(RegressionConfigUpdateTests.randomRegressionConfig(),
ClassificationConfigUpdateTests.randomClassificationConfig());
}
private static Map<String, Object> randomMap() {
@ -60,4 +69,26 @@ public class InternalInferModelActionRequestTests extends AbstractWireSerializin
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(entries);
}
@Override
@SuppressWarnings("unchecked")
protected Request mutateInstanceForVersion(Request instance, Version version) {
if (version.before(Version.V_7_8_0)) {
InferenceConfigUpdate update = null;
if (instance.getUpdate() instanceof ClassificationConfigUpdate) {
update = ClassificationConfigUpdate.fromConfig(
ClassificationConfigTests.mutateForVersion((ClassificationConfig) instance.getUpdate().toConfig(), version));
}
else if (instance.getUpdate() instanceof RegressionConfigUpdate) {
update = RegressionConfigUpdate.fromConfig(
RegressionConfigTests.mutateForVersion((RegressionConfig) instance.getUpdate().toConfig(), version));
}
else {
fail("unknown update type " + instance.getUpdate().getName());
}
return new Request(instance.getModelId(), instance.getObjectsToInfer(), update, instance.isPreviouslyLicensed());
}
return instance;
}
}

View File

@ -21,7 +21,9 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.license.License;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
@ -46,7 +48,7 @@ import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
public class TrainedModelConfigTests extends AbstractSerializingTestCase<TrainedModelConfig> {
public class TrainedModelConfigTests extends AbstractBWCSerializationTestCase<TrainedModelConfig> {
private boolean lenient;
@ -66,6 +68,8 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
License.OperationMode.ENTERPRISE.description(),
License.OperationMode.GOLD.description(),
License.OperationMode.BASIC.description()))
.setInferenceConfig(randomFrom(ClassificationConfigTests.randomClassificationConfig(),
RegressionConfigTests.randomRegressionConfig()))
.setTags(tags);
}
@ -143,7 +147,8 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))));
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))),
randomFrom(ClassificationConfigTests.randomClassificationConfig(), RegressionConfigTests.randomRegressionConfig()));
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
assertThat(reference.utf8ToString(), containsString("\"compressed_definition\""));
@ -182,7 +187,8 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))));
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))),
randomFrom(ClassificationConfigTests.randomClassificationConfig(), RegressionConfigTests.randomRegressionConfig()));
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
Map<String, Object> objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2();
@ -311,4 +317,12 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
.assertToXContentEquivalence(true)
.test();
}
@Override
protected TrainedModelConfig mutateInstanceForVersion(TrainedModelConfig instance, Version version) {
if (version.before(Version.V_7_8_0)) {
return new TrainedModelConfig.Builder(instance).setInferenceConfig(null).build();
}
return instance;
}
}

View File

@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTes
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
@ -61,7 +62,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
return false;
}
public static TrainedModelDefinition.Builder createRandomBuilder() {
public static TrainedModelDefinition.Builder createRandomBuilder(TargetType targetType) {
int numberOfProcessors = randomIntBetween(1, 10);
return new TrainedModelDefinition.Builder()
.setPreProcessors(
@ -71,7 +72,11 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
TargetMeanEncodingTests.createRandom()))
.limit(numberOfProcessors)
.collect(Collectors.toList()))
.setTrainedModel(randomFrom(TreeTests.createRandom(), EnsembleTests.createRandom()));
.setTrainedModel(randomFrom(TreeTests.createRandom(targetType), EnsembleTests.createRandom(targetType)));
}
public static TrainedModelDefinition.Builder createRandomBuilder() {
return createRandomBuilder(randomFrom(TargetType.values()));
}
private static final String ENSEMBLE_MODEL = "" +

View File

@ -5,19 +5,18 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
public class ClassificationConfigTests extends AbstractSerializingTestCase<ClassificationConfig> {
public class ClassificationConfigTests extends AbstractBWCSerializationTestCase<ClassificationConfig> {
private boolean lenient;
public static ClassificationConfig randomClassificationConfig() {
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10),
@ -26,23 +25,17 @@ public class ClassificationConfigTests extends AbstractSerializingTestCase<Class
);
}
public void testFromMap() {
ClassificationConfig expected = ClassificationConfig.EMPTY_PARAMS;
assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected));
expected = new ClassificationConfig(3, "foo", "bar", 2);
Map<String, Object> configMap = new HashMap<>();
configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo");
configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar");
configMap.put(ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 2);
assertThat(ClassificationConfig.fromMap(configMap), equalTo(expected));
public static ClassificationConfig mutateForVersion(ClassificationConfig instance, Version version) {
ClassificationConfig.Builder builder = new ClassificationConfig.Builder(instance);
if (version.before(Version.V_7_7_0)) {
builder.setNumTopFeatureImportanceValues(0);
}
return builder.build();
}
public void testFromMapWithUnknownField() {
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> ClassificationConfig.fromMap(Collections.singletonMap("some_key", 1)));
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
@ -57,6 +50,16 @@ public class ClassificationConfigTests extends AbstractSerializingTestCase<Class
@Override
protected ClassificationConfig doParseInstance(XContentParser parser) throws IOException {
return ClassificationConfig.fromXContent(parser);
return lenient ? ClassificationConfig.fromXContentLenient(parser) : ClassificationConfig.fromXContentStrict(parser);
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected ClassificationConfig mutateInstanceForVersion(ClassificationConfig instance, Version version) {
return mutateForVersion(instance, version);
}
}

View File

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTestCase<ClassificationConfigUpdate> {
public static ClassificationConfigUpdate randomClassificationConfig() {
return new ClassificationConfigUpdate(randomBoolean() ? null : randomIntBetween(-1, 10),
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomIntBetween(0, 10)
);
}
public void testFromMap() {
ClassificationConfigUpdate expected = new ClassificationConfigUpdate(null, null, null, null);
assertThat(ClassificationConfigUpdate.fromMap(Collections.emptyMap()), equalTo(expected));
expected = new ClassificationConfigUpdate(3, "foo", "bar", 2);
Map<String, Object> configMap = new HashMap<>();
configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo");
configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar");
configMap.put(ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 2);
assertThat(ClassificationConfigUpdate.fromMap(configMap), equalTo(expected));
}
public void testFromMapWithUnknownField() {
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> ClassificationConfigUpdate.fromMap(Collections.singletonMap("some_key", 1)));
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
}
@Override
protected ClassificationConfigUpdate createTestInstance() {
return randomClassificationConfig();
}
@Override
protected Writeable.Reader<ClassificationConfigUpdate> instanceReader() {
return ClassificationConfigUpdate::new;
}
@Override
protected ClassificationConfigUpdate doParseInstance(XContentParser parser) throws IOException {
return ClassificationConfigUpdate.fromXContentStrict(parser);
}
@Override
protected ClassificationConfigUpdate mutateInstanceForVersion(ClassificationConfigUpdate instance, Version version) {
return instance;
}
}

View File

@ -5,37 +5,32 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.junit.Before;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
public class RegressionConfigTests extends AbstractSerializingTestCase<RegressionConfig> {
public class RegressionConfigTests extends AbstractBWCSerializationTestCase<RegressionConfig> {
private boolean lenient;
public static RegressionConfig randomRegressionConfig() {
return new RegressionConfig(randomBoolean() ? null : randomAlphaOfLength(10));
}
public void testFromMap() {
RegressionConfig expected = new RegressionConfig("foo", 3);
Map<String, Object> config = new HashMap<String, Object>(){{
put(RegressionConfig.RESULTS_FIELD.getPreferredName(), "foo");
put(RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 3);
}};
assertThat(RegressionConfig.fromMap(config), equalTo(expected));
public static RegressionConfig mutateForVersion(RegressionConfig instance, Version version) {
RegressionConfig.Builder builder = new RegressionConfig.Builder(instance);
if (version.before(Version.V_7_7_0)) {
builder.setNumTopFeatureImportanceValues(0);
}
return builder.build();
}
public void testFromMapWithUnknownField() {
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> RegressionConfig.fromMap(Collections.singletonMap("some_key", 1)));
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
@ -50,6 +45,16 @@ public class RegressionConfigTests extends AbstractSerializingTestCase<Regressio
@Override
protected RegressionConfig doParseInstance(XContentParser parser) throws IOException {
return RegressionConfig.fromXContent(parser);
return lenient ? RegressionConfig.fromXContentLenient(parser) : RegressionConfig.fromXContentStrict(parser);
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected RegressionConfig mutateInstanceForVersion(RegressionConfig instance, Version version) {
return mutateForVersion(instance, version);
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCase<RegressionConfigUpdate> {
public static RegressionConfigUpdate randomRegressionConfig() {
return new RegressionConfigUpdate(randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomIntBetween(0, 10));
}
public void testFromMap() {
RegressionConfigUpdate expected = new RegressionConfigUpdate("foo", 3);
Map<String, Object> config = new HashMap<String, Object>(){{
put(RegressionConfig.RESULTS_FIELD.getPreferredName(), "foo");
put(RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 3);
}};
assertThat(RegressionConfigUpdate.fromMap(config), equalTo(expected));
}
public void testFromMapWithUnknownField() {
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> RegressionConfigUpdate.fromMap(Collections.singletonMap("some_key", 1)));
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
}
@Override
protected RegressionConfigUpdate createTestInstance() {
return randomRegressionConfig();
}
@Override
protected Writeable.Reader<RegressionConfigUpdate> instanceReader() {
return RegressionConfigUpdate::new;
}
@Override
protected RegressionConfigUpdate doParseInstance(XContentParser parser) throws IOException {
return RegressionConfigUpdate.fromXContentStrict(parser);
}
@Override
protected RegressionConfigUpdate mutateInstanceForVersion(RegressionConfigUpdate instance, Version version) {
return instance;
}
}

View File

@ -65,6 +65,10 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
}
public static Ensemble createRandom() {
return createRandom(randomFrom(TargetType.values()));
}
public static Ensemble createRandom(TargetType targetType) {
int numberOfFeatures = randomIntBetween(1, 10);
List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList());
int numberOfModels = randomIntBetween(1, 10);
@ -74,7 +78,6 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
double[] weights = randomBoolean() ?
null :
Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).mapToDouble(Double::valueOf).toArray();
TargetType targetType = randomFrom(TargetType.values());
List<String> categoryLabels = null;
if (randomBoolean() && targetType == TargetType.CLASSIFICATION) {
categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10));

View File

@ -64,16 +64,20 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
return createRandom();
}
public static Tree createRandom() {
public static Tree createRandom(TargetType targetType) {
int numberOfFeatures = randomIntBetween(1, 10);
List<String> featureNames = new ArrayList<>();
for (int i = 0; i < numberOfFeatures; i++) {
featureNames.add(randomAlphaOfLength(10));
}
return buildRandomTree(featureNames, 6);
return buildRandomTree(targetType, featureNames, 6);
}
public static Tree buildRandomTree(List<String> featureNames, int depth) {
public static Tree createRandom() {
return createRandom(randomFrom(TargetType.values()));
}
public static Tree buildRandomTree(TargetType targetType, List<String> featureNames, int depth) {
Tree.Builder builder = Tree.builder();
int maxFeatureIndex = featureNames.size() - 1;
builder.setFeatureNames(featureNames);
@ -96,7 +100,6 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
}
childNodes = nextNodes;
}
TargetType targetType = randomFrom(TargetType.values());
List<String> categoryLabels = null;
if (randomBoolean() && targetType == TargetType.CLASSIFICATION) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
@ -105,6 +108,10 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
return builder.setTargetType(targetType).setClassificationLabels(categoryLabels).build();
}
public static Tree buildRandomTree(List<String> featureNames, int depth) {
return buildRandomTree(randomFrom(TargetType.values()), featureNames, depth);
}
@Override
protected Writeable.Reader<Tree> instanceReader() {
return Tree::new;

View File

@ -141,9 +141,11 @@ integTest.runner {
'ml/inference_crud/Test put ensemble with empty models',
'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index',
'ml/inference_crud/Test put model with empty input.field_names',
'ml/inference_crud/Test PUT model where target type and inference config mismatch',
'ml/inference_processor/Test create processor with missing mandatory fields',
'ml/inference_processor/Test create and delete pipeline with inference processor',
'ml/inference_processor/Test create processor with deprecated fields',
'ml/inference_processor/Test simulate',
'ml/inference_stats_crud/Test get stats given missing trained model',
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
'ml/jobs_crud/Test cannot create job with existing categorizer state document',

View File

@ -424,6 +424,7 @@ public class InferenceIngestIT extends ESRestTestCase {
private static final String REGRESSION_CONFIG = "{" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for regression\",\n" +
" \"inference_config\": {\"regression\": {}},\n" +
" \"definition\": " + REGRESSION_DEFINITION +
"}";
@ -564,6 +565,7 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for classification\",\n" +
" \"default_field_map\": {\"col_1_alias\": \"col1\"},\n" +
" \"inference_config\": {\"classification\": {}},\n" +
" \"definition\": " + CLASSIFICATION_DEFINITION +
"}";

View File

@ -12,6 +12,7 @@ import org.elasticsearch.client.ResponseException;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelInput;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
@ -193,6 +194,7 @@ public class TrainedModelIT extends ESRestTestCase {
.setTrainedModel(buildRegression());
TrainedModelConfig.builder()
.setDefinition(definition)
.setInferenceConfig(new RegressionConfig())
.setModelId(modelId)
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3")))
.build().toXContent(builder, ToXContent.EMPTY_PARAMS);

View File

@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.ml.inference.loadingservice.Model;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@ -48,11 +49,12 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
}
@Override
@SuppressWarnings("unchecked")
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
Response.Builder responseBuilder = Response.builder();
ActionListener<Model> getModelListener = ActionListener.wrap(
ActionListener<Model<? extends InferenceConfig>> getModelListener = ActionListener.wrap(
model -> {
TypedChainTaskExecutor<InferenceResults> typedChainTaskExecutor =
new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME),
@ -62,7 +64,9 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
ex -> true);
request.getObjectsToInfer().forEach(stringObjectMap ->
typedChainTaskExecutor.add(chainedTask ->
model.infer(stringObjectMap, request.getConfig(), chainedTask)));
// The InferenceConfigUpdate here is unchecked, initially.
// It gets checked when it is applied
model.infer(stringObjectMap, request.getUpdate(), chainedTask)));
typedChainTaskExecutor.execute(ActionListener.wrap(
inferenceResultsInterfaces ->

View File

@ -93,6 +93,22 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Re
request.getTrainedModelConfig().getModelId()));
return;
}
if (request.getTrainedModelConfig()
.getInferenceConfig()
.isTargetTypeSupported(request.getTrainedModelConfig()
.getModelDefinition()
.getTrainedModel()
.targetType()) == false) {
listener.onFailure(ExceptionsHelper.badRequestException(
"Model [{}] inference config type [{}] does not support definition target type [{}]",
request.getTrainedModelConfig().getModelId(),
request.getTrainedModelConfig().getInferenceConfig().getName(),
request.getTrainedModelConfig()
.getModelDefinition()
.getTrainedModel()
.targetType()));
return;
}
Version minCompatibilityVersion = request.getTrainedModelConfig()
.getModelDefinition()

View File

@ -25,6 +25,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStat
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.security.user.XPackUser;
@ -230,9 +234,34 @@ public class AnalyticsResultProcessor {
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setDefaultFieldMap(defaultFieldMapping)
.setInferenceConfig(buildInferenceConfig(definition.getTrainedModel().targetType()))
.build();
}
private InferenceConfig buildInferenceConfig(TargetType targetType) {
switch (targetType) {
case CLASSIFICATION:
assert analytics.getAnalysis() instanceof Classification;
Classification classification = ((Classification)analytics.getAnalysis());
return ClassificationConfig.builder()
.setNumTopClasses(classification.getNumTopClasses())
.setNumTopFeatureImportanceValues(classification.getBoostedTreeParams().getNumTopFeatureImportanceValues())
.build();
case REGRESSION:
assert analytics.getAnalysis() instanceof Regression;
Regression regression = ((Regression)analytics.getAnalysis());
return RegressionConfig.builder()
.setNumTopFeatureImportanceValues(regression.getBoostedTreeParams().getNumTopFeatureImportanceValues())
.build();
default:
setAndReportFailure(ExceptionsHelper.serverError(
"process created a model with an unsupported target type [{}]",
null,
targetType));
return null;
}
}
private String getDependentVariable() {
if (analytics.getAnalysis() instanceof Classification) {
return ((Classification)analytics.getAnalysis()).getDependentVariable();

View File

@ -31,8 +31,11 @@ import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.loadingservice.Model;
@ -71,7 +74,7 @@ public class InferenceProcessor extends AbstractProcessor {
private final String modelId;
private final String targetField;
private final InferenceConfig inferenceConfig;
private final InferenceConfigUpdate<? extends InferenceConfig> inferenceConfig;
private final Map<String, String> fieldMap;
private final InferenceAuditor auditor;
private volatile boolean previouslyLicensed;
@ -82,7 +85,7 @@ public class InferenceProcessor extends AbstractProcessor {
String tag,
String targetField,
String modelId,
InferenceConfig inferenceConfig,
InferenceConfigUpdate<? extends InferenceConfig> inferenceConfig,
Map<String, String> fieldMap) {
super(tag);
this.client = ExceptionsHelper.requireNonNull(client, "client");
@ -245,7 +248,8 @@ public class InferenceProcessor extends AbstractProcessor {
LoggingDeprecationHandler.INSTANCE.usedDeprecatedName(null, () -> null, FIELD_MAPPINGS, FIELD_MAP);
}
}
InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG));
InferenceConfigUpdate<? extends InferenceConfig> inferenceConfig =
inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG));
return new InferenceProcessor(client,
auditor,
@ -262,7 +266,7 @@ public class InferenceProcessor extends AbstractProcessor {
this.maxIngestProcessors = maxIngestProcessors;
}
InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) {
InferenceConfigUpdate<? extends InferenceConfig> inferenceConfigFromMap(Map<String, Object> inferenceConfig) {
ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
if (inferenceConfig.size() != 1) {
throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.",
@ -279,12 +283,12 @@ public class InferenceProcessor extends AbstractProcessor {
if (inferenceConfig.containsKey(ClassificationConfig.NAME.getPreferredName())) {
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
ClassificationConfig config = ClassificationConfig.fromMap(valueMap);
ClassificationConfigUpdate config = ClassificationConfigUpdate.fromMap(valueMap);
checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField());
return config;
} else if (inferenceConfig.containsKey(RegressionConfig.NAME.getPreferredName())) {
checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
RegressionConfig config = RegressionConfig.fromMap(valueMap);
RegressionConfigUpdate config = RegressionConfigUpdate.fromMap(valueMap);
checkFieldUniqueness(config.getResultsField());
return config;
} else {
@ -298,6 +302,9 @@ public class InferenceProcessor extends AbstractProcessor {
Set<String> duplicatedFieldNames = new HashSet<>();
Set<String> currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES);
for(String fieldName : fieldNames) {
if (fieldName == null) {
continue;
}
if (currentFieldNames.contains(fieldName)) {
duplicatedFieldNames.add(fieldName);
} else {

View File

@ -10,6 +10,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
@ -24,21 +25,24 @@ import java.util.Set;
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;
public class LocalModel implements Model {
public class LocalModel<T extends InferenceConfig> implements Model<T> {
private final TrainedModelDefinition trainedModelDefinition;
private final String modelId;
private final Set<String> fieldNames;
private final Map<String, String> defaultFieldMap;
private final T inferenceConfig;
public LocalModel(String modelId,
TrainedModelDefinition trainedModelDefinition,
TrainedModelInput input,
Map<String, String> defaultFieldMap) {
Map<String, String> defaultFieldMap,
T modelInferenceConfig) {
this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId;
this.fieldNames = new HashSet<>(input.getFieldNames());
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
this.inferenceConfig = modelInferenceConfig;
}
long ramBytesUsed() {
@ -65,7 +69,15 @@ public class LocalModel implements Model {
}
@Override
public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
public void infer(Map<String, Object> fields, InferenceConfigUpdate<T> update, ActionListener<InferenceResults> listener) {
if (update.isSupported(this.inferenceConfig) == false) {
listener.onFailure(ExceptionsHelper.badRequestException(
"Model [{}] has inference config of type [{}] which is not supported by inference request of type [{}]",
this.modelId,
this.inferenceConfig.getName(),
update.getName()));
return;
}
try {
Model.mapFieldsIfNecessary(fields, defaultFieldMap);
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
@ -73,7 +85,7 @@ public class LocalModel implements Model {
return;
}
listener.onResponse(trainedModelDefinition.infer(fields, config));
listener.onResponse(trainedModelDefinition.infer(fields, update.apply(inferenceConfig)));
} catch (Exception e) {
listener.onFailure(e);
}

View File

@ -8,15 +8,16 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.util.Map;
public interface Model {
public interface Model<T extends InferenceConfig> {
String getResultsType();
void infer(Map<String, Object> fields, InferenceConfig inferenceConfig, ActionListener<InferenceResults> listener);
void infer(Map<String, Object> fields, InferenceConfigUpdate<T> inferenceConfig, ActionListener<InferenceResults> listener);
String getModelId();

View File

@ -27,6 +27,11 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@ -78,9 +83,9 @@ public class ModelLoadingService implements ClusterStateListener {
Setting.Property.NodeScope);
private static final Logger logger = LogManager.getLogger(ModelLoadingService.class);
private final Cache<String, LocalModel> localModelCache;
private final Cache<String, LocalModel<? extends InferenceConfig>> localModelCache;
private final Set<String> referencedModels = new HashSet<>();
private final Map<String, Queue<ActionListener<Model>>> loadingListeners = new HashMap<>();
private final Map<String, Queue<ActionListener<Model<? extends InferenceConfig>>>> loadingListeners = new HashMap<>();
private final TrainedModelProvider provider;
private final Set<String> shouldNotAudit;
private final ThreadPool threadPool;
@ -100,7 +105,7 @@ public class ModelLoadingService implements ClusterStateListener {
this.auditor = auditor;
this.shouldNotAudit = new HashSet<>();
this.namedXContentRegistry = namedXContentRegistry;
this.localModelCache = CacheBuilder.<String, LocalModel>builder()
this.localModelCache = CacheBuilder.<String, LocalModel<? extends InferenceConfig>>builder()
.setMaximumWeight(this.maxCacheSize.getBytes())
.weigher((id, localModel) -> localModel.ramBytesUsed())
.removalListener(this::cacheEvictionListener)
@ -126,8 +131,8 @@ public class ModelLoadingService implements ClusterStateListener {
* @param modelId the model to get
* @param modelActionListener the listener to alert when the model has been retrieved.
*/
public void getModel(String modelId, ActionListener<Model> modelActionListener) {
LocalModel cachedModel = localModelCache.get(modelId);
public void getModel(String modelId, ActionListener<Model<? extends InferenceConfig>> modelActionListener) {
LocalModel<? extends InferenceConfig> cachedModel = localModelCache.get(modelId);
if (cachedModel != null) {
modelActionListener.onResponse(cachedModel);
logger.trace("[{}] loaded from cache", modelId);
@ -138,12 +143,18 @@ public class ModelLoadingService implements ClusterStateListener {
// by a simulated pipeline
logger.trace("[{}] not actively loading, eager loading without cache", modelId);
provider.getTrainedModel(modelId, true, ActionListener.wrap(
trainedModelConfig ->
modelActionListener.onResponse(new LocalModel(
trainedModelConfig -> {
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry);
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) :
trainedModelConfig.getInferenceConfig();
modelActionListener.onResponse(new LocalModel<>(
trainedModelConfig.getModelId(),
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(),
trainedModelConfig.getModelDefinition(),
trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap())),
trainedModelConfig.getDefaultFieldMap(),
inferenceConfig));
},
modelActionListener::onFailure
));
} else {
@ -156,9 +167,9 @@ public class ModelLoadingService implements ClusterStateListener {
* Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded
* Returns false if the model is not loaded or actively being loaded
*/
private boolean loadModelIfNecessary(String modelId, ActionListener<Model> modelActionListener) {
private boolean loadModelIfNecessary(String modelId, ActionListener<Model<? extends InferenceConfig>> modelActionListener) {
synchronized (loadingListeners) {
Model cachedModel = localModelCache.get(modelId);
Model<? extends InferenceConfig> cachedModel = localModelCache.get(modelId);
if (cachedModel != null) {
modelActionListener.onResponse(cachedModel);
return true;
@ -197,12 +208,17 @@ public class ModelLoadingService implements ClusterStateListener {
}
private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelConfig) throws IOException {
Queue<ActionListener<Model>> listeners;
LocalModel loadedModel = new LocalModel(
Queue<ActionListener<Model<? extends InferenceConfig>>> listeners;
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry);
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) :
trainedModelConfig.getInferenceConfig();
LocalModel<? extends InferenceConfig> loadedModel = new LocalModel<>(
trainedModelConfig.getModelId(),
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(),
trainedModelConfig.getModelDefinition(),
trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap());
trainedModelConfig.getDefaultFieldMap(),
inferenceConfig);
synchronized (loadingListeners) {
listeners = loadingListeners.remove(modelId);
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such
@ -213,13 +229,13 @@ public class ModelLoadingService implements ClusterStateListener {
localModelCache.put(modelId, loadedModel);
shouldNotAudit.remove(modelId);
} // synchronized (loadingListeners)
for (ActionListener<Model> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
for (ActionListener<Model<? extends InferenceConfig>> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
listener.onResponse(loadedModel);
}
}
private void handleLoadFailure(String modelId, Exception failure) {
Queue<ActionListener<Model>> listeners;
Queue<ActionListener<Model<? extends InferenceConfig>>> listeners;
synchronized (loadingListeners) {
listeners = loadingListeners.remove(modelId);
if (listeners == null) {
@ -228,12 +244,12 @@ public class ModelLoadingService implements ClusterStateListener {
} // synchronized (loadingListeners)
// If we failed to load and there were listeners present, that means that this model is referenced by a processor
// Alert the listeners to the failure
for (ActionListener<Model> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
for (ActionListener<Model<? extends InferenceConfig>> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
listener.onFailure(failure);
}
}
private void cacheEvictionListener(RemovalNotification<String, LocalModel> notification) {
private void cacheEvictionListener(RemovalNotification<String, LocalModel<? extends InferenceConfig>> notification) {
if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
String msg = new ParameterizedMessage(
"model cache entry evicted." +
@ -263,7 +279,7 @@ public class ModelLoadingService implements ClusterStateListener {
return;
}
// The listeners still waiting for a model and we are canceling the load?
List<Tuple<String, List<ActionListener<Model>>>> drainWithFailure = new ArrayList<>();
List<Tuple<String, List<ActionListener<Model<? extends InferenceConfig>>>>> drainWithFailure = new ArrayList<>();
Set<String> referencedModelsBeforeClusterState = null;
Set<String> loadingModelBeforeClusterState = null;
Set<String> removedModels = null;
@ -306,11 +322,11 @@ public class ModelLoadingService implements ClusterStateListener {
referencedModels);
}
}
for (Tuple<String, List<ActionListener<Model>>> modelAndListeners : drainWithFailure) {
for (Tuple<String, List<ActionListener<Model<? extends InferenceConfig>>>> modelAndListeners : drainWithFailure) {
final String msg = new ParameterizedMessage(
"Cancelling load of model [{}] as it is no longer referenced by a pipeline",
modelAndListeners.v1()).getFormat();
for (ActionListener<Model> listener : modelAndListeners.v2()) {
for (ActionListener<Model<? extends InferenceConfig>> listener : modelAndListeners.v2()) {
listener.onFailure(new ElasticsearchException(msg));
}
}
@ -379,4 +395,14 @@ public class ModelLoadingService implements ClusterStateListener {
return allReferencedModelKeys;
}
private static InferenceConfig inferenceConfigFromTargetType(TargetType targetType) {
switch(targetType) {
case REGRESSION:
return RegressionConfig.EMPTY_PARAMS;
case CLASSIFICATION:
return ClassificationConfig.EMPTY_PARAMS;
default:
throw ExceptionsHelper.badRequestException("unsupported target type [{}]", targetType);
}
}
}

View File

@ -45,6 +45,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
@ -694,7 +695,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
RegressionConfig.EMPTY_PARAMS,
RegressionConfigUpdate.EMPTY_PARAMS,
false
), inferModelSuccess);
InternalInferModelAction.Response response = inferModelSuccess.actionGet();
@ -711,7 +712,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
RegressionConfig.EMPTY_PARAMS,
RegressionConfigUpdate.EMPTY_PARAMS,
false
)).actionGet();
});
@ -724,7 +725,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
RegressionConfig.EMPTY_PARAMS,
RegressionConfigUpdate.EMPTY_PARAMS,
true
), inferModelSuccess);
response = inferModelSuccess.actionGet();
@ -740,7 +741,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
client().execute(InternalInferModelAction.INSTANCE, new InternalInferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
RegressionConfig.EMPTY_PARAMS,
RegressionConfigUpdate.EMPTY_PARAMS,
false
), listener);
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
@ -760,6 +761,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
.setModelId(modelId)
.setDescription("test model for classification")
.setInput(new TrainedModelInput(Arrays.asList("feature1")))
.setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
.build();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
}

View File

@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
@ -168,7 +169,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
extractedFieldList.add(new DocValueField("foo", Collections.emptySet()));
extractedFieldList.add(new MultiField("bar", new DocValueField("bar.keyword", Collections.emptySet())));
extractedFieldList.add(new DocValueField("baz", Collections.emptySet()));
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder();
TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList);
@ -190,6 +192,11 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar.keyword", "baz")));
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed()));
assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations()));
if (targetType.equals(TargetType.CLASSIFICATION)) {
assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification"));
} else {
assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression"));
}
Map<String, Object> metadata = storedModel.getMetadata();
assertThat(metadata.size(), equalTo(1));
assertThat(metadata, hasKey("analytics_config"));
@ -213,7 +220,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
return null;
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder();
TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor();

View File

@ -14,7 +14,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import org.junit.Before;
@ -53,7 +55,7 @@ public class InferenceProcessorTests extends ESTestCase {
"my_processor",
targetField,
"classification_model",
ClassificationConfig.EMPTY_PARAMS,
ClassificationConfigUpdate.EMPTY_PARAMS,
Collections.emptyMap());
Map<String, Object> source = new HashMap<>();
@ -75,13 +77,14 @@ public class InferenceProcessorTests extends ESTestCase {
@SuppressWarnings("unchecked")
public void testMutateDocumentClassificationTopNClasses() {
ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null);
ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, null);
ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, null);
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
"ml.my_processor",
"classification_model",
classificationConfig,
classificationConfigUpdate,
Collections.emptyMap());
Map<String, Object> source = new HashMap<>();
@ -105,12 +108,13 @@ public class InferenceProcessorTests extends ESTestCase {
public void testMutateDocumentClassificationFeatureInfluence() {
ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2);
ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, 2);
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
"ml.my_processor",
"classification_model",
classificationConfig,
classificationConfigUpdate,
Collections.emptyMap());
Map<String, Object> source = new HashMap<>();
@ -145,12 +149,13 @@ public class InferenceProcessorTests extends ESTestCase {
@SuppressWarnings("unchecked")
public void testMutateDocumentClassificationTopNClassesWithSpecificField() {
ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops");
ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, "result", "tops", null);
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
"ml.my_processor",
"classification_model",
classificationConfig,
classificationConfigUpdate,
Collections.emptyMap());
Map<String, Object> source = new HashMap<>();
@ -174,12 +179,13 @@ public class InferenceProcessorTests extends ESTestCase {
public void testMutateDocumentRegression() {
RegressionConfig regressionConfig = new RegressionConfig("foo");
RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", null);
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
"ml.my_processor",
"regression_model",
regressionConfig,
regressionConfigUpdate,
Collections.emptyMap());
Map<String, Object> source = new HashMap<>();
@ -196,12 +202,13 @@ public class InferenceProcessorTests extends ESTestCase {
public void testMutateDocumentRegressionWithTopFetures() {
RegressionConfig regressionConfig = new RegressionConfig("foo", 2);
RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", 2);
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
"ml.my_processor",
"regression_model",
regressionConfig,
regressionConfigUpdate,
Collections.emptyMap());
Map<String, Object> source = new HashMap<>();
@ -233,7 +240,7 @@ public class InferenceProcessorTests extends ESTestCase {
"my_processor",
"my_field",
modelId,
new ClassificationConfig(topNClasses, null, null),
new ClassificationConfigUpdate(topNClasses, null, null, null),
Collections.emptyMap());
Map<String, Object> source = new HashMap<String, Object>(){{
@ -262,7 +269,7 @@ public class InferenceProcessorTests extends ESTestCase {
"my_processor",
"my_field",
modelId,
new ClassificationConfig(topNClasses, null, null),
new ClassificationConfigUpdate(topNClasses, null, null, null),
fieldMapping);
Map<String, Object> source = new HashMap<String, Object>(5){{
@ -298,7 +305,7 @@ public class InferenceProcessorTests extends ESTestCase {
"my_processor",
"my_field",
modelId,
new ClassificationConfig(topNClasses, null, null),
new ClassificationConfigUpdate(topNClasses, null, null, null),
fieldMapping);
Map<String, Object> source = new HashMap<String, Object>(5){{
@ -326,7 +333,7 @@ public class InferenceProcessorTests extends ESTestCase {
"my_processor",
targetField,
"regression_model",
RegressionConfig.EMPTY_PARAMS,
RegressionConfigUpdate.EMPTY_PARAMS,
Collections.emptyMap());
Map<String, Object> source = new HashMap<>();
@ -369,7 +376,7 @@ public class InferenceProcessorTests extends ESTestCase {
"my_processor",
"ml",
"regression_model",
RegressionConfig.EMPTY_PARAMS,
RegressionConfigUpdate.EMPTY_PARAMS,
Collections.emptyMap());
Map<String, Object> source = new HashMap<>();

View File

@ -13,8 +13,11 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
@ -31,7 +34,6 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.closeTo;
@ -48,22 +50,23 @@ public class LocalModelTests extends ESTestCase {
.setTrainedModel(buildClassification(false))
.build();
Model model = new LocalModel(modelId,
Model<ClassificationConfig> model = new LocalModel<>(modelId,
definition,
new TrainedModelInput(inputFields),
Collections.singletonMap("field.foo", "field.foo.keyword"));
Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS);
Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0);
put("field.bar", 0.5);
put("categorical", "dog");
}};
SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfig(0));
SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), is("0"));
ClassificationInferenceResults classificationResult =
(ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1));
(ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfigUpdate(1, null, null, null));
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0"));
@ -72,22 +75,29 @@ public class LocalModelTests extends ESTestCase {
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(true))
.build();
model = new LocalModel(modelId,
model = new LocalModel<>(modelId,
definition,
new TrainedModelInput(inputFields),
Collections.singletonMap("field.foo", "field.foo.keyword"));
result = getSingleValue(model, fields, new ClassificationConfig(0));
Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS);
result = getSingleValue(model, fields, new ClassificationConfigUpdate(0, null, null, null));
assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), equalTo("not_to_be"));
classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1));
classificationResult = (ClassificationInferenceResults)getSingleValue(model,
fields,
new ClassificationConfigUpdate(1, null, null, null));
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be"));
classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(2));
classificationResult = (ClassificationInferenceResults)getSingleValue(model,
fields,
new ClassificationConfigUpdate(2, null, null, null));
assertThat(classificationResult.getTopClasses(), hasSize(2));
classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(-1));
classificationResult = (ClassificationInferenceResults)getSingleValue(model,
fields,
new ClassificationConfigUpdate(-1, null, null, null));
assertThat(classificationResult.getTopClasses(), hasSize(2));
}
@ -97,10 +107,11 @@ public class LocalModelTests extends ESTestCase {
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildRegression())
.build();
Model model = new LocalModel("regression_model",
Model<RegressionConfig> model = new LocalModel<>("regression_model",
trainedModelDefinition,
new TrainedModelInput(inputFields),
Collections.singletonMap("bar", "bar.keyword"));
Collections.singletonMap("bar", "bar.keyword"),
RegressionConfig.EMPTY_PARAMS);
Map<String, Object> fields = new HashMap<String, Object>() {{
put("foo", 1.0);
@ -108,14 +119,8 @@ public class LocalModelTests extends ESTestCase {
put("categorical", "dog");
}};
SingleValueInferenceResults results = getSingleValue(model, fields, RegressionConfig.EMPTY_PARAMS);
SingleValueInferenceResults results = getSingleValue(model, fields, RegressionConfigUpdate.EMPTY_PARAMS);
assertThat(results.value(), equalTo(1.3));
PlainActionFuture<InferenceResults> failedFuture = new PlainActionFuture<>();
model.infer(fields, new ClassificationConfig(2), failedFuture);
ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get);
assertThat(ex.getCause().getMessage(),
equalTo("Cannot infer using configuration for [classification] when model target_type is [regression]"));
}
public void testAllFieldsMissing() throws Exception {
@ -124,7 +129,12 @@ public class LocalModelTests extends ESTestCase {
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildRegression())
.build();
Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields), null);
Model<RegressionConfig> model = new LocalModel<>(
"regression_model",
trainedModelDefinition,
new TrainedModelInput(inputFields),
null,
RegressionConfig.EMPTY_PARAMS);
Map<String, Object> fields = new HashMap<String, Object>() {{
put("something", 1.0);
@ -132,18 +142,21 @@ public class LocalModelTests extends ESTestCase {
put("baz", "dog");
}};
WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfig.EMPTY_PARAMS);
WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfigUpdate.EMPTY_PARAMS);
assertThat(results.getWarning(),
equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, "regression_model")));
}
private static SingleValueInferenceResults getSingleValue(Model model,
Map<String, Object> fields,
InferenceConfig config) throws Exception {
private static <T extends InferenceConfig> SingleValueInferenceResults getSingleValue(Model<T> model,
Map<String, Object> fields,
InferenceConfigUpdate<T> config)
throws Exception {
return (SingleValueInferenceResults)getInferenceResult(model, fields, config);
}
private static InferenceResults getInferenceResult(Model model, Map<String, Object> fields, InferenceConfig config) throws Exception {
private static <T extends InferenceConfig> InferenceResults getInferenceResult(Model<T> model,
Map<String, Object> fields,
InferenceConfigUpdate<T> config) throws Exception {
PlainActionFuture<InferenceResults> future = new PlainActionFuture<>();
model.infer(fields, config, future);
return future.get();

View File

@ -36,6 +36,8 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@ -111,7 +113,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
String[] modelIds = new String[]{model1, model2, model3};
for(int i = 0; i < 10; i++) {
String model = modelIds[i%3];
PlainActionFuture<Model> future = new PlainActionFuture<>();
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
assertThat(future.get(), is(not(nullValue())));
}
@ -124,7 +126,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
for(int i = 0; i < 10; i++) {
String model = modelIds[i%3];
PlainActionFuture<Model> future = new PlainActionFuture<>();
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
assertThat(future.get(), is(not(nullValue())));
}
@ -164,7 +166,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
for(int i = 0; i < 10; i++) {
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
String model = modelIds[i%2];
PlainActionFuture<Model> future = new PlainActionFuture<>();
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
assertThat(future.get(), is(not(nullValue())));
}
@ -176,7 +178,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
// Load model 3, should invalidate 1
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model> future3 = new PlainActionFuture<>();
PlainActionFuture<Model<? extends InferenceConfig>> future3 = new PlainActionFuture<>();
modelLoadingService.getModel(model3, future3);
assertThat(future3.get(), is(not(nullValue())));
}
@ -184,7 +186,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
// Load model 1, should invalidate 2
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model> future1 = new PlainActionFuture<>();
PlainActionFuture<Model<? extends InferenceConfig>> future1 = new PlainActionFuture<>();
modelLoadingService.getModel(model1, future1);
assertThat(future1.get(), is(not(nullValue())));
}
@ -192,7 +194,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
// Load model 2, should invalidate 3
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model> future2 = new PlainActionFuture<>();
PlainActionFuture<Model<? extends InferenceConfig>> future2 = new PlainActionFuture<>();
modelLoadingService.getModel(model2, future2);
assertThat(future2.get(), is(not(nullValue())));
}
@ -204,7 +206,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
for(int i = 0; i < 10; i++) {
String model = modelIds[i%3];
PlainActionFuture<Model> future = new PlainActionFuture<>();
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
assertThat(future.get(), is(not(nullValue())));
}
@ -230,7 +232,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
modelLoadingService.clusterChanged(ingestChangedEvent(false, model1));
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model> future = new PlainActionFuture<>();
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
modelLoadingService.getModel(model1, future);
assertThat(future.get(), is(not(nullValue())));
}
@ -250,7 +252,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
Settings.EMPTY);
modelLoadingService.clusterChanged(ingestChangedEvent(model));
PlainActionFuture<Model> future = new PlainActionFuture<>();
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
try {
@ -274,7 +276,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
NamedXContentRegistry.EMPTY,
Settings.EMPTY);
PlainActionFuture<Model> future = new PlainActionFuture<>();
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
try {
future.get();
@ -296,7 +298,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
Settings.EMPTY);
for(int i = 0; i < 3; i++) {
PlainActionFuture<Model> future = new PlainActionFuture<>();
PlainActionFuture<Model<? extends InferenceConfig>> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
assertThat(future.get(), is(not(nullValue())));
}
@ -310,6 +312,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
when(definition.ramBytesUsed()).thenReturn(size);
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
when(trainedModelConfig.getModelDefinition()).thenReturn(definition);
when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS);
when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz")));
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")

View File

@ -20,8 +20,8 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
@ -146,20 +146,20 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
// Test regression
InternalInferModelAction.Request request = new InternalInferModelAction.Request(modelId1,
toInfer,
RegressionConfig.EMPTY_PARAMS,
RegressionConfigUpdate.EMPTY_PARAMS,
true);
InternalInferModelAction.Response response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()),
contains(1.3, 1.25));
request = new InternalInferModelAction.Request(modelId1, toInfer2, RegressionConfig.EMPTY_PARAMS, true);
request = new InternalInferModelAction.Request(modelId1, toInfer2, RegressionConfigUpdate.EMPTY_PARAMS, true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()),
contains(1.65, 1.55));
// Test classification
request = new InternalInferModelAction.Request(modelId2, toInfer, ClassificationConfig.EMPTY_PARAMS, true);
request = new InternalInferModelAction.Request(modelId2, toInfer, ClassificationConfigUpdate.EMPTY_PARAMS, true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults()
.stream()
@ -168,7 +168,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
contains("not_to_be", "to_be"));
// Get top classes
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2, null, null), true);
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null), true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
ClassificationInferenceResults classificationInferenceResults =
@ -187,7 +187,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
// Test that top classes restrict the number returned
request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1, null, null), true);
request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfigUpdate(1, null, null, null), true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0);
@ -262,7 +262,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
// Test regression
InternalInferModelAction.Request request = new InternalInferModelAction.Request(modelId,
toInfer,
ClassificationConfig.EMPTY_PARAMS,
ClassificationConfigUpdate.EMPTY_PARAMS,
true);
InternalInferModelAction.Response response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults()
@ -271,7 +271,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
.collect(Collectors.toList()),
contains("option_0", "option_2"));
request = new InternalInferModelAction.Request(modelId, toInfer2, ClassificationConfig.EMPTY_PARAMS, true);
request = new InternalInferModelAction.Request(modelId, toInfer2, ClassificationConfigUpdate.EMPTY_PARAMS, true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults()
.stream()
@ -281,7 +281,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
// Get top classes
request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfig(3, null, null), true);
request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfigUpdate(3, null, null, null), true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
ClassificationInferenceResults classificationInferenceResults =
@ -303,7 +303,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
InternalInferModelAction.Request request = new InternalInferModelAction.Request(
model,
Collections.emptyList(),
RegressionConfig.EMPTY_PARAMS,
RegressionConfigUpdate.EMPTY_PARAMS,
true);
try {
client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
@ -344,7 +344,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
InternalInferModelAction.Request request = new InternalInferModelAction.Request(
modelId,
toInferMissingField,
RegressionConfig.EMPTY_PARAMS,
RegressionConfigUpdate.EMPTY_PARAMS,
true);
try {
InferenceResults result =

View File

@ -11,6 +11,7 @@ setup:
"description": "empty model for tests",
"tags": ["regression", "tag1"],
"input": {"field_names": ["field1", "field2"]},
"inference_config": {"regression": {}},
"definition": {
"preprocessors": [],
"trained_model": {
@ -35,6 +36,7 @@ setup:
"description": "empty model for tests",
"input": {"field_names": ["field1", "field2"]},
"tags": ["regression", "tag2"],
"inference_config": {"regression": {}},
"definition": {
"preprocessors": [],
"trained_model": {
@ -58,6 +60,7 @@ setup:
"description": "empty model for tests",
"input": {"field_names": ["field1", "field2"]},
"tags": ["classification", "tag2"],
"inference_config": {"classification": {}},
"definition": {
"preprocessors": [],
"trained_model": {
@ -83,6 +86,7 @@ setup:
"description": "empty model for tests",
"input": {"field_names": ["field1", "field2"]},
"tags": ["classification", "tag3"],
"inference_config": {"classification": {}},
"definition": {
"preprocessors": [],
"trained_model": {
@ -108,6 +112,7 @@ setup:
"description": "empty model for tests",
"input": {"field_names": ["field1", "field2"]},
"tags": ["classification", "tag3"],
"inference_config": {"classification": {}},
"definition": {
"preprocessors": [],
"trained_model": {
@ -343,6 +348,7 @@ setup:
"input": {
"field_names": "fieldy_mc_fieldname"
},
"inference_config": {"regression": {}},
"definition": {
"trained_model": {
"ensemble": {
@ -377,6 +383,7 @@ setup:
"input": {
"field_names": "fieldy_mc_fieldname"
},
"inference_config": {"regression": {}},
"definition": {
"trained_model": {
"ensemble": {
@ -397,6 +404,7 @@ setup:
"input": {
"field_names": "fieldy_mc_fieldname"
},
"inference_config": {"regression": {}},
"definition": {
"trained_model": {
"ensemble": {
@ -434,6 +442,7 @@ setup:
"input": {
"field_names": []
},
"inference_config": {"regression": {}},
"definition": {
"trained_model": {
"ensemble": {
@ -469,6 +478,7 @@ setup:
{
"description": "model for tests",
"input": {"field_names": ["field1", "field2"]},
"inference_config": {"regression": {}},
"definition": {
"preprocessors": [],
"trained_model": {
@ -510,3 +520,47 @@ setup:
- is_true: create_time
- is_true: version
- is_true: estimated_heap_memory_usage_bytes
---
"Test PUT model where target type and inference config mismatch":
- do:
catch: /Model \[my-regression-model\] inference config type \[classification\] does not support definition target type \[regression\]/
ml.put_trained_model:
model_id: my-regression-model
body: >
{
"description": "model for tests",
"input": {"field_names": ["field1", "field2"]},
"inference_config": {"classification": {}},
"definition": {
"preprocessors": [],
"trained_model": {
"ensemble": {
"target_type": "regression",
"trained_models": [
{
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
{"node_index": 1, "leaf_value": 0},
{"node_index": 2, "leaf_value": 1}
],
"target_type": "regression"
}
},
{
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
{"node_index": 1, "leaf_value": 0},
{"node_index": 2, "leaf_value": 1}
],
"target_type": "regression"
}
}
]
}
}
}
}

View File

@ -13,6 +13,7 @@ setup:
"description": "empty model for tests",
"tags": ["regression", "tag1"],
"input": {"field_names": ["field1", "field2"]},
"inference_config": { "regression": {"results_field": "my_regression"}},
"definition": {
"preprocessors": [],
"trained_model": {
@ -112,3 +113,42 @@ setup:
- 'Deprecated field [field_mappings] used, expected [field_map] instead'
ingest.delete_pipeline:
id: "regression-model-pipeline"
---
"Test simulate":
- do:
ingest.simulate:
body: >
{
"pipeline": {
"processors": [
{
"inference" : {
"model_id" : "a-perfect-regression-model",
"inference_config": {"regression": {}},
"target_field": "regression_field",
"field_map": {}
}
}
]},
"docs": [{"_source": {"field1": 1, "field2": 2}}]
}
- match: { docs.0.doc._source.regression_field.my_regression: 42.0 }
- do:
ingest.simulate:
body: >
{
"pipeline": {
"processors": [
{
"inference" : {
"model_id" : "a-perfect-regression-model",
"inference_config": {"regression": {"results_field": "value"}},
"target_field": "regression_field",
"field_map": {}
}
}
]},
"docs": [{"_source": {"field1": 1, "field2": 2}}]
}
- match: { docs.0.doc._source.regression_field.value: 42.0 }

View File

@ -6,12 +6,13 @@ setup:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
index:
id: trained_model_config-a-unused-regression-model1-0
index: .ml-inference-000001
index: .ml-inference-000002
body: >
{
"model_id": "a-unused-regression-model1",
"created_by": "ml_tests",
"version": "8.0.0",
"inference_config": {"regression": {}},
"description": "empty model for tests",
"create_time": 0,
"doc_type": "trained_model_config"
@ -22,12 +23,13 @@ setup:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
index:
id: trained_model_config-a-unused-regression-model-0
index: .ml-inference-000001
index: .ml-inference-000002
body: >
{
"model_id": "a-unused-regression-model",
"created_by": "ml_tests",
"version": "8.0.0",
"inference_config": {"regression": {}},
"description": "empty model for tests",
"create_time": 0,
"doc_type": "trained_model_config"
@ -37,12 +39,13 @@ setup:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
index:
id: trained_model_config-a-used-regression-model-0
index: .ml-inference-000001
index: .ml-inference-000002
body: >
{
"model_id": "a-used-regression-model",
"created_by": "ml_tests",
"version": "8.0.0",
"inference_config": {"regression": {}},
"description": "empty model for tests",
"create_time": 0,
"doc_type": "trained_model_config"

View File

@ -13,6 +13,7 @@ setup:
{
"description": "empty model for tests",
"tags": ["regression", "tag1"],
"inference_config": {"regression":{}},
"input": {"field_names": ["field1", "field2"]},
"definition": {
"preprocessors": [],
@ -36,6 +37,7 @@ setup:
body: >
{
"description": "empty model for tests",
"inference_config": {"regression":{}},
"input": {"field_names": ["field1", "field2"]},
"definition": {
"preprocessors": [],