This adds a new field for the inference processor. `warning_field` is a place for us to write warnings provided from the inference call. When there are warnings we are not going to write an inference result. The goal of this is to indicate that the data provided was too poor or too different for the model to make an accurate prediction. The user could optionally include the `warning_field`. When it is not provided, it is assumed no warnings were desired to be written. The first of these warnings is when ALL of the input fields are missing. If none of the trained fields are present, we don't bother inferencing against the model and instead provide a warning stating that the fields were missing. Also, this adds checks to not allow duplicated fields during processor creation.
This commit is contained in:
parent
41736dd6c3
commit
4805d8ac7d
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
public class WarningInferenceResults implements InferenceResults {
|
||||
|
||||
public static final String NAME = "warning";
|
||||
public static final ParseField WARNING = new ParseField("warning");
|
||||
|
||||
private final String warning;
|
||||
|
||||
public WarningInferenceResults(String warning) {
|
||||
this.warning = warning;
|
||||
}
|
||||
|
||||
public WarningInferenceResults(StreamInput in) throws IOException {
|
||||
this.warning = in.readString();
|
||||
}
|
||||
|
||||
public String getWarning() {
|
||||
return warning;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(warning);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
WarningInferenceResults that = (WarningInferenceResults) object;
|
||||
return Objects.equals(warning, that.warning);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(warning);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeResult(IngestDocument document, String parentResultField) {
|
||||
ExceptionsHelper.requireNonNull(document, "document");
|
||||
ExceptionsHelper.requireNonNull(parentResultField, "resultField");
|
||||
document.setFieldValue(parentResultField + "." + "warning", warning);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
}
|
|
@ -92,6 +92,7 @@ public final class Messages {
|
|||
public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]";
|
||||
public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED =
|
||||
"Getting model definition is not supported when getting more than one model";
|
||||
public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing";
|
||||
|
||||
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
|
||||
public static final String JOB_AUDIT_CREATED = "Job created";
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class WarningInferenceResultsTests extends AbstractWireSerializingTestCase<WarningInferenceResults> {
|
||||
|
||||
public static WarningInferenceResults createRandomResults() {
|
||||
return new WarningInferenceResults(randomAlphaOfLength(10));
|
||||
}
|
||||
|
||||
public void testWriteResults() {
|
||||
WarningInferenceResults result = new WarningInferenceResults("foo");
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
|
||||
assertThat(document.getFieldValue("result_field.warning", String.class), equalTo("foo"));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected WarningInferenceResults createTestInstance() {
|
||||
return createRandomResults();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<WarningInferenceResults> instanceReader() {
|
||||
return WarningInferenceResults::new;
|
||||
}
|
||||
}
|
|
@ -16,6 +16,8 @@ import org.elasticsearch.common.xcontent.XContentHelper;
|
|||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
|
@ -34,6 +36,8 @@ import java.util.Objects;
|
|||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
public class AnalyticsResultProcessor {
|
||||
|
||||
private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class);
|
||||
|
@ -163,6 +167,10 @@ public class AnalyticsResultProcessor {
|
|||
Instant createTime = Instant.now();
|
||||
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
|
||||
TrainedModelDefinition definition = inferenceModel.build();
|
||||
String dependentVariable = getDependentVariable();
|
||||
List<String> fieldNamesWithoutDependentVariable = fieldNames.stream()
|
||||
.filter(f -> f.equals(dependentVariable) == false)
|
||||
.collect(toList());
|
||||
return TrainedModelConfig.builder()
|
||||
.setModelId(modelId)
|
||||
.setCreatedBy("data-frame-analytics")
|
||||
|
@ -175,11 +183,21 @@ public class AnalyticsResultProcessor {
|
|||
.setEstimatedHeapMemory(definition.ramBytesUsed())
|
||||
.setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations())
|
||||
.setParsedDefinition(inferenceModel)
|
||||
.setInput(new TrainedModelInput(fieldNames))
|
||||
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
.build();
|
||||
}
|
||||
|
||||
private String getDependentVariable() {
|
||||
if (analytics.getAnalysis() instanceof Classification) {
|
||||
return ((Classification)analytics.getAnalysis()).getDependentVariable();
|
||||
}
|
||||
if (analytics.getAnalysis() instanceof Regression) {
|
||||
return ((Regression)analytics.getAnalysis()).getDependentVariable();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) {
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
ActionListener<Boolean> storeListener = ActionListener.wrap(
|
||||
|
|
|
@ -28,6 +28,8 @@ import org.elasticsearch.ingest.PipelineConfiguration;
|
|||
import org.elasticsearch.ingest.Processor;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
|
@ -37,7 +39,9 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
|||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.Consumer;
|
||||
|
@ -146,7 +150,12 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
if (response.getInferenceResults().isEmpty()) {
|
||||
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
|
||||
InferenceResults inferenceResults = response.getInferenceResults().get(0);
|
||||
if (inferenceResults instanceof WarningInferenceResults) {
|
||||
inferenceResults.writeResult(ingestDocument, this.targetField);
|
||||
} else {
|
||||
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
|
||||
}
|
||||
ingestDocument.setFieldValue(targetField + "." + MODEL_ID, modelId);
|
||||
}
|
||||
|
||||
|
@ -164,6 +173,10 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
|
||||
private static final Logger logger = LogManager.getLogger(Factory.class);
|
||||
|
||||
private static final Set<String> RESERVED_ML_FIELD_NAMES = new HashSet<>(Arrays.asList(
|
||||
WarningInferenceResults.WARNING.getPreferredName(),
|
||||
MODEL_ID));
|
||||
|
||||
private final Client client;
|
||||
private final IngestService ingestService;
|
||||
private final InferenceAuditor auditor;
|
||||
|
@ -235,6 +248,7 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD, defaultTargetField);
|
||||
Map<String, String> fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS);
|
||||
InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG));
|
||||
|
||||
return new InferenceProcessor(client,
|
||||
auditor,
|
||||
tag,
|
||||
|
@ -252,7 +266,6 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
|
||||
InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) {
|
||||
ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
|
||||
|
||||
if (inferenceConfig.size() != 1) {
|
||||
throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.",
|
||||
INFERENCE_CONFIG);
|
||||
|
@ -268,10 +281,14 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
|
||||
if (inferenceConfig.containsKey(ClassificationConfig.NAME)) {
|
||||
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
|
||||
return ClassificationConfig.fromMap(valueMap);
|
||||
ClassificationConfig config = ClassificationConfig.fromMap(valueMap);
|
||||
checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField());
|
||||
return config;
|
||||
} else if (inferenceConfig.containsKey(RegressionConfig.NAME)) {
|
||||
checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
|
||||
return RegressionConfig.fromMap(valueMap);
|
||||
RegressionConfig config = RegressionConfig.fromMap(valueMap);
|
||||
checkFieldUniqueness(config.getResultsField());
|
||||
return config;
|
||||
} else {
|
||||
throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
|
||||
inferenceConfig.keySet(),
|
||||
|
@ -279,6 +296,23 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
private static void checkFieldUniqueness(String... fieldNames) {
|
||||
Set<String> duplicatedFieldNames = new HashSet<>();
|
||||
Set<String> currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES);
|
||||
for(String fieldName : fieldNames) {
|
||||
if (currentFieldNames.contains(fieldName)) {
|
||||
duplicatedFieldNames.add(fieldName);
|
||||
} else {
|
||||
currentFieldNames.add(fieldName);
|
||||
}
|
||||
}
|
||||
if (duplicatedFieldNames.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException("Cannot create processor as configured." +
|
||||
" More than one field is configured as {}",
|
||||
duplicatedFieldNames);
|
||||
}
|
||||
}
|
||||
|
||||
void checkSupportedVersion(InferenceConfig config) {
|
||||
if (config.getMinimalSupportedVersion().after(minNodeVersion)) {
|
||||
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION,
|
||||
|
@ -287,6 +321,5 @@ public class InferenceProcessor extends AbstractProcessor {
|
|||
minNodeVersion));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,23 +6,33 @@
|
|||
package org.elasticsearch.xpack.ml.inference.loadingservice;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;
|
||||
|
||||
public class LocalModel implements Model {
|
||||
|
||||
private final TrainedModelDefinition trainedModelDefinition;
|
||||
private final String modelId;
|
||||
private final Set<String> fieldNames;
|
||||
|
||||
public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) {
|
||||
public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition, TrainedModelInput input) {
|
||||
this.trainedModelDefinition = trainedModelDefinition;
|
||||
this.modelId = modelId;
|
||||
this.fieldNames = new HashSet<>(input.getFieldNames());
|
||||
}
|
||||
|
||||
long ramBytesUsed() {
|
||||
|
@ -51,6 +61,11 @@ public class LocalModel implements Model {
|
|||
@Override
|
||||
public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
|
||||
try {
|
||||
if (Sets.haveEmptyIntersection(fieldNames, fields.keySet())) {
|
||||
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
|
||||
return;
|
||||
}
|
||||
|
||||
listener.onResponse(trainedModelDefinition.infer(fields, config));
|
||||
} catch (Exception e) {
|
||||
listener.onFailure(e);
|
||||
|
|
|
@ -141,7 +141,8 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
trainedModelConfig ->
|
||||
modelActionListener.onResponse(new LocalModel(
|
||||
trainedModelConfig.getModelId(),
|
||||
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition())),
|
||||
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(),
|
||||
trainedModelConfig.getInput())),
|
||||
modelActionListener::onFailure
|
||||
));
|
||||
} else {
|
||||
|
@ -198,7 +199,8 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
Queue<ActionListener<Model>> listeners;
|
||||
LocalModel loadedModel = new LocalModel(
|
||||
trainedModelConfig.getModelId(),
|
||||
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition());
|
||||
trainedModelConfig.ensureParsedDefinition(namedXContentRegistry).getModelDefinition(),
|
||||
trainedModelConfig.getInput());
|
||||
synchronized (loadingListeners) {
|
||||
listeners = loadingListeners.remove(modelId);
|
||||
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such
|
||||
|
|
|
@ -171,7 +171,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
assertThat(storedModel.getTags(), contains(JOB_ID));
|
||||
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
|
||||
assertThat(storedModel.getModelDefinition(), equalTo(inferenceModel.build()));
|
||||
assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames));
|
||||
assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar", "baz")));
|
||||
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed()));
|
||||
assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations()));
|
||||
Map<String, Object> metadata = storedModel.getMetadata();
|
||||
|
|
|
@ -240,6 +240,29 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testCreateProcessorWithDuplicateFields() {
|
||||
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
|
||||
clusterService,
|
||||
Settings.EMPTY,
|
||||
ingestService);
|
||||
|
||||
Map<String, Object> regression = new HashMap<String, Object>() {{
|
||||
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
|
||||
put(InferenceProcessor.MODEL_ID, "my_model");
|
||||
put(InferenceProcessor.TARGET_FIELD, "ml");
|
||||
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME,
|
||||
Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning")));
|
||||
}};
|
||||
|
||||
try {
|
||||
processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression);
|
||||
fail("should not have succeeded creating with duplicate fields");
|
||||
} catch (Exception ex) {
|
||||
assertThat(ex.getMessage(), equalTo("Cannot create processor as configured. " +
|
||||
"More than one field is configured as [warning]"));
|
||||
}
|
||||
}
|
||||
|
||||
private static ClusterState buildClusterState(MetaData metaData) {
|
||||
return ClusterState.builder(new ClusterName("_name")).metaData(metaData).build();
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
||||
|
@ -253,4 +254,26 @@ public class InferenceProcessorTests extends ESTestCase {
|
|||
verify(auditor, times(1)).warning(eq("regression_model"), any(String.class));
|
||||
}
|
||||
|
||||
public void testMutateDocumentWithWarningResult() {
|
||||
String targetField = "regression_value";
|
||||
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
|
||||
auditor,
|
||||
"my_processor",
|
||||
"ml",
|
||||
"regression_model",
|
||||
RegressionConfig.EMPTY_PARAMS,
|
||||
Collections.emptyMap());
|
||||
|
||||
Map<String, Object> source = new HashMap<>();
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new WarningInferenceResults("something broke")), true);
|
||||
inferenceProcessor.mutateDocument(response, document);
|
||||
|
||||
assertThat(document.hasField(targetField), is(false));
|
||||
assertThat(document.hasField("ml.warning"), is(true));
|
||||
assertThat(document.hasField("ml.my_processor"), is(false));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,8 +8,10 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
|
|||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
|
@ -22,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
|
@ -38,12 +41,13 @@ public class LocalModelTests extends ESTestCase {
|
|||
|
||||
public void testClassificationInfer() throws Exception {
|
||||
String modelId = "classification_model";
|
||||
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
|
||||
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setTrainedModel(buildClassification(false))
|
||||
.build();
|
||||
|
||||
Model model = new LocalModel(modelId, definition);
|
||||
Model model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields));
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.5);
|
||||
|
@ -64,7 +68,7 @@ public class LocalModelTests extends ESTestCase {
|
|||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setTrainedModel(buildClassification(true))
|
||||
.build();
|
||||
model = new LocalModel(modelId, definition);
|
||||
model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields));
|
||||
result = getSingleValue(model, fields, new ClassificationConfig(0));
|
||||
assertThat(result.value(), equalTo(0.0));
|
||||
assertThat(result.valueAsString(), equalTo("not_to_be"));
|
||||
|
@ -81,11 +85,12 @@ public class LocalModelTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testRegression() throws Exception {
|
||||
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
|
||||
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setTrainedModel(buildRegression())
|
||||
.build();
|
||||
Model model = new LocalModel("regression_model", trainedModelDefinition);
|
||||
Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields));
|
||||
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("foo", 1.0);
|
||||
|
@ -103,16 +108,39 @@ public class LocalModelTests extends ESTestCase {
|
|||
equalTo("Cannot infer using configuration for [classification] when model target_type is [regression]"));
|
||||
}
|
||||
|
||||
public void testAllFieldsMissing() throws Exception {
|
||||
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
|
||||
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setTrainedModel(buildRegression())
|
||||
.build();
|
||||
Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields));
|
||||
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("something", 1.0);
|
||||
put("other", 0.5);
|
||||
put("baz", "dog");
|
||||
}};
|
||||
|
||||
WarningInferenceResults results = (WarningInferenceResults)getInferenceResult(model, fields, RegressionConfig.EMPTY_PARAMS);
|
||||
assertThat(results.getWarning(),
|
||||
equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, "regression_model")));
|
||||
}
|
||||
|
||||
private static SingleValueInferenceResults getSingleValue(Model model,
|
||||
Map<String, Object> fields,
|
||||
InferenceConfig config) throws Exception {
|
||||
return (SingleValueInferenceResults)getInferenceResult(model, fields, config);
|
||||
}
|
||||
|
||||
private static InferenceResults getInferenceResult(Model model, Map<String, Object> fields, InferenceConfig config) throws Exception {
|
||||
PlainActionFuture<InferenceResults> future = new PlainActionFuture<>();
|
||||
model.infer(fields, config, future);
|
||||
return (SingleValueInferenceResults)future.get();
|
||||
return future.get();
|
||||
}
|
||||
|
||||
private static Map<String, String> oneHotMap() {
|
||||
Map<String, String> oneHotEncoding = new HashMap<>();
|
||||
Map<String, String> oneHotEncoding = new HashMap<String, String>();
|
||||
oneHotEncoding.put("cat", "animal_cat");
|
||||
oneHotEncoding.put("dog", "animal_dog");
|
||||
return oneHotEncoding;
|
||||
|
|
|
@ -35,6 +35,7 @@ import org.elasticsearch.threadpool.TestThreadPool;
|
|||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
|
@ -45,6 +46,7 @@ import org.mockito.Mockito;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.net.InetAddress;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -308,6 +310,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
when(definition.ramBytesUsed()).thenReturn(size);
|
||||
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
|
||||
when(trainedModelConfig.getModelDefinition()).thenReturn(definition);
|
||||
when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz")));
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
||||
|
|
|
@ -17,7 +17,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
|||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
|
@ -39,6 +41,7 @@ import java.util.stream.Collectors;
|
|||
|
||||
import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildClassification;
|
||||
import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildRegression;
|
||||
import static org.hamcrest.CoreMatchers.instanceOf;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
@ -63,7 +66,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
oneHotEncoding.put("cat", "animal_cat");
|
||||
oneHotEncoding.put("dog", "animal_dog");
|
||||
TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("field1", "field2")))
|
||||
.setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical")))
|
||||
.setParsedDefinition(new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
|
||||
.setTrainedModel(buildClassification(true)))
|
||||
|
@ -74,7 +77,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
.setEstimatedHeapMemory(0)
|
||||
.build();
|
||||
TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("field1", "field2")))
|
||||
.setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical")))
|
||||
.setParsedDefinition(new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
|
||||
.setTrainedModel(buildRegression()))
|
||||
|
@ -184,6 +187,51 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testInferMissingFields() throws Exception {
|
||||
String modelId = "test-load-models-regression-missing-fields";
|
||||
Map<String, String> oneHotEncoding = new HashMap<>();
|
||||
oneHotEncoding.put("cat", "animal_cat");
|
||||
oneHotEncoding.put("dog", "animal_dog");
|
||||
TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("field1", "field2")))
|
||||
.setParsedDefinition(new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
|
||||
.setTrainedModel(buildRegression()))
|
||||
.setVersion(Version.CURRENT)
|
||||
.setEstimatedOperations(0)
|
||||
.setEstimatedHeapMemory(0)
|
||||
.setCreateTime(Instant.now())
|
||||
.build();
|
||||
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
|
||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||
|
||||
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
|
||||
assertThat(putConfigHolder.get(), is(true));
|
||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
|
||||
|
||||
List<Map<String, Object>> toInferMissingField = new ArrayList<>();
|
||||
toInferMissingField.add(new HashMap<String, Object>() {{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.5);
|
||||
}});
|
||||
|
||||
InternalInferModelAction.Request request = new InternalInferModelAction.Request(
|
||||
modelId,
|
||||
toInferMissingField,
|
||||
RegressionConfig.EMPTY_PARAMS,
|
||||
true);
|
||||
try {
|
||||
InferenceResults result =
|
||||
client().execute(InternalInferModelAction.INSTANCE, request).actionGet().getInferenceResults().get(0);
|
||||
assertThat(result, is(instanceOf(WarningInferenceResults.class)));
|
||||
assertThat(((WarningInferenceResults)result).getWarning(),
|
||||
equalTo(Messages.getMessage(Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
|
||||
} catch (ElasticsearchException ex) {
|
||||
fail("Should not have thrown. Ex: " + ex.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
|
||||
return TrainedModelConfig.builder()
|
||||
.setCreatedBy("ml_test")
|
||||
|
|
Loading…
Reference in New Issue