From 4275a715c906d829b9cd0dea9622b03c5ceaa5e6 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 12 Aug 2020 08:34:18 -0400 Subject: [PATCH] [ML] adjusting inference processor to support foreach usage (#60915) (#61022) `foreach` processors store information within the `_ingest` metadata object. This commit adds the contents of the `_ingest` metadata (if it is not empty). And will append new inference results if the result field already exists. This allows a `foreach` to execute and multiple inference results being written to the same result field. closes https://github.com/elastic/elasticsearch/issues/60867 --- .../ClassificationInferenceResults.java | 9 ---- .../inference/results/InferenceResults.java | 15 +++++- .../results/RawInferenceResults.java | 7 +-- .../results/RegressionInferenceResults.java | 9 ---- .../results/WarningInferenceResults.java | 9 ---- .../ClassificationInferenceResultsTests.java | 20 ++++++-- .../RegressionInferenceResultsTests.java | 11 ++++- .../results/WarningInferenceResultsTests.java | 9 +++- .../ml/integration/InferenceIngestIT.java | 46 +++++++++++++++++++ .../inference/ingest/InferenceProcessor.java | 21 +++++---- .../loadingservice/ModelLoadingService.java | 3 +- ...sportGetTrainedModelsStatsActionTests.java | 3 +- .../InferenceProcessorFactoryTests.java | 23 +++++----- .../ingest/InferenceProcessorTests.java | 15 ++++++ .../loadingservice/LocalModelTests.java | 7 +-- .../ModelLoadingServiceTests.java | 3 +- 16 files changed, 143 insertions(+), 67 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java index 9688fa0a3a1..0f846203f86 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -9,11 +9,9 @@ import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Collections; @@ -160,13 +158,6 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults return predictionFieldType.transformPredictedValue(value(), valueAsString()); } - @Override - public void writeResult(IngestDocument document, String parentResultField) { - ExceptionsHelper.requireNonNull(document, "document"); - ExceptionsHelper.requireNonNull(parentResultField, "parentResultField"); - document.setFieldValue(parentResultField, asMap()); - } - public Double getPredictionProbability() { return predictionProbability; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java index 6b83f44cffb..193446a23e1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java @@ -8,12 +8,25 @@ package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.xcontent.ToXContentFragment; import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.util.Map; public interface InferenceResults extends NamedWriteable, ToXContentFragment { + String MODEL_ID_RESULTS_FIELD = "model_id"; - void writeResult(IngestDocument document, String parentResultField); + static void writeResult(InferenceResults results, IngestDocument ingestDocument, String resultField, String modelId) { + ExceptionsHelper.requireNonNull(results, "results"); + ExceptionsHelper.requireNonNull(ingestDocument, "ingestDocument"); + ExceptionsHelper.requireNonNull(resultField, "resultField"); + Map resultMap = results.asMap(); + resultMap.put(MODEL_ID_RESULTS_FIELD, modelId); + if (ingestDocument.hasField(resultField)) { + ingestDocument.appendFieldValue(resultField, resultMap); + } else { + ingestDocument.setFieldValue(resultField, resultMap); + } + } Map asMap(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java index 1f907e8a2f5..f905f84d909 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.ingest.IngestDocument; import java.io.IOException; import java.util.Arrays; @@ -53,15 +52,11 @@ public class RawInferenceResults implements InferenceResults { return Objects.hash(Arrays.hashCode(value), featureImportance); } - @Override - public void writeResult(IngestDocument document, String parentResultField) { - throw new UnsupportedOperationException("[raw] does not support writing inference results"); - } - @Override public Map asMap() { throw new UnsupportedOperationException("[raw] does not support map conversion"); } + @Override public Object predictedValue() { return null; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java index f082633a88f..498fd2828bc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java @@ -8,10 +8,8 @@ package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Collections; @@ -83,13 +81,6 @@ public class RegressionInferenceResults extends SingleValueInferenceResults { return super.value(); } - @Override - public void writeResult(IngestDocument document, String parentResultField) { - ExceptionsHelper.requireNonNull(document, "document"); - ExceptionsHelper.requireNonNull(parentResultField, "parentResultField"); - document.setFieldValue(parentResultField, asMap()); - } - @Override public Map asMap() { Map map = new LinkedHashMap<>(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java index 9bff569bc2e..6c48eb929a7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java @@ -9,8 +9,6 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.ingest.IngestDocument; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.LinkedHashMap; @@ -54,13 +52,6 @@ public class WarningInferenceResults implements InferenceResults { return Objects.hash(warning); } - @Override - public void writeResult(IngestDocument document, String parentResultField) { - ExceptionsHelper.requireNonNull(document, "document"); - ExceptionsHelper.requireNonNull(parentResultField, "resultField"); - document.setFieldValue(parentResultField, asMap()); - } - @Override public Map asMap() { Map asMap = new LinkedHashMap<>(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java index 5f245183c33..efeb2cdb256 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java @@ -22,6 +22,7 @@ import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -64,7 +65,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing 1.0, 1.0); IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); - result.writeResult(document, "result_field"); + writeResult(result, document, "result_field", "test"); assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("foo")); } @@ -78,9 +79,20 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing 1.0, 1.0); IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); - result.writeResult(document, "result_field"); + writeResult(result, document, "result_field", "test"); assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("1.0")); + + result = new ClassificationInferenceResults(2.0, + null, + Collections.emptyList(), + Collections.emptyList(), + ClassificationConfig.EMPTY_PARAMS, + 1.0, + 1.0); + writeResult(result, document, "result_field", "test"); + assertThat(document.getFieldValue("result_field.0.predicted_value", String.class), equalTo("1.0")); + assertThat(document.getFieldValue("result_field.1.predicted_value", String.class), equalTo("2.0")); } @SuppressWarnings("unchecked") @@ -97,7 +109,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing 0.7, 0.7); IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); - result.writeResult(document, "result_field"); + writeResult(result, document, "result_field", "test"); List list = document.getFieldValue("result_field.bar", List.class); assertThat(list.size(), equalTo(3)); @@ -126,7 +138,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing 1.0, 1.0); IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); - result.writeResult(document, "result_field"); + writeResult(result, document, "result_field", "test"); assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("foo")); @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java index 783e44728c1..91899b688ae 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java @@ -19,6 +19,7 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -37,9 +38,15 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest public void testWriteResults() { RegressionInferenceResults result = new RegressionInferenceResults(0.3, RegressionConfig.EMPTY_PARAMS); IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); - result.writeResult(document, "result_field"); + writeResult(result, document, "result_field", "test"); assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3)); + + result = new RegressionInferenceResults(0.5, RegressionConfig.EMPTY_PARAMS); + writeResult(result, document, "result_field", "test"); + + assertThat(document.getFieldValue("result_field.0.predicted_value", Double.class), equalTo(0.3)); + assertThat(document.getFieldValue("result_field.1.predicted_value", Double.class), equalTo(0.5)); } public void testWriteResultsWithImportance() { @@ -50,7 +57,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest new RegressionConfig("predicted_value", 3), importanceList); IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); - result.writeResult(document, "result_field"); + writeResult(result, document, "result_field", "test"); assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3)); @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java index 6a53bddcee9..f034099962f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java @@ -16,6 +16,7 @@ import java.io.IOException; import java.util.HashMap; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult; import static org.hamcrest.Matchers.equalTo; public class WarningInferenceResultsTests extends AbstractSerializingTestCase { @@ -36,9 +37,15 @@ public class WarningInferenceResultsTests extends AbstractSerializingTestCase(), new HashMap<>()); - result.writeResult(document, "result_field"); + writeResult(result, document, "result_field", "test"); assertThat(document.getFieldValue("result_field.warning", String.class), equalTo("foo")); + + result = new WarningInferenceResults("bar"); + writeResult(result, document, "result_field", "test"); + + assertThat(document.getFieldValue("result_field.0.warning", String.class), equalTo("foo")); + assertThat(document.getFieldValue("result_field.1.warning", String.class), equalTo("bar")); } @Override diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index e2b4b555336..d2ed1ce5a7b 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -347,6 +347,52 @@ public class InferenceIngestIT extends ESRestTestCase { assertThat(EntityUtils.toString(response.getEntity()), containsString("\"predicted_value\":\"en\"")); } + public void testSimulateLangIdentForeach() throws IOException { + String source = "{" + + " \"pipeline\": {\n" + + " \"description\": \"detect text lang\",\n" + + " \"processors\": [\n" + + " {\n" + + " \"foreach\": {\n" + + " \"field\": \"greetings\",\n" + + " \"processor\": {\n" + + " \"inference\": {\n" + + " \"model_id\": \"lang_ident_model_1\",\n" + + " \"inference_config\": {\n" + + " \"classification\": {\n" + + " \"num_top_classes\": 5\n" + + " }\n" + + " },\n" + + " \"field_map\": {\n" + + " \"_ingest._value.text\": \"text\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " },\n" + + " \"docs\": [\n" + + " {\n" + + " \"_source\": {\n" + + " \"greetings\": [\n" + + " {\n" + + " \"text\": \" a backup credit card by visiting your billing preferences page or visit the adwords help\"\n" + + " },\n" + + " {\n" + + " \"text\": \" 개별적으로 리포트 액세스 권한을 부여할 수 있습니다 액세스 권한 부여사용자에게 프로필 리포트에 \"\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + Response response = client().performRequest(simulateRequest(source)); + String stringResponse = EntityUtils.toString(response.getEntity()); + assertThat(stringResponse, containsString("\"predicted_value\":\"en\"")); + assertThat(stringResponse, containsString("\"predicted_value\":\"ko\"")); + } + private static Request simulateRequest(String jsonEntity) { Request request = new Request("POST", "_ingest/pipeline/_simulate"); request.setJsonEntity(jsonEntity); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index c17a0668a79..bceb30462f3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -28,6 +28,7 @@ import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.ingest.Processor; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; @@ -49,9 +50,11 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.Consumer; +import static org.elasticsearch.ingest.IngestDocument.INGEST_KEY; import static org.elasticsearch.ingest.Pipeline.PROCESSORS_KEY; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.MODEL_ID_RESULTS_FIELD; public class InferenceProcessor extends AbstractProcessor { @@ -63,7 +66,6 @@ public class InferenceProcessor extends AbstractProcessor { Setting.Property.NodeScope); public static final String TYPE = "inference"; - public static final String MODEL_ID = "model_id"; public static final String INFERENCE_CONFIG = "inference_config"; public static final String TARGET_FIELD = "target_field"; public static final String FIELD_MAPPINGS = "field_mappings"; @@ -92,7 +94,7 @@ public class InferenceProcessor extends AbstractProcessor { this.client = ExceptionsHelper.requireNonNull(client, "client"); this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD); this.auditor = ExceptionsHelper.requireNonNull(auditor, "auditor"); - this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); + this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID_RESULTS_FIELD); this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); this.fieldMap = ExceptionsHelper.requireNonNull(fieldMap, FIELD_MAP); } @@ -132,6 +134,10 @@ public class InferenceProcessor extends AbstractProcessor { InternalInferModelAction.Request buildRequest(IngestDocument ingestDocument) { Map fields = new HashMap<>(ingestDocument.getSourceAndMetadata()); + // Add ingestMetadata as previous processors might have added metadata from which we are predicting (see: foreach processor) + if (ingestDocument.getIngestMetadata().isEmpty() == false) { + fields.put(INGEST_KEY, ingestDocument.getIngestMetadata()); + } LocalModel.mapFieldsIfNecessary(fields, fieldMap); return new InternalInferModelAction.Request(modelId, fields, inferenceConfig, previouslyLicensed); } @@ -150,8 +156,7 @@ public class InferenceProcessor extends AbstractProcessor { throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR); } assert response.getInferenceResults().size() == 1; - response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField); - ingestDocument.setFieldValue(targetField + "." + MODEL_ID, modelId); + InferenceResults.writeResult(response.getInferenceResults().get(0), ingestDocument, targetField, modelId); } @Override @@ -278,7 +283,7 @@ public class InferenceProcessor extends AbstractProcessor { maxIngestProcessors); } - String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID); + String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID_RESULTS_FIELD); String defaultTargetField = tag == null ? DEFAULT_TARGET_FIELD : DEFAULT_TARGET_FIELD + "." + tag; // If multiple inference processors are in the same pipeline, it is wise to tag them // The tag will keep default value entries from stepping on each other @@ -341,12 +346,10 @@ public class InferenceProcessor extends AbstractProcessor { if (configMap.containsKey(ClassificationConfig.NAME.getPreferredName())) { checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS); - ClassificationConfigUpdate config = ClassificationConfigUpdate.fromMap(valueMap); - return config; + return ClassificationConfigUpdate.fromMap(valueMap); } else if (configMap.containsKey(RegressionConfig.NAME.getPreferredName())) { checkSupportedVersion(RegressionConfig.EMPTY_PARAMS); - RegressionConfigUpdate config = RegressionConfigUpdate.fromMap(valueMap); - return config; + return RegressionConfigUpdate.fromMap(valueMap); } else { throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}", configMap.keySet(), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 89b99ab9830..794580938ac 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; @@ -561,7 +562,7 @@ public class ModelLoadingService implements ClusterStateListener { if (processor instanceof Map) { Object processorConfig = ((Map) processor).get(InferenceProcessor.TYPE); if (processorConfig instanceof Map) { - Object modelId = ((Map) processorConfig).get(InferenceProcessor.MODEL_ID); + Object modelId = ((Map) processorConfig).get(InferenceResults.MODEL_ID_RESULTS_FIELD); if (modelId != null) { assert modelId instanceof String; allReferencedModelKeys.add(modelId.toString()); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java index b090efdbffc..faf07b1ccec 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java @@ -34,6 +34,7 @@ import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.junit.Before; @@ -272,7 +273,7 @@ public class TransportGetTrainedModelsStatsActionTests extends ESTestCase { Collections.singletonList( Collections.singletonMap(InferenceProcessor.TYPE, new HashMap() {{ - put(InferenceProcessor.MODEL_ID, modelId); + put(InferenceResults.MODEL_ID_RESULTS_FIELD, modelId); put("inference_config", Collections.singletonMap("regression", Collections.emptyMap())); put("field_map", Collections.emptyMap()); put("target_field", randomAlphaOfLength(10)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java index 916b8b13fea..33e25f3e2e7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -31,6 +31,7 @@ import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.junit.Before; @@ -160,7 +161,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase { Map config = new HashMap() {{ put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("unknown_type", Collections.emptyMap())); }}; @@ -172,7 +173,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase { Map config2 = new HashMap() {{ put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("regression", "boom")); }}; @@ -183,7 +184,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase { Map config3 = new HashMap() {{ put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put(InferenceProcessor.INFERENCE_CONFIG, Collections.emptyMap()); }}; @@ -201,7 +202,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase { Map regression = new HashMap() {{ put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap())); @@ -214,7 +215,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase { Map classification = new HashMap() {{ put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME.getPreferredName(), Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1))); @@ -233,7 +234,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase { processorFactory.accept(builderClusterStateWithModelReferences(Version.V_7_5_0, "model1")); Map minimalConfig = new HashMap() {{ - put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); }}; @@ -249,7 +250,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase { Map regression = new HashMap() {{ put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap())); @@ -260,7 +261,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase { Map classification = new HashMap() {{ put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME.getPreferredName(), Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1))); @@ -269,7 +270,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase { processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, classification); Map mininmal = new HashMap() {{ - put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); }}; @@ -283,7 +284,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase { Map regression = new HashMap() {{ put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); put(InferenceProcessor.TARGET_FIELD, "ml"); put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning"))); @@ -350,7 +351,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase { private static Map inferenceProcessorForModel(String modelId) { return Collections.singletonMap(InferenceProcessor.TYPE, new HashMap() {{ - put(InferenceProcessor.MODEL_ID, modelId); + put(InferenceResults.MODEL_ID_RESULTS_FIELD, modelId); put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap())); put(InferenceProcessor.TARGET_FIELD, "new_field"); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 52a6a5414f9..5fa4ea21d45 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -271,6 +271,13 @@ public class InferenceProcessorTests extends ESTestCase { IngestDocument document = new IngestDocument(source, ingestMetadata); assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(source)); + + ingestMetadata = Collections.singletonMap("_value", 3); + document = new IngestDocument(source, ingestMetadata); + + Map expected = new HashMap<>(source); + expected.put("_ingest", ingestMetadata); + assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expected)); } public void testGenerateWithMapping() { @@ -281,6 +288,7 @@ public class InferenceProcessorTests extends ESTestCase { put("value1", "new_value1"); put("value2", "new_value2"); put("categorical", "new_categorical"); + put("_ingest._value", "metafield"); }}; InferenceProcessor processor = new InferenceProcessor(client, @@ -307,6 +315,13 @@ public class InferenceProcessorTests extends ESTestCase { put("un_touched", "bar"); }}; assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap)); + + ingestMetadata = Collections.singletonMap("_value", "baz"); + document = new IngestDocument(source, ingestMetadata); + expectedMap = new HashMap<>(expectedMap); + expectedMap.put("metafield", "baz"); + expectedMap.put("_ingest", ingestMetadata); + assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap)); } public void testGenerateWithMappingNestedFields() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 3f81b8fb027..19f99edf77e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -44,6 +44,7 @@ import java.util.List; import java.util.Map; import java.util.TreeMap; +import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModelTests.serializeFromTrainedModel; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.closeTo; @@ -167,7 +168,7 @@ public class LocalModelTests extends ESTestCase { new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.STRING)); IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); - result.writeResult(document, "result_field"); + writeResult(result, document, "result_field", modelId); assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("not_to_be")); List list = document.getFieldValue("result_field.top_classes", List.class); assertThat(list.size(), equalTo(2)); @@ -177,7 +178,7 @@ public class LocalModelTests extends ESTestCase { result = getInferenceResult(model, fields, new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.NUMBER)); document = new IngestDocument(new HashMap<>(), new HashMap<>()); - result.writeResult(document, "result_field"); + writeResult(result, document, "result_field", modelId); assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.0)); list = document.getFieldValue("result_field.top_classes", List.class); assertThat(list.size(), equalTo(2)); @@ -187,7 +188,7 @@ public class LocalModelTests extends ESTestCase { result = getInferenceResult(model, fields, new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.BOOLEAN)); document = new IngestDocument(new HashMap<>(), new HashMap<>()); - result.writeResult(document, "result_field"); + writeResult(result, document, "result_field", modelId); assertThat(document.getFieldValue("result_field.predicted_value", Boolean.class), equalTo(false)); list = document.getFieldValue("result_field.top_classes", List.class); assertThat(list.size(), equalTo(2)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 0a7e2475ebd..8d8f54120cf 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -36,6 +36,7 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; @@ -630,7 +631,7 @@ public class ModelLoadingServiceTests extends ESTestCase { try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors", Collections.singletonList( Collections.singletonMap(InferenceProcessor.TYPE, - Collections.singletonMap(InferenceProcessor.MODEL_ID, + Collections.singletonMap(InferenceResults.MODEL_ID_RESULTS_FIELD, modelId)))))) { return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON); }