[ML][Inference] adds new default_field_map field to trained models (#53294) (#53419)

Adds a new `default_field_map` field to trained model config objects.

This allows the model creator to supply field map if it knows that there should be some map for inference to work directly against the training data.

The use case internally is having analytics jobs supply a field mapping for multi-field fields. This allows us to use the model "out of the box" on data where we trained on `foo.keyword` but the `_source` only references `foo`.
This commit is contained in:
Benjamin Trent 2020-03-11 13:49:39 -04:00 committed by GitHub
parent 9ada508347
commit 89668c5ea0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 226 additions and 41 deletions

View File

@ -53,6 +53,7 @@ public class TrainedModelConfig implements ToXContentObject {
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map");
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
@ -76,6 +77,7 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
PARSER.declareObject(TrainedModelConfig.Builder::setDefaultFieldMap, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAP);
}
public static TrainedModelConfig fromXContent(XContentParser parser) throws IOException {
@ -95,6 +97,7 @@ public class TrainedModelConfig implements ToXContentObject {
private final Long estimatedHeapMemory;
private final Long estimatedOperations;
private final String licenseLevel;
private final Map<String, String> defaultFieldMap;
TrainedModelConfig(String modelId,
String createdBy,
@ -108,7 +111,8 @@ public class TrainedModelConfig implements ToXContentObject {
TrainedModelInput input,
Long estimatedHeapMemory,
Long estimatedOperations,
String licenseLevel) {
String licenseLevel,
Map<String, String> defaultFieldMap) {
this.modelId = modelId;
this.createdBy = createdBy;
this.version = version;
@ -122,6 +126,7 @@ public class TrainedModelConfig implements ToXContentObject {
this.estimatedHeapMemory = estimatedHeapMemory;
this.estimatedOperations = estimatedOperations;
this.licenseLevel = licenseLevel;
this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap);
}
public String getModelId() {
@ -180,6 +185,10 @@ public class TrainedModelConfig implements ToXContentObject {
return licenseLevel;
}
public Map<String, String> getDefaultFieldMap() {
return defaultFieldMap;
}
public static Builder builder() {
return new Builder();
}
@ -226,6 +235,9 @@ public class TrainedModelConfig implements ToXContentObject {
if (licenseLevel != null) {
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel);
}
if (defaultFieldMap != null) {
builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap);
}
builder.endObject();
return builder;
}
@ -252,6 +264,7 @@ public class TrainedModelConfig implements ToXContentObject {
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
Objects.equals(estimatedOperations, that.estimatedOperations) &&
Objects.equals(licenseLevel, that.licenseLevel) &&
Objects.equals(defaultFieldMap, that.defaultFieldMap) &&
Objects.equals(metadata, that.metadata);
}
@ -269,7 +282,8 @@ public class TrainedModelConfig implements ToXContentObject {
estimatedOperations,
metadata,
licenseLevel,
input);
input,
defaultFieldMap);
}
@ -288,6 +302,7 @@ public class TrainedModelConfig implements ToXContentObject {
private Long estimatedHeapMemory;
private Long estimatedOperations;
private String licenseLevel;
private Map<String, String> defaultFieldMap;
public Builder setModelId(String modelId) {
this.modelId = modelId;
@ -367,6 +382,11 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
public Builder setDefaultFieldMap(Map<String, String> defaultFieldMap) {
this.defaultFieldMap = defaultFieldMap;
return this;
}
public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
@ -381,7 +401,8 @@ public class TrainedModelConfig implements ToXContentObject {
input,
estimatedHeapMemory,
estimatedOperations,
licenseLevel);
licenseLevel,
defaultFieldMap);
}
}

View File

@ -30,6 +30,7 @@ import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -52,7 +53,11 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
randomBoolean() ? null : randomNonNegativeLong(),
randomBoolean() ? null : randomNonNegativeLong(),
randomBoolean() ? null : randomFrom("platinum", "basic"));
randomBoolean() ? null : randomFrom("platinum", "basic"),
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))));
}
@Override

View File

@ -14,7 +14,7 @@ ingested in the pipeline.
| Name | Required | Default | Description
| `model_id` | yes | - | (String) The ID of the model to load and infer against.
| `target_field` | no | `ml.inference.<processor_tag>` | (String) Field added to incoming documents to contain results objects.
| `field_mappings` | yes | - | (Object) Maps the document field names to the known field names of the model.
| `field_mappings` | yes | - | (Object) Maps the document field names to the known field names of the model. This mapping takes precedence over any default mappings provided in the model configuration.
| `inference_config` | yes | - | (Object) Contains the inference type and its options. There are two types: <<inference-processor-regression-opt,`regression`>> and <<inference-processor-classification-opt,`classification`>>.
include::common-options.asciidoc[]
|======

View File

@ -1505,6 +1505,17 @@ The estimated number of operations to use the trained model.
`license_level`:::
(string)
The license level of the trained model.
`default_field_map` :::
(object)
A string to string object that contains the default field map to use
when inferring against the model. For example, data frame analytics
may train the model on a specific multi-field `foo.keyword`.
The analytics job would then supply a default field map entry for
`"foo" : "foo.keyword"`.
Any field map described in the inference configuration takes precedence.
end::trained-model-configs[]
tag::training-percent[]

View File

@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -60,6 +61,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
@ -90,6 +92,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
DEFINITION);
parser.declareString(TrainedModelConfig.Builder::setLazyDefinition, COMPRESSED_DEFINITION);
parser.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
parser.declareObject(TrainedModelConfig.Builder::setDefaultFieldMap, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAP);
return parser;
}
@ -108,6 +111,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private final long estimatedHeapMemory;
private final long estimatedOperations;
private final License.OperationMode licenseLevel;
private final Map<String, String> defaultFieldMap;
private final LazyModelDefinition definition;
@ -122,7 +126,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
TrainedModelInput input,
Long estimatedHeapMemory,
Long estimatedOperations,
String licenseLevel) {
String licenseLevel,
Map<String, String> defaultFieldMap) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
@ -142,6 +147,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
}
this.estimatedOperations = estimatedOperations;
this.licenseLevel = License.OperationMode.parse(ExceptionsHelper.requireNonNull(licenseLevel, LICENSE_LEVEL));
this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap);
}
public TrainedModelConfig(StreamInput in) throws IOException {
@ -157,6 +163,13 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
estimatedHeapMemory = in.readVLong();
estimatedOperations = in.readVLong();
licenseLevel = License.OperationMode.parse(in.readString());
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.defaultFieldMap = in.readBoolean() ?
Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString)) :
null;
} else {
this.defaultFieldMap = null;
}
}
public String getModelId() {
@ -187,6 +200,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return metadata;
}
public Map<String, String> getDefaultFieldMap() {
return defaultFieldMap;
}
@Nullable
public String getCompressedDefinition() throws IOException {
if (definition == null) {
@ -249,6 +266,14 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
out.writeVLong(estimatedHeapMemory);
out.writeVLong(estimatedOperations);
out.writeString(licenseLevel.description());
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
if (defaultFieldMap != null) {
out.writeBoolean(true);
out.writeMap(defaultFieldMap, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
}
}
@Override
@ -283,6 +308,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
new ByteSizeValue(estimatedHeapMemory));
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
if (defaultFieldMap != null && defaultFieldMap.isEmpty() == false) {
builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap);
}
builder.endObject();
return builder;
}
@ -308,6 +336,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
Objects.equals(estimatedOperations, that.estimatedOperations) &&
Objects.equals(licenseLevel, that.licenseLevel) &&
Objects.equals(defaultFieldMap, that.defaultFieldMap) &&
Objects.equals(metadata, that.metadata);
}
@ -324,7 +353,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
estimatedHeapMemory,
estimatedOperations,
input,
licenseLevel);
licenseLevel,
defaultFieldMap);
}
public static class Builder {
@ -341,6 +371,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private Long estimatedOperations;
private LazyModelDefinition definition;
private String licenseLevel;
private Map<String, String> defaultFieldMap;
public Builder() {}
@ -357,6 +388,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
this.estimatedOperations = config.estimatedOperations;
this.estimatedHeapMemory = config.estimatedHeapMemory;
this.licenseLevel = config.licenseLevel.description();
this.defaultFieldMap = config.defaultFieldMap == null ? null : new HashMap<>(config.defaultFieldMap);
}
public Builder setModelId(String modelId) {
@ -475,6 +507,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this;
}
public Builder setDefaultFieldMap(Map<String, String> defaultFieldMap) {
this.defaultFieldMap = defaultFieldMap;
return this;
}
public Builder validate() {
return validate(false);
}
@ -567,7 +604,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
input,
estimatedHeapMemory == null ? 0 : estimatedHeapMemory,
estimatedOperations == null ? 0 : estimatedOperations,
licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel);
licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel,
defaultFieldMap);
}
}

