`foreach` processors store information within the `_ingest` metadata object. This commit adds the contents of the `_ingest` metadata (if it is not empty). And will append new inference results if the result field already exists. This allows a `foreach` to execute and multiple inference results being written to the same result field. closes https://github.com/elastic/elasticsearch/issues/60867
This commit is contained in:
parent
c81dc2b8b7
commit
4275a715c9
|
@ -9,11 +9,9 @@ import org.elasticsearch.Version;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
@ -160,13 +158,6 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
|||
return predictionFieldType.transformPredictedValue(value(), valueAsString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeResult(IngestDocument document, String parentResultField) {
|
||||
ExceptionsHelper.requireNonNull(document, "document");
|
||||
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
|
||||
document.setFieldValue(parentResultField, asMap());
|
||||
}
|
||||
|
||||
public Double getPredictionProbability() {
|
||||
return predictionProbability;
|
||||
}
|
||||
|
|
|
@ -8,12 +8,25 @@ package org.elasticsearch.xpack.core.ml.inference.results;
|
|||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentFragment;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public interface InferenceResults extends NamedWriteable, ToXContentFragment {
|
||||
String MODEL_ID_RESULTS_FIELD = "model_id";
|
||||
|
||||
void writeResult(IngestDocument document, String parentResultField);
|
||||
static void writeResult(InferenceResults results, IngestDocument ingestDocument, String resultField, String modelId) {
|
||||
ExceptionsHelper.requireNonNull(results, "results");
|
||||
ExceptionsHelper.requireNonNull(ingestDocument, "ingestDocument");
|
||||
ExceptionsHelper.requireNonNull(resultField, "resultField");
|
||||
Map<String, Object> resultMap = results.asMap();
|
||||
resultMap.put(MODEL_ID_RESULTS_FIELD, modelId);
|
||||
if (ingestDocument.hasField(resultField)) {
|
||||
ingestDocument.appendFieldValue(resultField, resultMap);
|
||||
} else {
|
||||
ingestDocument.setFieldValue(resultField, resultMap);
|
||||
}
|
||||
}
|
||||
|
||||
Map<String, Object> asMap();
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.ml.inference.results;
|
|||
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
@ -53,15 +52,11 @@ public class RawInferenceResults implements InferenceResults {
|
|||
return Objects.hash(Arrays.hashCode(value), featureImportance);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeResult(IngestDocument document, String parentResultField) {
|
||||
throw new UnsupportedOperationException("[raw] does not support writing inference results");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> asMap() {
|
||||
throw new UnsupportedOperationException("[raw] does not support map conversion");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object predictedValue() {
|
||||
return null;
|
||||
|
|
|
@ -8,10 +8,8 @@ package org.elasticsearch.xpack.core.ml.inference.results;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
@ -83,13 +81,6 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
|||
return super.value();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeResult(IngestDocument document, String parentResultField) {
|
||||
ExceptionsHelper.requireNonNull(document, "document");
|
||||
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
|
||||
document.setFieldValue(parentResultField, asMap());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> asMap() {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
|
|
|
@ -9,8 +9,6 @@ import org.elasticsearch.common.ParseField;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.LinkedHashMap;
|
||||
|
@ -54,13 +52,6 @@ public class WarningInferenceResults implements InferenceResults {
|
|||
return Objects.hash(warning);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeResult(IngestDocument document, String parentResultField) {
|
||||
ExceptionsHelper.requireNonNull(document, "document");
|
||||
ExceptionsHelper.requireNonNull(parentResultField, "resultField");
|
||||
document.setFieldValue(parentResultField, asMap());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> asMap() {
|
||||
Map<String, Object> asMap = new LinkedHashMap<>();
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.util.function.Supplier;
|
|||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
||||
|
@ -64,7 +65,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
1.0,
|
||||
1.0);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
writeResult(result, document, "result_field", "test");
|
||||
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("foo"));
|
||||
}
|
||||
|
@ -78,9 +79,20 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
1.0,
|
||||
1.0);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
writeResult(result, document, "result_field", "test");
|
||||
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("1.0"));
|
||||
|
||||
result = new ClassificationInferenceResults(2.0,
|
||||
null,
|
||||
Collections.emptyList(),
|
||||
Collections.emptyList(),
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
1.0,
|
||||
1.0);
|
||||
writeResult(result, document, "result_field", "test");
|
||||
assertThat(document.getFieldValue("result_field.0.predicted_value", String.class), equalTo("1.0"));
|
||||
assertThat(document.getFieldValue("result_field.1.predicted_value", String.class), equalTo("2.0"));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -97,7 +109,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
0.7,
|
||||
0.7);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
writeResult(result, document, "result_field", "test");
|
||||
|
||||
List<?> list = document.getFieldValue("result_field.bar", List.class);
|
||||
assertThat(list.size(), equalTo(3));
|
||||
|
@ -126,7 +138,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
|||
1.0,
|
||||
1.0);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
writeResult(result, document, "result_field", "test");
|
||||
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("foo"));
|
||||
@SuppressWarnings("unchecked")
|
||||
|
|
|
@ -19,6 +19,7 @@ import java.util.Map;
|
|||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
||||
|
@ -37,9 +38,15 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
|
|||
public void testWriteResults() {
|
||||
RegressionInferenceResults result = new RegressionInferenceResults(0.3, RegressionConfig.EMPTY_PARAMS);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
writeResult(result, document, "result_field", "test");
|
||||
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3));
|
||||
|
||||
result = new RegressionInferenceResults(0.5, RegressionConfig.EMPTY_PARAMS);
|
||||
writeResult(result, document, "result_field", "test");
|
||||
|
||||
assertThat(document.getFieldValue("result_field.0.predicted_value", Double.class), equalTo(0.3));
|
||||
assertThat(document.getFieldValue("result_field.1.predicted_value", Double.class), equalTo(0.5));
|
||||
}
|
||||
|
||||
public void testWriteResultsWithImportance() {
|
||||
|
@ -50,7 +57,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
|
|||
new RegressionConfig("predicted_value", 3),
|
||||
importanceList);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
writeResult(result, document, "result_field", "test");
|
||||
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3));
|
||||
@SuppressWarnings("unchecked")
|
||||
|
|
|
@ -16,6 +16,7 @@ import java.io.IOException;
|
|||
import java.util.HashMap;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class WarningInferenceResultsTests extends AbstractSerializingTestCase<WarningInferenceResults> {
|
||||
|
@ -36,9 +37,15 @@ public class WarningInferenceResultsTests extends AbstractSerializingTestCase<Wa
|
|||
public void testWriteResults() {
|
||||
WarningInferenceResults result = new WarningInferenceResults("foo");
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
writeResult(result, document, "result_field", "test");
|
||||
|
||||
assertThat(document.getFieldValue("result_field.warning", String.class), equalTo("foo"));
|
||||
|
||||
result = new WarningInferenceResults("bar");
|
||||
writeResult(result, document, "result_field", "test");
|
||||
|
||||
assertThat(document.getFieldValue("result_field.0.warning", String.class), equalTo("foo"));
|
||||
assertThat(document.getFieldValue("result_field.1.warning", String.class), equalTo("bar"));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -347,6 +347,52 @@ public class InferenceIngestIT extends ESRestTestCase {
|
|||
assertThat(EntityUtils.toString(response.getEntity()), containsString("\"predicted_value\":\"en\""));
|
||||
}
|
||||
|
||||
public void testSimulateLangIdentForeach() throws IOException {
|
||||
String source = "{" +
|
||||
" \"pipeline\": {\n" +
|
||||
" \"description\": \"detect text lang\",\n" +
|
||||
" \"processors\": [\n" +
|
||||
" {\n" +
|
||||
" \"foreach\": {\n" +
|
||||
" \"field\": \"greetings\",\n" +
|
||||
" \"processor\": {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"model_id\": \"lang_ident_model_1\",\n" +
|
||||
" \"inference_config\": {\n" +
|
||||
" \"classification\": {\n" +
|
||||
" \"num_top_classes\": 5\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" \"field_map\": {\n" +
|
||||
" \"_ingest._value.text\": \"text\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ]\n" +
|
||||
" },\n" +
|
||||
" \"docs\": [\n" +
|
||||
" {\n" +
|
||||
" \"_source\": {\n" +
|
||||
" \"greetings\": [\n" +
|
||||
" {\n" +
|
||||
" \"text\": \" a backup credit card by visiting your billing preferences page or visit the adwords help\"\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"text\": \" 개별적으로 리포트 액세스 권한을 부여할 수 있습니다 액세스 권한 부여사용자에게 프로필 리포트에 \"\n" +
|
||||
" }\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ]\n" +
|
||||
"}";
|
||||
Response response = client().performRequest(simulateRequest(source));
|
||||
String stringResponse = EntityUtils.toString(response.getEntity());
|
||||
assertThat(stringResponse, containsString("\"predicted_value\":\"en\""));
|
||||
assertThat(stringResponse, containsString("\"predicted_value\":\"ko\""));
|
||||
}
|
||||
|
||||
private static Request simulateRequest(String jsonEntity) {
|
||||
Request request = new Request("POST", "_ingest/pipeline/_simulate");
|
||||
request.setJsonEntity(jsonEntity);
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.elasticsearch.ingest.PipelineConfiguration;
|
|||
import org.elasticsearch.ingest.Processor;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
|
||||
|
@ -49,9 +50,11 @@ import java.util.concurrent.atomic.AtomicBoolean;
|
|||
import java.util.function.BiConsumer;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static org.elasticsearch.ingest.IngestDocument.INGEST_KEY;
|
||||
import static org.elasticsearch.ingest.Pipeline.PROCESSORS_KEY;
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.MODEL_ID_RESULTS_FIELD;
|
||||
|
||||
public class InferenceProcessor extends AbstractProcessor {
|
||||
|
||||
|
@ -63,7 +66,6 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
Setting.Property.NodeScope);
|
||||
|
||||
public static final String TYPE = "inference";
|
||||
public static final String MODEL_ID = "model_id";
|
||||
public static final String INFERENCE_CONFIG = "inference_config";
|
||||
public static final String TARGET_FIELD = "target_field";
|
||||
public static final String FIELD_MAPPINGS = "field_mappings";
|
||||
|
@ -92,7 +94,7 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
this.client = ExceptionsHelper.requireNonNull(client, "client");
|
||||
this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD);
|
||||
this.auditor = ExceptionsHelper.requireNonNull(auditor, "auditor");
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID_RESULTS_FIELD);
|
||||
this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
|
||||
this.fieldMap = ExceptionsHelper.requireNonNull(fieldMap, FIELD_MAP);
|
||||
}
|
||||
|
@ -132,6 +134,10 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
|
||||
InternalInferModelAction.Request buildRequest(IngestDocument ingestDocument) {
|
||||
Map<String, Object> fields = new HashMap<>(ingestDocument.getSourceAndMetadata());
|
||||
// Add ingestMetadata as previous processors might have added metadata from which we are predicting (see: foreach processor)
|
||||
if (ingestDocument.getIngestMetadata().isEmpty() == false) {
|
||||
fields.put(INGEST_KEY, ingestDocument.getIngestMetadata());
|
||||
}
|
||||
LocalModel.mapFieldsIfNecessary(fields, fieldMap);
|
||||
return new InternalInferModelAction.Request(modelId, fields, inferenceConfig, previouslyLicensed);
|
||||
}
|
||||
|
@ -150,8 +156,7 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
assert response.getInferenceResults().size() == 1;
|
||||
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
|
||||
ingestDocument.setFieldValue(targetField + "." + MODEL_ID, modelId);
|
||||
InferenceResults.writeResult(response.getInferenceResults().get(0), ingestDocument, targetField, modelId);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -278,7 +283,7 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
maxIngestProcessors);
|
||||
}
|
||||
|
||||
String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID);
|
||||
String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID_RESULTS_FIELD);
|
||||
String defaultTargetField = tag == null ? DEFAULT_TARGET_FIELD : DEFAULT_TARGET_FIELD + "." + tag;
|
||||
// If multiple inference processors are in the same pipeline, it is wise to tag them
|
||||
// The tag will keep default value entries from stepping on each other
|
||||
|
@ -341,12 +346,10 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
|
||||
if (configMap.containsKey(ClassificationConfig.NAME.getPreferredName())) {
|
||||
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
|
||||
ClassificationConfigUpdate config = ClassificationConfigUpdate.fromMap(valueMap);
|
||||
return config;
|
||||
return ClassificationConfigUpdate.fromMap(valueMap);
|
||||
} else if (configMap.containsKey(RegressionConfig.NAME.getPreferredName())) {
|
||||
checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
|
||||
RegressionConfigUpdate config = RegressionConfigUpdate.fromMap(valueMap);
|
||||
return config;
|
||||
return RegressionConfigUpdate.fromMap(valueMap);
|
||||
} else {
|
||||
throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
|
||||
configMap.keySet(),
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.common.util.set.Sets;
|
|||
import org.elasticsearch.ingest.IngestMetadata;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
|
@ -561,7 +562,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
if (processor instanceof Map<?, ?>) {
|
||||
Object processorConfig = ((Map<?, ?>) processor).get(InferenceProcessor.TYPE);
|
||||
if (processorConfig instanceof Map<?, ?>) {
|
||||
Object modelId = ((Map<?, ?>) processorConfig).get(InferenceProcessor.MODEL_ID);
|
||||
Object modelId = ((Map<?, ?>) processorConfig).get(InferenceResults.MODEL_ID_RESULTS_FIELD);
|
||||
if (modelId != null) {
|
||||
assert modelId instanceof String;
|
||||
allReferencedModelKeys.add(modelId.toString());
|
||||
|
|
|
@ -34,6 +34,7 @@ import org.elasticsearch.license.XPackLicenseState;
|
|||
import org.elasticsearch.plugins.IngestPlugin;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -272,7 +273,7 @@ public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
|
|||
Collections.singletonList(
|
||||
Collections.singletonMap(InferenceProcessor.TYPE,
|
||||
new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.MODEL_ID, modelId);
|
||||
put(InferenceResults.MODEL_ID_RESULTS_FIELD, modelId);
|
||||
put("inference_config", Collections.singletonMap("regression", Collections.emptyMap()));
|
||||
put("field_map", Collections.emptyMap());
|
||||
put("target_field", randomAlphaOfLength(10));
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.elasticsearch.ingest.IngestMetadata;
|
|||
import org.elasticsearch.ingest.PipelineConfiguration;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.junit.Before;
|
||||
|
@ -160,7 +161,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
|
||||
Map<String, Object> config = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("unknown_type", Collections.emptyMap()));
|
||||
}};
|
||||
|
@ -172,7 +173,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
|
||||
Map<String, Object> config2 = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("regression", "boom"));
|
||||
}};
|
||||
|
@ -183,7 +184,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
|
||||
Map<String, Object> config3 = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.emptyMap());
|
||||
}};
|
||||
|
@ -201,7 +202,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
|
||||
Map<String, Object> regression = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG,
|
||||
Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()));
|
||||
|
@ -214,7 +215,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
|
||||
Map<String, Object> classification = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME.getPreferredName(),
|
||||
Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1)));
|
||||
|
@ -233,7 +234,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
processorFactory.accept(builderClusterStateWithModelReferences(Version.V_7_5_0, "model1"));
|
||||
|
||||
Map<String, Object> minimalConfig = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
}};
|
||||
|
||||
|
@ -249,7 +250,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
|
||||
Map<String, Object> regression = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG,
|
||||
Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()));
|
||||
|
@ -260,7 +261,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
|
||||
Map<String, Object> classification = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME.getPreferredName(),
|
||||
Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1)));
|
||||
|
@ -269,7 +270,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, classification);
|
||||
|
||||
Map<String, Object> mininmal = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "result");
|
||||
}};
|
||||
|
||||
|
@ -283,7 +284,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
|
||||
Map<String, Object> regression = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "ml");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME.getPreferredName(),
|
||||
Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning")));
|
||||
|
@ -350,7 +351,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
private static Map<String, Object> inferenceProcessorForModel(String modelId) {
|
||||
return Collections.singletonMap(InferenceProcessor.TYPE,
|
||||
new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.MODEL_ID, modelId);
|
||||
put(InferenceResults.MODEL_ID_RESULTS_FIELD, modelId);
|
||||
put(InferenceProcessor.INFERENCE_CONFIG,
|
||||
Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()));
|
||||
put(InferenceProcessor.TARGET_FIELD, "new_field");
|
||||
|
|
|
@ -271,6 +271,13 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(source));
|
||||
|
||||
ingestMetadata = Collections.singletonMap("_value", 3);
|
||||
document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
Map<String, Object> expected = new HashMap<>(source);
|
||||
expected.put("_ingest", ingestMetadata);
|
||||
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expected));
|
||||
}
|
||||
|
||||
public void testGenerateWithMapping() {
|
||||
|
@ -281,6 +288,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
put("value1", "new_value1");
|
||||
put("value2", "new_value2");
|
||||
put("categorical", "new_categorical");
|
||||
put("_ingest._value", "metafield");
|
||||
}};
|
||||
|
||||
InferenceProcessor processor = new InferenceProcessor(client,
|
||||
|
@ -307,6 +315,13 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
put("un_touched", "bar");
|
||||
}};
|
||||
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap));
|
||||
|
||||
ingestMetadata = Collections.singletonMap("_value", "baz");
|
||||
document = new IngestDocument(source, ingestMetadata);
|
||||
expectedMap = new HashMap<>(expectedMap);
|
||||
expectedMap.put("metafield", "baz");
|
||||
expectedMap.put("_ingest", ingestMetadata);
|
||||
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap));
|
||||
}
|
||||
|
||||
public void testGenerateWithMappingNestedFields() {
|
||||
|
|
|
@ -44,6 +44,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.TreeMap;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModelTests.serializeFromTrainedModel;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
|
@ -167,7 +168,7 @@ public class LocalModelTests extends ESTestCase {
|
|||
new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.STRING));
|
||||
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
writeResult(result, document, "result_field", modelId);
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("not_to_be"));
|
||||
List<?> list = document.getFieldValue("result_field.top_classes", List.class);
|
||||
assertThat(list.size(), equalTo(2));
|
||||
|
@ -177,7 +178,7 @@ public class LocalModelTests extends ESTestCase {
|
|||
result = getInferenceResult(model, fields, new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.NUMBER));
|
||||
|
||||
document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
writeResult(result, document, "result_field", modelId);
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.0));
|
||||
list = document.getFieldValue("result_field.top_classes", List.class);
|
||||
assertThat(list.size(), equalTo(2));
|
||||
|
@ -187,7 +188,7 @@ public class LocalModelTests extends ESTestCase {
|
|||
result = getInferenceResult(model, fields, new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.BOOLEAN));
|
||||
|
||||
document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
writeResult(result, document, "result_field", modelId);
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", Boolean.class), equalTo(false));
|
||||
list = document.getFieldValue("result_field.top_classes", List.class);
|
||||
assertThat(list.size(), equalTo(2));
|
||||
|
|
|
@ -36,6 +36,7 @@ import org.elasticsearch.threadpool.TestThreadPool;
|
|||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
|
||||
|
@ -630,7 +631,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors",
|
||||
Collections.singletonList(
|
||||
Collections.singletonMap(InferenceProcessor.TYPE,
|
||||
Collections.singletonMap(InferenceProcessor.MODEL_ID,
|
||||
Collections.singletonMap(InferenceResults.MODEL_ID_RESULTS_FIELD,
|
||||
modelId)))))) {
|
||||
return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue