From 4805d8ac7df5a65b3f66740a3ce6f59541d6a6ef Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 13 Dec 2019 10:39:51 -0500 Subject: [PATCH] [ML][Inference] Adding a warning_field for warning msgs. (#49838) (#50183) This adds a new field for the inference processor. `warning_field` is a place for us to write warnings provided from the inference call. When there are warnings we are not going to write an inference result. The goal of this is to indicate that the data provided was too poor or too different for the model to make an accurate prediction. The user could optionally include the `warning_field`. When it is not provided, it is assumed no warnings were desired to be written. The first of these warnings is when ALL of the input fields are missing. If none of the trained fields are present, we don't bother inferencing against the model and instead provide a warning stating that the fields were missing. Also, this adds checks to not allow duplicated fields during processor creation. --- .../results/WarningInferenceResults.java | 66 +++++++++++++++++++ .../xpack/core/ml/job/messages/Messages.java | 1 + .../results/WarningInferenceResultsTests.java | 39 +++++++++++ .../process/AnalyticsResultProcessor.java | 20 +++++- .../inference/ingest/InferenceProcessor.java | 43 ++++++++++-- .../inference/loadingservice/LocalModel.java | 17 ++++- .../loadingservice/ModelLoadingService.java | 6 +- .../AnalyticsResultProcessorTests.java | 2 +- .../InferenceProcessorFactoryTests.java | 23 +++++++ .../ingest/InferenceProcessorTests.java | 23 +++++++ .../loadingservice/LocalModelTests.java | 38 +++++++++-- .../ModelLoadingServiceTests.java | 3 + .../integration/ModelInferenceActionIT.java | 52 ++++++++++++++- 13 files changed, 316 insertions(+), 17 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java new file mode 100644 index 00000000000..a052c2b263d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class WarningInferenceResults implements InferenceResults { + + public static final String NAME = "warning"; + public static final ParseField WARNING = new ParseField("warning"); + + private final String warning; + + public WarningInferenceResults(String warning) { + this.warning = warning; + } + + public WarningInferenceResults(StreamInput in) throws IOException { + this.warning = in.readString(); + } + + public String getWarning() { + return warning; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(warning); + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + WarningInferenceResults that = (WarningInferenceResults) object; + return Objects.equals(warning, that.warning); + } + + @Override + public int hashCode() { + return Objects.hash(warning); + } + + @Override + public void writeResult(IngestDocument document, String parentResultField) { + ExceptionsHelper.requireNonNull(document, "document"); + ExceptionsHelper.requireNonNull(parentResultField, "resultField"); + document.setFieldValue(parentResultField + "." + "warning", warning); + } + + @Override + public String getWriteableName() { + return NAME; + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index febe9a8eb42..5e7d3ee3318 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -92,6 +92,7 @@ public final class Messages { public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]"; public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED = "Getting model definition is not supported when getting more than one model"; + public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing"; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java new file mode 100644 index 00000000000..da48a91cdde --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.util.HashMap; + +import static org.hamcrest.Matchers.equalTo; + +public class WarningInferenceResultsTests extends AbstractWireSerializingTestCase { + + public static WarningInferenceResults createRandomResults() { + return new WarningInferenceResults(randomAlphaOfLength(10)); + } + + public void testWriteResults() { + WarningInferenceResults result = new WarningInferenceResults("foo"); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field.warning", String.class), equalTo("foo")); + } + + @Override + protected WarningInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return WarningInferenceResults::new; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index b6ac9213472..00fea87a05b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -16,6 +16,8 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.license.License; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; @@ -34,6 +36,8 @@ import java.util.Objects; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import static java.util.stream.Collectors.toList; + public class AnalyticsResultProcessor { private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class); @@ -163,6 +167,10 @@ public class AnalyticsResultProcessor { Instant createTime = Instant.now(); String modelId = analytics.getId() + "-" + createTime.toEpochMilli(); TrainedModelDefinition definition = inferenceModel.build(); + String dependentVariable = getDependentVariable(); + List fieldNamesWithoutDependentVariable = fieldNames.stream() + .filter(f -> f.equals(dependentVariable) == false) + .collect(toList()); return TrainedModelConfig.builder() .setModelId(modelId) .setCreatedBy("data-frame-analytics") @@ -175,11 +183,21 @@ public class AnalyticsResultProcessor { .setEstimatedHeapMemory(definition.ramBytesUsed()) .setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations()) .setParsedDefinition(inferenceModel) - .setInput(new TrainedModelInput(fieldNames)) + .setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable)) .setLicenseLevel(License.OperationMode.PLATINUM.description()) .build(); } + private String getDependentVariable() { + if (analytics.getAnalysis() instanceof Classification) { + return ((Classification)analytics.getAnalysis()).getDependentVariable(); + } + if (analytics.getAnalysis() instanceof Regression) { + return ((Regression)analytics.getAnalysis()).getDependentVariable(); + } + return null; + } + private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) { CountDownLatch latch = new CountDownLatch(1); ActionListener storeListener = ActionListener.wrap( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 805123cf53c..19c0054b522 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -28,6 +28,8 @@ import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.ingest.Processor; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; @@ -37,7 +39,9 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -146,7 +150,12 @@ public class InferenceProcessor extends AbstractProcessor { if (response.getInferenceResults().isEmpty()) { throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR); } - response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField); + InferenceResults inferenceResults = response.getInferenceResults().get(0); + if (inferenceResults instanceof WarningInferenceResults) { + inferenceResults.writeResult(ingestDocument, this.targetField); + } else { + response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField); + } ingestDocument.setFieldValue(targetField + "." + MODEL_ID, modelId); } @@ -164,6 +173,10 @@ public class InferenceProcessor extends AbstractProcessor { private static final Logger logger = LogManager.getLogger(Factory.class); + private static final Set RESERVED_ML_FIELD_NAMES = new HashSet<>(Arrays.asList( + WarningInferenceResults.WARNING.getPreferredName(), + MODEL_ID)); + private final Client client; private final IngestService ingestService; private final InferenceAuditor auditor; @@ -235,6 +248,7 @@ public class InferenceProcessor extends AbstractProcessor { String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD, defaultTargetField); Map fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS); InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG)); + return new InferenceProcessor(client, auditor, tag, @@ -252,7 +266,6 @@ public class InferenceProcessor extends AbstractProcessor { InferenceConfig inferenceConfigFromMap(Map inferenceConfig) { ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); - if (inferenceConfig.size() != 1) { throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.", INFERENCE_CONFIG); @@ -268,10 +281,14 @@ public class InferenceProcessor extends AbstractProcessor { if (inferenceConfig.containsKey(ClassificationConfig.NAME)) { checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS); - return ClassificationConfig.fromMap(valueMap); + ClassificationConfig config = ClassificationConfig.fromMap(valueMap); + checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField()); + return config; } else if (inferenceConfig.containsKey(RegressionConfig.NAME)) { checkSupportedVersion(RegressionConfig.EMPTY_PARAMS); - return RegressionConfig.fromMap(valueMap); + RegressionConfig config = RegressionConfig.fromMap(valueMap); + checkFieldUniqueness(config.getResultsField()); + return config; } else { throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}", inferenceConfig.keySet(), @@ -279,6 +296,23 @@ public class InferenceProcessor extends AbstractProcessor { } } + private static void checkFieldUniqueness(String... fieldNames) { + Set duplicatedFieldNames = new HashSet<>(); + Set currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES); + for(String fieldName : fieldNames) { + if (currentFieldNames.contains(fieldName)) { + duplicatedFieldNames.add(fieldName); + } else { + currentFieldNames.add(fieldName); + } + } + if (duplicatedFieldNames.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Cannot create processor as configured." + + " More than one field is configured as {}", + duplicatedFieldNames); + } + } + void checkSupportedVersion(InferenceConfig config) { if (config.getMinimalSupportedVersion().after(minNodeVersion)) { throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION, @@ -287,6 +321,5 @@ public class InferenceProcessor extends AbstractProcessor { minNodeVersion)); } } - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index 403f10dd7d8..4e62c69336b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -6,23 +6,33 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import java.util.HashSet; import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING; public class LocalModel implements Model { private final TrainedModelDefinition trainedModelDefinition; private final String modelId; + private final Set fieldNames; - public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) { + public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition, TrainedModelInput input) { this.trainedModelDefinition = trainedModelDefinition; this.modelId = modelId; + this.fieldNames = new HashSet<>(input.getFieldNames()); } long ramBytesUsed() { @@ -51,6 +61,11 @@ public class LocalModel implements Model { @Override public void infer(Map fields, InferenceConfig config, ActionListener listener) { try { + if (Sets.haveEmptyIntersection(fieldNames, fields.keySet())) { + listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId))); + return; + } + listener.onResponse(trainedModelDefinition.infer(fields, config)); } catch (Exception e) { listener.onFailure(e); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 2355228f4c3..b5862acfefc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -141,7 +141,8 @@ public class ModelLoadingService implements ClusterStateListener { trainedModelConfig -> modelActionListener.onResponse(new LocalModel( trainedModelConfig.getModelId(), - trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition())), + trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(), + trainedModelConfig.getInput())), modelActionListener::onFailure )); } else { @@ -198,7 +199,8 @@ public class ModelLoadingService implements ClusterStateListener { Queue> listeners; LocalModel loadedModel = new LocalModel( trainedModelConfig.getModelId(), - trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition()); + trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(), + trainedModelConfig.getInput()); synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); // If there is no loadingListener that means the loading was canceled and the listener was already notified as such diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 15bd32da3c3..036023eb8c9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -171,7 +171,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { assertThat(storedModel.getTags(), contains(JOB_ID)); assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); assertThat(storedModel.getModelDefinition(), equalTo(inferenceModel.build())); - assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames)); + assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar", "baz"))); assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed())); assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations())); Map metadata = storedModel.getMetadata(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java index 03b5cf4fd08..04357c4e19b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -240,6 +240,29 @@ public class InferenceProcessorFactoryTests extends ESTestCase { } } + public void testCreateProcessorWithDuplicateFields() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService); + + Map regression = new HashMap() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "ml"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, + Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning"))); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression); + fail("should not have succeeded creating with duplicate fields"); + } catch (Exception ex) { + assertThat(ex.getMessage(), equalTo("Cannot create processor as configured. " + + "More than one field is configured as [warning]")); + } + } + private static ClusterState buildClusterState(MetaData metaData) { return ClusterState.builder(new ClusterName("_name")).metaData(metaData).build(); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index ae455544fe3..81e5c79135c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; @@ -253,4 +254,26 @@ public class InferenceProcessorTests extends ESTestCase { verify(auditor, times(1)).warning(eq("regression_model"), any(String.class)); } + public void testMutateDocumentWithWarningResult() { + String targetField = "regression_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + auditor, + "my_processor", + "ml", + "regression_model", + RegressionConfig.EMPTY_PARAMS, + Collections.emptyMap()); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InternalInferModelAction.Response response = new InternalInferModelAction.Response( + Collections.singletonList(new WarningInferenceResults("something broke")), true); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.hasField(targetField), is(false)); + assertThat(document.hasField("ml.warning"), is(true)); + assertThat(document.hasField("ml.my_processor"), is(false)); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 695b6d5ffac..b8cb652878f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -8,8 +8,10 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; @@ -22,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import java.util.Arrays; import java.util.HashMap; @@ -38,12 +41,13 @@ public class LocalModelTests extends ESTestCase { public void testClassificationInfer() throws Exception { String modelId = "classification_model"; + List inputFields = Arrays.asList("foo", "bar", "categorical"); TrainedModelDefinition definition = new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildClassification(false)) .build(); - Model model = new LocalModel(modelId, definition); + Model model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields)); Map fields = new HashMap() {{ put("foo", 1.0); put("bar", 0.5); @@ -64,7 +68,7 @@ public class LocalModelTests extends ESTestCase { .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildClassification(true)) .build(); - model = new LocalModel(modelId, definition); + model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields)); result = getSingleValue(model, fields, new ClassificationConfig(0)); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), equalTo("not_to_be")); @@ -81,11 +85,12 @@ public class LocalModelTests extends ESTestCase { } public void testRegression() throws Exception { + List inputFields = Arrays.asList("foo", "bar", "categorical"); TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildRegression()) .build(); - Model model = new LocalModel("regression_model", trainedModelDefinition); + Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields)); Map fields = new HashMap() {{ put("foo", 1.0); @@ -103,16 +108,39 @@ public class LocalModelTests extends ESTestCase { equalTo("Cannot infer using configuration for [classification] when model target_type is [regression]")); } + public void testAllFieldsMissing() throws Exception { + List inputFields = Arrays.asList("foo", "bar", "categorical"); + TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildRegression()) + .build(); + Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields)); + + Map fields = new HashMap() {{ + put("something", 1.0); + put("other", 0.5); + put("baz", "dog"); + }}; + + WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfig.EMPTY_PARAMS); + assertThat(results.getWarning(), + equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, "regression_model"))); + } + private static SingleValueInferenceResults getSingleValue(Model model, Map fields, InferenceConfig config) throws Exception { + return (SingleValueInferenceResults)getInferenceResult(model, fields, config); + } + + private static InferenceResults getInferenceResult(Model model, Map fields, InferenceConfig config) throws Exception { PlainActionFuture future = new PlainActionFuture<>(); model.infer(fields, config, future); - return (SingleValueInferenceResults)future.get(); + return future.get(); } private static Map oneHotMap() { - Map oneHotEncoding = new HashMap<>(); + Map oneHotEncoding = new HashMap(); oneHotEncoding.put("cat", "animal_cat"); oneHotEncoding.put("dog", "animal_dog"); return oneHotEncoding; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 462a9a90527..85c34c6f504 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -45,6 +46,7 @@ import org.mockito.Mockito; import java.io.IOException; import java.net.InetAddress; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -308,6 +310,7 @@ public class ModelLoadingServiceTests extends ESTestCase { when(definition.ramBytesUsed()).thenReturn(size); TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class); when(trainedModelConfig.getModelDefinition()).thenReturn(definition); + when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz"))); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 1cce6bf35cd..0d4aefc26b5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -17,7 +17,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.job.messages.Messages; @@ -39,6 +41,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildClassification; import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildRegression; +import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; @@ -63,7 +66,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { oneHotEncoding.put("cat", "animal_cat"); oneHotEncoding.put("dog", "animal_dog"); TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2) - .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) + .setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical"))) .setParsedDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) .setTrainedModel(buildClassification(true))) @@ -74,7 +77,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { .setEstimatedHeapMemory(0) .build(); TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1) - .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) + .setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical"))) .setParsedDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) .setTrainedModel(buildRegression())) @@ -184,6 +187,51 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { } } + public void testInferMissingFields() throws Exception { + String modelId = "test-load-models-regression-missing-fields"; + Map oneHotEncoding = new HashMap<>(); + oneHotEncoding.put("cat", "animal_cat"); + oneHotEncoding.put("dog", "animal_dog"); + TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId) + .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) + .setParsedDefinition(new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setTrainedModel(buildRegression())) + .setVersion(Version.CURRENT) + .setEstimatedOperations(0) + .setEstimatedHeapMemory(0) + .setCreateTime(Instant.now()) + .build(); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + + + List> toInferMissingField = new ArrayList<>(); + toInferMissingField.add(new HashMap() {{ + put("foo", 1.0); + put("bar", 0.5); + }}); + + InternalInferModelAction.Request request = new InternalInferModelAction.Request( + modelId, + toInferMissingField, + RegressionConfig.EMPTY_PARAMS, + true); + try { + InferenceResults result = + client().execute(InternalInferModelAction.INSTANCE, request).actionGet().getInferenceResults().get(0); + assertThat(result, is(instanceOf(WarningInferenceResults.class))); + assertThat(((WarningInferenceResults)result).getWarning(), + equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId))); + } catch (ElasticsearchException ex) { + fail("Should not have thrown. Ex: " + ex.getMessage()); + } + } + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { return TrainedModelConfig.builder() .setCreatedBy("ml_test")