View File

@ -64,6 +64,9 @@
},
"total_definition_length": {
"type": "long"
},
"default_field_map": {
"enabled": false
}
}
}

View File

@ -34,9 +34,11 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester;
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
@ -137,7 +139,11 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
TrainedModelInputTests.createRandomInput(),
randomNonNegativeLong(),
randomNonNegativeLong(),
"platinum");
"platinum",
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))));
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
assertThat(reference.utf8ToString(), containsString("\"compressed_definition\""));
@ -172,7 +178,11 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
TrainedModelInputTests.createRandomInput(),
randomNonNegativeLong(),
randomNonNegativeLong(),
"platinum");
"platinum",
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))));
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
Map<String, Object> objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2();

View File

@ -191,6 +191,40 @@ public class InferenceIngestIT extends ESRestTestCase {
assertThat(responseString, containsString("Could not find trained model [test_classification_missing]"));
}
public void testSimulateWithDefaultMappedField() throws IOException {
String source = "{\n" +
" \"pipeline\": {\n" +
" \"processors\": [\n" +
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"ml.classification\",\n" +
" \"inference_config\": {\"classification\": " +
" {\"num_top_classes\":2, " +
" \"top_classes_results_field\": \"result_class_prob\"," +
" \"num_top_feature_importance_values\": 2" +
" }},\n" +
" \"model_id\": \"test_classification\",\n" +
" \"field_mappings\": {}\n" +
" }\n" +
" }\n"+
" ]\n" +
" },\n" +
" \"docs\": [\n" +
" {\"_source\": {\n" +
" \"col_1_alias\": \"female\",\n" +
" \"col2\": \"M\",\n" +
" \"col3\": \"none\",\n" +
" \"col4\": 10\n" +
" }}]\n" +
"}";
Response response = client().performRequest(simulateRequest(source));
String responseString = EntityUtils.toString(response.getEntity());
assertThat(responseString, containsString("\"predicted_value\":\"second\""));
assertThat(responseString, containsString("\"col2\":0.944"));
assertThat(responseString, containsString("\"col1\":0.19999"));
}
public void testSimulateLangIdent() throws IOException {
String source = "{\n" +
" \"pipeline\": {\n" +
@ -525,6 +559,7 @@ public class InferenceIngestIT extends ESRestTestCase {
"{\n" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for classification\",\n" +
" \"default_field_map\": {\"col_1_alias\": \"col1\"},\n" +
" \"definition\": " + CLASSIFICATION_DEFINITION +
"}";

View File

@ -234,6 +234,10 @@ public class DataFrameDataExtractor {
return context.extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList());
}
public List<ExtractedField> getAllExtractedFields() {
return context.extractedFields.getAllFields();
}
public DataSummary collectDataSummary() {
SearchRequestBuilder searchRequestBuilder = buildDataSummarySearchRequestBuilder();
SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);

View File

@ -434,7 +434,7 @@ public class AnalyticsProcessManager {
new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService);
return new AnalyticsResultProcessor(
config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, resultsPersisterService,
dataExtractor.get().getFieldNames());
dataExtractor.get().getAllExtractedFields());
}
}
}

View File

@ -33,6 +33,8 @@ import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.MultiField;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
@ -42,10 +44,12 @@ import java.time.Instant;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import static java.util.stream.Collectors.toList;
@ -71,7 +75,7 @@ public class AnalyticsResultProcessor {
private final TrainedModelProvider trainedModelProvider;
private final DataFrameAnalyticsAuditor auditor;
private final ResultsPersisterService resultsPersisterService;
private final List<String> fieldNames;
private final List<ExtractedField> fieldNames;
private final CountDownLatch completionLatch = new CountDownLatch(1);
private volatile String failure;
private volatile boolean isCancelled;
@ -79,7 +83,7 @@ public class AnalyticsResultProcessor {
public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
StatsHolder statsHolder, TrainedModelProvider trainedModelProvider,
DataFrameAnalyticsAuditor auditor, ResultsPersisterService resultsPersisterService,
List<String> fieldNames) {
List<ExtractedField> fieldNames) {
this.analytics = Objects.requireNonNull(analytics);
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
this.statsHolder = Objects.requireNonNull(statsHolder);
@ -192,8 +196,12 @@ public class AnalyticsResultProcessor {
TrainedModelDefinition definition = inferenceModel.build();
String dependentVariable = getDependentVariable();
List<String> fieldNamesWithoutDependentVariable = fieldNames.stream()
.map(ExtractedField::getName)
.filter(f -> f.equals(dependentVariable) == false)
.collect(toList());
Map<String, String> defaultFieldMapping = fieldNames.stream()
.filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false))
.collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName));
return TrainedModelConfig.builder()
.setModelId(modelId)
.setCreatedBy(XPackUser.NAME)
@ -209,6 +217,7 @@ public class AnalyticsResultProcessor {
.setParsedDefinition(inferenceModel)
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setDefaultFieldMap(defaultFieldMapping)
.build();
}

View File

@ -17,7 +17,7 @@ public class MultiField implements ExtractedField {
private final ExtractedField field;
private final String parent;
MultiField(String parent, ExtractedField field) {
public MultiField(String parent, ExtractedField field) {
this(field.getName(), field.getSearchField(), parent, field);
}

View File

@ -34,7 +34,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import org.elasticsearch.xpack.ml.inference.loadingservice.Model;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import java.util.Arrays;
@ -126,14 +126,7 @@ public class InferenceProcessor extends AbstractProcessor {
InternalInferModelAction.Request buildRequest(IngestDocument ingestDocument) {
Map<String, Object> fields = new HashMap<>(ingestDocument.getSourceAndMetadata());
if (fieldMapping != null) {
fieldMapping.forEach((src, dest) -> {
Object srcValue = MapHelper.dig(src, fields);
if (srcValue != null) {
fields.put(dest, srcValue);
}
});
}
Model.mapFieldsIfNecessary(fields, fieldMapping);
return new InternalInferModelAction.Request(modelId, fields, inferenceConfig, previouslyLicensed);
}

View File

@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
@ -28,11 +29,16 @@ public class LocalModel implements Model {
private final TrainedModelDefinition trainedModelDefinition;
private final String modelId;
private final Set<String> fieldNames;
private final Map<String, String> defaultFieldMap;
public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition, TrainedModelInput input) {
public LocalModel(String modelId,
TrainedModelDefinition trainedModelDefinition,
TrainedModelInput input,
Map<String, String> defaultFieldMap) {
this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId;
this.fieldNames = new HashSet<>(input.getFieldNames());
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
}
long ramBytesUsed() {
@ -61,6 +67,7 @@ public class LocalModel implements Model {
@Override
public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
try {
Model.mapFieldsIfNecessary(fields, defaultFieldMap);
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
return;

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.util.Map;
@ -18,4 +19,27 @@ public interface Model {
void infer(Map<String, Object> fields, InferenceConfig inferenceConfig, ActionListener<InferenceResults> listener);
String getModelId();
/**
* Used for translating field names in according to the passed `fieldMappings` parameter.
*
* This mutates the `fields` parameter in-place.
*
* Fields are only appended. If the expected field name already exists, it is not created/overwritten.
*
* Original fields are not deleted.
*
* @param fields Fields to map against
* @param fieldMapping Field originalName to expectedName string mapping
*/
static void mapFieldsIfNecessary(Map<String, Object> fields, Map<String, String> fieldMapping) {
if (fieldMapping != null) {
fieldMapping.forEach((src, dest) -> {
Object srcValue = MapHelper.dig(src, fields);
if (srcValue != null) {
fields.putIfAbsent(dest, srcValue);
}
});
}
}
}

View File

@ -142,7 +142,8 @@ public class ModelLoadingService implements ClusterStateListener {
modelActionListener.onResponse(new LocalModel(
trainedModelConfig.getModelId(),
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(),
trainedModelConfig.getInput())),
trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap())),
modelActionListener::onFailure
));
} else {
@ -200,7 +201,8 @@ public class ModelLoadingService implements ClusterStateListener {
LocalModel loadedModel = new LocalModel(
trainedModelConfig.getModelId(),
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(),
trainedModelConfig.getInput());
trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap());
synchronized (loadingListeners) {
listeners = loadingListeners.remove(modelId);
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such

View File

@ -169,7 +169,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
inOrder.verify(process).isProcessAlive();
inOrder.verify(task).getStatsHolder();
inOrder.verify(dataExtractor).getFieldNames();
inOrder.verify(dataExtractor).getAllExtractedFields();
inOrder.verify(executorServiceForProcess, times(2)).execute(any()); // 'processData' and 'processResults' threads
verifyNoMoreInteractions(dataExtractor, executorServiceForProcess, process, task);
}
@ -227,7 +227,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
inOrder.verify(process).isProcessAlive();
inOrder.verify(task).getStatsHolder();
inOrder.verify(dataExtractor).getFieldNames();
inOrder.verify(dataExtractor).getAllExtractedFields();
// stop
inOrder.verify(dataExtractor).cancel();
inOrder.verify(process).kill();

View File

@ -24,7 +24,10 @@ import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.extractor.DocValueField;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.extractor.MultiField;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
@ -33,6 +36,7 @@ import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.Mockito;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@ -158,10 +162,13 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
return null;
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
List<String> expectedFieldNames = Arrays.asList("foo", "bar", "baz");
List<ExtractedField> extractedFieldList = new ArrayList<>(3);
extractedFieldList.add(new DocValueField("foo", Collections.emptySet()));
extractedFieldList.add(new MultiField("bar", new DocValueField("bar.keyword", Collections.emptySet())));
extractedFieldList.add(new DocValueField("baz", Collections.emptySet()));
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder();
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor(expectedFieldNames);
AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList);
resultProcessor.process(process);
resultProcessor.awaitForCompletion();
@ -177,7 +184,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
assertThat(storedModel.getTags(), contains(JOB_ID));
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
assertThat(storedModel.getModelDefinition(), equalTo(inferenceModel.build()));
assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar", "baz")));
assertThat(storedModel.getDefaultFieldMap(), equalTo(Collections.singletonMap("bar", "bar.keyword")));
assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar.keyword", "baz")));
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed()));
assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations()));
Map<String, Object> metadata = storedModel.getMetadata();
@ -235,8 +243,13 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
return createResultProcessor(Collections.emptyList());
}
private AnalyticsResultProcessor createResultProcessor(List<String> fieldNames) {
return new AnalyticsResultProcessor(
analyticsConfig, dataFrameRowsJoiner, statsHolder, trainedModelProvider, auditor, resultsPersisterService, fieldNames);
private AnalyticsResultProcessor createResultProcessor(List<ExtractedField> fieldNames) {
return new AnalyticsResultProcessor(analyticsConfig,
dataFrameRowsJoiner,
statsHolder,
trainedModelProvider,
auditor,
resultsPersisterService,
fieldNames);
}
}

