diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java new file mode 100644 index 00000000000..a052c2b263d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class WarningInferenceResults implements InferenceResults { + + public static final String NAME = "warning"; + public static final ParseField WARNING = new ParseField("warning"); + + private final String warning; + + public WarningInferenceResults(String warning) { + this.warning = warning; + } + + public WarningInferenceResults(StreamInput in) throws IOException { + this.warning = in.readString(); + } + + public String getWarning() { + return warning; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(warning); + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + WarningInferenceResults that = (WarningInferenceResults) object; + return Objects.equals(warning, that.warning); + } + + @Override + public int hashCode() { + return Objects.hash(warning); + } + + @Override + public void writeResult(IngestDocument document, String parentResultField) { + ExceptionsHelper.requireNonNull(document, "document"); + ExceptionsHelper.requireNonNull(parentResultField, "resultField"); + document.setFieldValue(parentResultField + "." + "warning", warning); + } + + @Override + public String getWriteableName() { + return NAME; + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index febe9a8eb42..5e7d3ee3318 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -92,6 +92,7 @@ public final class Messages { public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]"; public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED = "Getting model definition is not supported when getting more than one model"; + public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing"; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java new file mode 100644 index 00000000000..da48a91cdde --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.util.HashMap; + +import static org.hamcrest.Matchers.equalTo; + +public class WarningInferenceResultsTests extends AbstractWireSerializingTestCase { + + public static WarningInferenceResults createRandomResults() { + return new WarningInferenceResults(randomAlphaOfLength(10)); + } + + public void testWriteResults() { + WarningInferenceResults result = new WarningInferenceResults("foo"); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field.warning", String.class), equalTo("foo")); + } + + @Override + protected WarningInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return WarningInferenceResults::new; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index b6ac9213472..00fea87a05b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -16,6 +16,8 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.license.License; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; @@ -34,6 +36,8 @@ import java.util.Objects; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import static java.util.stream.Collectors.toList; + public class AnalyticsResultProcessor { private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class); @@ -163,6 +167,10 @@ public class AnalyticsResultProcessor { Instant createTime = Instant.now(); String modelId = analytics.getId() + "-" + createTime.toEpochMilli(); TrainedModelDefinition definition = inferenceModel.build(); + String dependentVariable = getDependentVariable(); + List fieldNamesWithoutDependentVariable = fieldNames.stream() + .filter(f -> f.equals(dependentVariable) == false) + .collect(toList()); return TrainedModelConfig.builder() .setModelId(modelId) .setCreatedBy("data-frame-analytics") @@ -175,11 +183,21 @@ public class AnalyticsResultProcessor { .setEstimatedHeapMemory(definition.ramBytesUsed()) .setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations()) .setParsedDefinition(inferenceModel) - .setInput(new TrainedModelInput(fieldNames)) + .setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable)) .setLicenseLevel(License.OperationMode.PLATINUM.description()) .build(); } + private String getDependentVariable() { + if (analytics.getAnalysis() instanceof Classification) { + return ((Classification)analytics.getAnalysis()).getDependentVariable(); + } + if (analytics.getAnalysis() instanceof Regression) { + return ((Regression)analytics.getAnalysis()).getDependentVariable(); + } + return null; + } + private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) { CountDownLatch latch = new CountDownLatch(1); ActionListener storeListener = ActionListener.wrap( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 805123cf53c..19c0054b522 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -28,6 +28,8 @@ import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.ingest.Processor; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; @@ -37,7 +39,9 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -146,7 +150,12 @@ public class InferenceProcessor extends AbstractProcessor { if (response.getInferenceResults().isEmpty()) { throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR); } - response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField); + InferenceResults inferenceResults = response.getInferenceResults().get(0); + if (inferenceResults instanceof WarningInferenceResults) { + inferenceResults.writeResult(ingestDocument, this.targetField); + } else { + response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField); + } ingestDocument.setFieldValue(targetField + "." + MODEL_ID, modelId); } @@ -164,6 +173,10 @@ public class InferenceProcessor extends AbstractProcessor { private static final Logger logger = LogManager.getLogger(Factory.class); + private static final Set RESERVED_ML_FIELD_NAMES = new HashSet<>(Arrays.asList( + WarningInferenceResults.WARNING.getPreferredName(), + MODEL_ID)); + private final Client client; private final IngestService ingestService; private final InferenceAuditor auditor; @@ -235,6 +248,7 @@ public class InferenceProcessor extends AbstractProcessor { String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD, defaultTargetField); Map fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS); InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG)); + return new InferenceProcessor(client, auditor, tag, @@ -252,7 +266,6 @@ public class InferenceProcessor extends AbstractProcessor { InferenceConfig inferenceConfigFromMap(Map inferenceConfig) { ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); - if (inferenceConfig.size() != 1) { throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.", INFERENCE_CONFIG); @@ -268,10 +281,14 @@ public class InferenceProcessor extends AbstractProcessor { if (inferenceConfig.containsKey(ClassificationConfig.NAME)) { checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS); - return ClassificationConfig.fromMap(valueMap); + ClassificationConfig config = ClassificationConfig.fromMap(valueMap); + checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField()); + return config; } else if (inferenceConfig.containsKey(RegressionConfig.NAME)) { checkSupportedVersion(RegressionConfig.EMPTY_PARAMS); - return RegressionConfig.fromMap(valueMap); + RegressionConfig config = RegressionConfig.fromMap(valueMap); + checkFieldUniqueness(config.getResultsField()); + return config; } else { throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}", inferenceConfig.keySet(), @@ -279,6 +296,23 @@ public class InferenceProcessor extends AbstractProcessor { } } + private static void checkFieldUniqueness(String... fieldNames) { + Set duplicatedFieldNames = new HashSet<>(); + Set currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES); + for(String fieldName : fieldNames) { + if (currentFieldNames.contains(fieldName)) { + duplicatedFieldNames.add(fieldName); + } else { + currentFieldNames.add(fieldName); + } + } + if (duplicatedFieldNames.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Cannot create processor as configured." + + " More than one field is configured as {}", + duplicatedFieldNames); + } + } + void checkSupportedVersion(InferenceConfig config) { if (config.getMinimalSupportedVersion().after(minNodeVersion)) { throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION, @@ -287,6 +321,5 @@ public class InferenceProcessor extends AbstractProcessor { minNodeVersion)); } } - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index 403f10dd7d8..4e62c69336b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -6,23 +6,33 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import java.util.HashSet; import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING; public class LocalModel implements Model { private final TrainedModelDefinition trainedModelDefinition; private final String modelId; + private final Set fieldNames; - public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) { + public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition, TrainedModelInput input) { this.trainedModelDefinition = trainedModelDefinition; this.modelId = modelId; + this.fieldNames = new HashSet<>(input.getFieldNames()); } long ramBytesUsed() { @@ -51,6 +61,11 @@ public class LocalModel implements Model { @Override public void infer(Map fields, InferenceConfig config, ActionListener listener) { try { + if (Sets.haveEmptyIntersection(fieldNames, fields.keySet())) { + listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId))); + return; + } + listener.onResponse(trainedModelDefinition.infer(fields, config)); } catch (Exception e) { listener.onFailure(e); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 2355228f4c3..b5862acfefc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -141,7 +141,8 @@ public class ModelLoadingService implements ClusterStateListener { trainedModelConfig -> modelActionListener.onResponse(new LocalModel( trainedModelConfig.getModelId(), - trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition())), + trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(), + trainedModelConfig.getInput())), modelActionListener::onFailure )); } else { @@ -198,7 +199,8 @@ public class ModelLoadingService implements ClusterStateListener { Queue> listeners; LocalModel loadedModel = new LocalModel( trainedModelConfig.getModelId(), - trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition()); + trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(), + trainedModelConfig.getInput()); synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); // If there is no loadingListener that means the loading was canceled and the listener was already notified as such diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 15bd32da3c3..036023eb8c9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -171,7 +171,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { assertThat(storedModel.getTags(), contains(JOB_ID)); assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); assertThat(storedModel.getModelDefinition(), equalTo(inferenceModel.build())); - assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames)); + assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar", "baz"))); assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed())); assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations())); Map metadata = storedModel.getMetadata(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java index 03b5cf4fd08..04357c4e19b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -240,6 +240,29 @@ public class InferenceProcessorFactoryTests extends ESTestCase { } } + public void testCreateProcessorWithDuplicateFields() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService); + + Map regression = new HashMap() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "ml"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, + Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning"))); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression); + fail("should not have succeeded creating with duplicate fields"); + } catch (Exception ex) { + assertThat(ex.getMessage(), equalTo("Cannot create processor as configured. " + + "More than one field is configured as [warning]")); + } + } + private static ClusterState buildClusterState(MetaData metaData) { return ClusterState.builder(new ClusterName("_name")).metaData(metaData).build(); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index ae455544fe3..81e5c79135c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; @@ -253,4 +254,26 @@ public class InferenceProcessorTests extends ESTestCase { verify(auditor, times(1)).warning(eq("regression_model"), any(String.class)); } + public void testMutateDocumentWithWarningResult() { + String targetField = "regression_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + auditor, + "my_processor", + "ml", + "regression_model", + RegressionConfig.EMPTY_PARAMS, + Collections.emptyMap()); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InternalInferModelAction.Response response = new InternalInferModelAction.Response( + Collections.singletonList(new WarningInferenceResults("something broke")), true); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.hasField(targetField), is(false)); + assertThat(document.hasField("ml.warning"), is(true)); + assertThat(document.hasField("ml.my_processor"), is(false)); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 695b6d5ffac..b8cb652878f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -8,8 +8,10 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; @@ -22,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import java.util.Arrays; import java.util.HashMap; @@ -38,12 +41,13 @@ public class LocalModelTests extends ESTestCase { public void testClassificationInfer() throws Exception { String modelId = "classification_model"; + List inputFields = Arrays.asList("foo", "bar", "categorical"); TrainedModelDefinition definition = new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildClassification(false)) .build(); - Model model = new LocalModel(modelId, definition); + Model model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields)); Map fields = new HashMap() {{ put("foo", 1.0); put("bar", 0.5); @@ -64,7 +68,7 @@ public class LocalModelTests extends ESTestCase { .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildClassification(true)) .build(); - model = new LocalModel(modelId, definition); + model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields)); result = getSingleValue(model, fields, new ClassificationConfig(0)); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), equalTo("not_to_be")); @@ -81,11 +85,12 @@ public class LocalModelTests extends ESTestCase { } public void testRegression() throws Exception { + List inputFields = Arrays.asList("foo", "bar", "categorical"); TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildRegression()) .build(); - Model model = new LocalModel("regression_model", trainedModelDefinition); + Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields)); Map fields = new HashMap() {{ put("foo", 1.0); @@ -103,16 +108,39 @@ public class LocalModelTests extends ESTestCase { equalTo("Cannot infer using configuration for [classification] when model target_type is [regression]")); } + public void testAllFieldsMissing() throws Exception { + List inputFields = Arrays.asList("foo", "bar", "categorical"); + TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildRegression()) + .build(); + Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields)); + + Map fields = new HashMap() {{ + put("something", 1.0); + put("other", 0.5); + put("baz", "dog"); + }}; + + WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfig.EMPTY_PARAMS); + assertThat(results.getWarning(), + equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, "regression_model"))); + } + private static SingleValueInferenceResults getSingleValue(Model model, Map fields, InferenceConfig config) throws Exception { + return (SingleValueInferenceResults)getInferenceResult(model, fields, config); + } + + private static InferenceResults getInferenceResult(Model model, Map fields, InferenceConfig config) throws Exception { PlainActionFuture future = new PlainActionFuture<>(); model.infer(fields, config, future); - return (SingleValueInferenceResults)future.get(); + return future.get(); } private static Map oneHotMap() { - Map oneHotEncoding = new HashMap<>(); + Map oneHotEncoding = new HashMap(); oneHotEncoding.put("cat", "animal_cat"); oneHotEncoding.put("dog", "animal_dog"); return oneHotEncoding; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 462a9a90527..85c34c6f504 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -45,6 +46,7 @@ import org.mockito.Mockito; import java.io.IOException; import java.net.InetAddress; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -308,6 +310,7 @@ public class ModelLoadingServiceTests extends ESTestCase { when(definition.ramBytesUsed()).thenReturn(size); TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class); when(trainedModelConfig.getModelDefinition()).thenReturn(definition); + when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz"))); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 1cce6bf35cd..0d4aefc26b5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -17,7 +17,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.job.messages.Messages; @@ -39,6 +41,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildClassification; import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildRegression; +import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; @@ -63,7 +66,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { oneHotEncoding.put("cat", "animal_cat"); oneHotEncoding.put("dog", "animal_dog"); TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2) - .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) + .setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical"))) .setParsedDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) .setTrainedModel(buildClassification(true))) @@ -74,7 +77,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { .setEstimatedHeapMemory(0) .build(); TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1) - .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) + .setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical"))) .setParsedDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) .setTrainedModel(buildRegression())) @@ -184,6 +187,51 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { } } + public void testInferMissingFields() throws Exception { + String modelId = "test-load-models-regression-missing-fields"; + Map oneHotEncoding = new HashMap<>(); + oneHotEncoding.put("cat", "animal_cat"); + oneHotEncoding.put("dog", "animal_dog"); + TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId) + .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) + .setParsedDefinition(new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setTrainedModel(buildRegression())) + .setVersion(Version.CURRENT) + .setEstimatedOperations(0) + .setEstimatedHeapMemory(0) + .setCreateTime(Instant.now()) + .build(); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + + + List> toInferMissingField = new ArrayList<>(); + toInferMissingField.add(new HashMap() {{ + put("foo", 1.0); + put("bar", 0.5); + }}); + + InternalInferModelAction.Request request = new InternalInferModelAction.Request( + modelId, + toInferMissingField, + RegressionConfig.EMPTY_PARAMS, + true); + try { + InferenceResults result = + client().execute(InternalInferModelAction.INSTANCE, request).actionGet().getInferenceResults().get(0); + assertThat(result, is(instanceOf(WarningInferenceResults.class))); + assertThat(((WarningInferenceResults)result).getWarning(), + equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId))); + } catch (ElasticsearchException ex) { + fail("Should not have thrown. Ex: " + ex.getMessage()); + } + } + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { return TrainedModelConfig.builder() .setCreatedBy("ml_test")