[ML][Inference] Adding a warning_field for warning msgs. (#49838) (#50183)

This adds a new field for the inference processor.

`warning_field` is a place for us to write warnings provided from the inference call. When there are warnings we are not going to write an inference result. The goal of this is to indicate that the data provided was too poor or too different for the model to make an accurate prediction.

The user could optionally include the `warning_field`. When it is not provided, it is assumed no warnings were desired to be written.

The first of these warnings is when ALL of the input fields are missing. If none of the trained fields are present, we don't bother inferencing against the model and instead provide a warning stating that the fields were missing.

Also, this adds checks to not allow duplicated fields during processor creation.
This commit is contained in:
Benjamin Trent 2019-12-13 10:39:51 -05:00 committed by GitHub
parent 41736dd6c3
commit 4805d8ac7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 316 additions and 17 deletions

View File

@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Objects;
public class WarningInferenceResults implements InferenceResults {
public static final String NAME = "warning";
public static final ParseField WARNING = new ParseField("warning");
private final String warning;
public WarningInferenceResults(String warning) {
this.warning = warning;
}
public WarningInferenceResults(StreamInput in) throws IOException {
this.warning = in.readString();
}
public String getWarning() {
return warning;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(warning);
}
@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
WarningInferenceResults that = (WarningInferenceResults) object;
return Objects.equals(warning, that.warning);
}
@Override
public int hashCode() {
return Objects.hash(warning);
}
@Override
public void writeResult(IngestDocument document, String parentResultField) {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(parentResultField, "resultField");
document.setFieldValue(parentResultField + "." + "warning", warning);
}
@Override
public String getWriteableName() {
return NAME;
}
}

View File

@ -92,6 +92,7 @@ public final class Messages {
public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]";
public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED =
"Getting model definition is not supported when getting more than one model";
public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing";
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
public static final String JOB_AUDIT_CREATED = "Job created";

View File

@ -0,0 +1,39 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import java.util.HashMap;
import static org.hamcrest.Matchers.equalTo;
public class WarningInferenceResultsTests extends AbstractWireSerializingTestCase<WarningInferenceResults> {
public static WarningInferenceResults createRandomResults() {
return new WarningInferenceResults(randomAlphaOfLength(10));
}
public void testWriteResults() {
WarningInferenceResults result = new WarningInferenceResults("foo");
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
assertThat(document.getFieldValue("result_field.warning", String.class), equalTo("foo"));
}
@Override
protected WarningInferenceResults createTestInstance() {
return createRandomResults();
}
@Override
protected Writeable.Reader<WarningInferenceResults> instanceReader() {
return WarningInferenceResults::new;
}
}

View File

@ -16,6 +16,8 @@ import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
@ -34,6 +36,8 @@ import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static java.util.stream.Collectors.toList;
public class AnalyticsResultProcessor {
private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class);
@ -163,6 +167,10 @@ public class AnalyticsResultProcessor {
Instant createTime = Instant.now();
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
TrainedModelDefinition definition = inferenceModel.build();
String dependentVariable = getDependentVariable();
List<String> fieldNamesWithoutDependentVariable = fieldNames.stream()
.filter(f -> f.equals(dependentVariable) == false)
.collect(toList());
return TrainedModelConfig.builder()
.setModelId(modelId)
.setCreatedBy("data-frame-analytics")
@ -175,11 +183,21 @@ public class AnalyticsResultProcessor {
.setEstimatedHeapMemory(definition.ramBytesUsed())
.setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations())
.setParsedDefinition(inferenceModel)
.setInput(new TrainedModelInput(fieldNames))
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.build();
}
private String getDependentVariable() {
if (analytics.getAnalysis() instanceof Classification) {
return ((Classification)analytics.getAnalysis()).getDependentVariable();
}
if (analytics.getAnalysis() instanceof Regression) {
return ((Regression)analytics.getAnalysis()).getDependentVariable();
}
return null;
}
private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) {
CountDownLatch latch = new CountDownLatch(1);
ActionListener<Boolean> storeListener = ActionListener.wrap(

View File

@ -28,6 +28,8 @@ import org.elasticsearch.ingest.PipelineConfiguration;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
@ -37,7 +39,9 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
@ -146,7 +150,12 @@ public class InferenceProcessor extends AbstractProcessor {
if (response.getInferenceResults().isEmpty()) {
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
}
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
InferenceResults inferenceResults = response.getInferenceResults().get(0);
if (inferenceResults instanceof WarningInferenceResults) {
inferenceResults.writeResult(ingestDocument, this.targetField);
} else {
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
}
ingestDocument.setFieldValue(targetField + "." + MODEL_ID, modelId);
}
@ -164,6 +173,10 @@ public class InferenceProcessor extends AbstractProcessor {
private static final Logger logger = LogManager.getLogger(Factory.class);
private static final Set<String> RESERVED_ML_FIELD_NAMES = new HashSet<>(Arrays.asList(
WarningInferenceResults.WARNING.getPreferredName(),
MODEL_ID));
private final Client client;
private final IngestService ingestService;
private final InferenceAuditor auditor;
@ -235,6 +248,7 @@ public class InferenceProcessor extends AbstractProcessor {
String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD, defaultTargetField);
Map<String, String> fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS);
InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG));
return new InferenceProcessor(client,
auditor,
tag,
@ -252,7 +266,6 @@ public class InferenceProcessor extends AbstractProcessor {
InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) {
ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
if (inferenceConfig.size() != 1) {
throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.",
INFERENCE_CONFIG);
@ -268,10 +281,14 @@ public class InferenceProcessor extends AbstractProcessor {
if (inferenceConfig.containsKey(ClassificationConfig.NAME)) {
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
return ClassificationConfig.fromMap(valueMap);
ClassificationConfig config = ClassificationConfig.fromMap(valueMap);
checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField());
return config;
} else if (inferenceConfig.containsKey(RegressionConfig.NAME)) {
checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
return RegressionConfig.fromMap(valueMap);
RegressionConfig config = RegressionConfig.fromMap(valueMap);
checkFieldUniqueness(config.getResultsField());
return config;
} else {
throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
inferenceConfig.keySet(),
@ -279,6 +296,23 @@ public class InferenceProcessor extends AbstractProcessor {
}
}
private static void checkFieldUniqueness(String... fieldNames) {
Set<String> duplicatedFieldNames = new HashSet<>();
Set<String> currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES);
for(String fieldName : fieldNames) {
if (currentFieldNames.contains(fieldName)) {
duplicatedFieldNames.add(fieldName);
} else {
currentFieldNames.add(fieldName);
}
}
if (duplicatedFieldNames.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Cannot create processor as configured." +
" More than one field is configured as {}",
duplicatedFieldNames);
}
}
void checkSupportedVersion(InferenceConfig config) {
if (config.getMinimalSupportedVersion().after(minNodeVersion)) {
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION,
@ -287,6 +321,5 @@ public class InferenceProcessor extends AbstractProcessor {
minNodeVersion));
}
}
}
}

View File

@ -6,23 +6,33 @@
package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;
public class LocalModel implements Model {
private final TrainedModelDefinition trainedModelDefinition;
private final String modelId;
private final Set<String> fieldNames;
public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) {
public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition, TrainedModelInput input) {
this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId;
this.fieldNames = new HashSet<>(input.getFieldNames());
}
long ramBytesUsed() {
@ -51,6 +61,11 @@ public class LocalModel implements Model {
@Override
public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
try {
if (Sets.haveEmptyIntersection(fieldNames, fields.keySet())) {
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
return;
}
listener.onResponse(trainedModelDefinition.infer(fields, config));
} catch (Exception e) {
listener.onFailure(e);

View File

@ -141,7 +141,8 @@ public class ModelLoadingService implements ClusterStateListener {
trainedModelConfig ->
modelActionListener.onResponse(new LocalModel(
trainedModelConfig.getModelId(),
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition())),
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(),
trainedModelConfig.getInput())),
modelActionListener::onFailure
));
} else {
@ -198,7 +199,8 @@ public class ModelLoadingService implements ClusterStateListener {
Queue<ActionListener<Model>> listeners;
LocalModel loadedModel = new LocalModel(
trainedModelConfig.getModelId(),
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition());
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(),
trainedModelConfig.getInput());
synchronized (loadingListeners) {
listeners = loadingListeners.remove(modelId);
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such

View File

@ -171,7 +171,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
assertThat(storedModel.getTags(), contains(JOB_ID));
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
assertThat(storedModel.getModelDefinition(), equalTo(inferenceModel.build()));
assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames));
assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar", "baz")));
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed()));
assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations()));
Map<String, Object> metadata = storedModel.getMetadata();

View File

@ -240,6 +240,29 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
}
}
public void testCreateProcessorWithDuplicateFields() {
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
clusterService,
Settings.EMPTY,
ingestService);
Map<String, Object> regression = new HashMap<String, Object>() {{
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "ml");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME,
Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning")));
}};
try {
processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression);
fail("should not have succeeded creating with duplicate fields");
} catch (Exception ex) {
assertThat(ex.getMessage(), equalTo("Cannot create processor as configured. " +
"More than one field is configured as [warning]"));
}
}
private static ClusterState buildClusterState(MetaData metaData) {
return ClusterState.builder(new ClusterName("_name")).metaData(metaData).build();
}

View File

@ -11,6 +11,7 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
@ -253,4 +254,26 @@ public class InferenceProcessorTests extends ESTestCase {
verify(auditor, times(1)).warning(eq("regression_model"), any(String.class));
}
public void testMutateDocumentWithWarningResult() {
String targetField = "regression_value";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
"ml",
"regression_model",
RegressionConfig.EMPTY_PARAMS,
Collections.emptyMap());
Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new WarningInferenceResults("something broke")), true);
inferenceProcessor.mutateDocument(response, document);
assertThat(document.hasField(targetField), is(false));
assertThat(document.hasField("ml.warning"), is(true));
assertThat(document.hasField("ml.my_processor"), is(false));
}
}

View File

@ -8,8 +8,10 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
@ -22,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import java.util.Arrays;
import java.util.HashMap;
@ -38,12 +41,13 @@ public class LocalModelTests extends ESTestCase {
public void testClassificationInfer() throws Exception {
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(false))
.build();
Model model = new LocalModel(modelId, definition);
Model model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields));
Map<String, Object> fields = new HashMap<String, Object>() {{
put("foo", 1.0);
put("bar", 0.5);
@ -64,7 +68,7 @@ public class LocalModelTests extends ESTestCase {
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(true))
.build();
model = new LocalModel(modelId, definition);
model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields));
result = getSingleValue(model, fields, new ClassificationConfig(0));
assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), equalTo("not_to_be"));
@ -81,11 +85,12 @@ public class LocalModelTests extends ESTestCase {
}
public void testRegression() throws Exception {
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildRegression())
.build();
Model model = new LocalModel("regression_model", trainedModelDefinition);
Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields));
Map<String, Object> fields = new HashMap<String, Object>() {{
put("foo", 1.0);
@ -103,16 +108,39 @@ public class LocalModelTests extends ESTestCase {
equalTo("Cannot infer using configuration for [classification] when model target_type is [regression]"));
}
public void testAllFieldsMissing() throws Exception {
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildRegression())
.build();
Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields));
Map<String, Object> fields = new HashMap<String, Object>() {{
put("something", 1.0);
put("other", 0.5);
put("baz", "dog");
}};
WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfig.EMPTY_PARAMS);
assertThat(results.getWarning(),
equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, "regression_model")));
}
private static SingleValueInferenceResults getSingleValue(Model model,
Map<String, Object> fields,
InferenceConfig config) throws Exception {
return (SingleValueInferenceResults)getInferenceResult(model, fields, config);
}
private static InferenceResults getInferenceResult(Model model, Map<String, Object> fields, InferenceConfig config) throws Exception {
PlainActionFuture<InferenceResults> future = new PlainActionFuture<>();
model.infer(fields, config, future);
return (SingleValueInferenceResults)future.get();
return future.get();
}
private static Map<String, String> oneHotMap() {
Map<String, String> oneHotEncoding = new HashMap<>();
Map<String, String> oneHotEncoding = new HashMap<String, String>();
oneHotEncoding.put("cat", "animal_cat");
oneHotEncoding.put("dog", "animal_dog");
return oneHotEncoding;

View File

@ -35,6 +35,7 @@ import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@ -45,6 +46,7 @@ import org.mockito.Mockito;
import java.io.IOException;
import java.net.InetAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@ -308,6 +310,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
when(definition.ramBytesUsed()).thenReturn(size);
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
when(trainedModelConfig.getModelDefinition()).thenReturn(definition);
when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz")));
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];

View File

@ -17,7 +17,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
@ -39,6 +41,7 @@ import java.util.stream.Collectors;
import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildClassification;
import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildRegression;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.equalTo;
@ -63,7 +66,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
oneHotEncoding.put("cat", "animal_cat");
oneHotEncoding.put("dog", "animal_dog");
TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2)
.setInput(new TrainedModelInput(Arrays.asList("field1", "field2")))
.setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical")))
.setParsedDefinition(new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
.setTrainedModel(buildClassification(true)))
@ -74,7 +77,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
.setEstimatedHeapMemory(0)
.build();
TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1)
.setInput(new TrainedModelInput(Arrays.asList("field1", "field2")))
.setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical")))
.setParsedDefinition(new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
.setTrainedModel(buildRegression()))
@ -184,6 +187,51 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
}
}
public void testInferMissingFields() throws Exception {
String modelId = "test-load-models-regression-missing-fields";
Map<String, String> oneHotEncoding = new HashMap<>();
oneHotEncoding.put("cat", "animal_cat");
oneHotEncoding.put("dog", "animal_dog");
TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId)
.setInput(new TrainedModelInput(Arrays.asList("field1", "field2")))
.setParsedDefinition(new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
.setTrainedModel(buildRegression()))
.setVersion(Version.CURRENT)
.setEstimatedOperations(0)
.setEstimatedHeapMemory(0)
.setCreateTime(Instant.now())
.build();
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
assertThat(putConfigHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
List<Map<String, Object>> toInferMissingField = new ArrayList<>();
toInferMissingField.add(new HashMap<String, Object>() {{
put("foo", 1.0);
put("bar", 0.5);
}});
InternalInferModelAction.Request request = new InternalInferModelAction.Request(
modelId,
toInferMissingField,
RegressionConfig.EMPTY_PARAMS,
true);
try {
InferenceResults result =
client().execute(InternalInferModelAction.INSTANCE, request).actionGet().getInferenceResults().get(0);
assertThat(result, is(instanceOf(WarningInferenceResults.class)));
assertThat(((WarningInferenceResults)result).getWarning(),
equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
} catch (ElasticsearchException ex) {
fail("Should not have thrown. Ex: " + ex.getMessage());
}
}
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
return TrainedModelConfig.builder()
.setCreatedBy("ml_test")