View File

@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -41,13 +42,16 @@ public class LocalModelTests extends ESTestCase {
public void testClassificationInfer() throws Exception {
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
List<String> inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical");
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(false))
.build();
Model model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields));
Model model = new LocalModel(modelId,
definition,
new TrainedModelInput(inputFields),
Collections.singletonMap("field.foo", "field.foo.keyword"));
Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0);
put("field.bar", 0.5);
@ -68,7 +72,10 @@ public class LocalModelTests extends ESTestCase {
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(true))
.build();
model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields));
model = new LocalModel(modelId,
definition,
new TrainedModelInput(inputFields),
Collections.singletonMap("field.foo", "field.foo.keyword"));
result = getSingleValue(model, fields, new ClassificationConfig(0));
assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), equalTo("not_to_be"));
@ -90,11 +97,14 @@ public class LocalModelTests extends ESTestCase {
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildRegression())
.build();
Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields));
Model model = new LocalModel("regression_model",
trainedModelDefinition,
new TrainedModelInput(inputFields),
Collections.singletonMap("bar", "bar.keyword"));
Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0);
put("field.bar", 0.5);
put("foo", 1.0);
put("bar.keyword", 0.5);
put("categorical", "dog");
}};
@ -114,7 +124,7 @@ public class LocalModelTests extends ESTestCase {
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildRegression())
.build();
Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields));
Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields), null);
Map<String, Object> fields = new HashMap<String, Object>() {{
put("something", 1.0);