[ML][Inference] adjust so target_field always has inference result and optionally allow new top classes field in the classification config (#49923) (#49982)

This commit is contained in:
Benjamin Trent 2019-12-09 08:29:45 -05:00 committed by GitHub
parent 62e128f02d
commit 049d854360
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 102 additions and 50 deletions

View File

@ -12,6 +12,8 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -90,13 +92,15 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
}
@Override
public void writeResult(IngestDocument document, String resultField) {
public void writeResult(IngestDocument document, String resultField, InferenceConfig config) {
assert config instanceof ClassificationConfig;
ClassificationConfig classificationConfig = (ClassificationConfig)config;
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(resultField, "resultField");
if (topClasses.isEmpty()) {
document.setFieldValue(resultField, valueAsString());
} else {
document.setFieldValue(resultField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
document.setFieldValue(resultField, valueAsString());
if (topClasses.isEmpty() == false) {
document.setFieldValue(classificationConfig.getTopClassesResultsField(),
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
}

View File

@ -7,10 +7,11 @@ package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
public interface InferenceResults extends NamedXContentObject, NamedWriteable {
void writeResult(IngestDocument document, String resultField);
void writeResult(IngestDocument document, String resultField, InferenceConfig config);
}

View File

@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import java.io.IOException;
import java.util.Objects;
@ -49,7 +50,7 @@ public class RawInferenceResults extends SingleValueInferenceResults {
}
@Override
public void writeResult(IngestDocument document, String resultField) {
public void writeResult(IngestDocument document, String resultField, InferenceConfig config) {
throw new UnsupportedOperationException("[raw] does not support writing inference results");
}

View File

@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -50,7 +51,7 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
}
@Override
public void writeResult(IngestDocument document, String resultField) {
public void writeResult(IngestDocument document, String resultField, InferenceConfig config) {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(resultField, "resultField");
document.setFieldValue(resultField, value());

View File

@ -21,37 +21,52 @@ public class ClassificationConfig implements InferenceConfig {
public static final String NAME = "classification";
public static final String DEFAULT_TOP_CLASSES_RESULT_FIELD = "top_classes";
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TOP_CLASSES_RESULT_FIELD = new ParseField("top_classes_result_field");
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0);
public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0, DEFAULT_TOP_CLASSES_RESULT_FIELD);
private final int numTopClasses;
private final String topClassesResultsField;
public static ClassificationConfig fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULT_FIELD.getPreferredName());
if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
}
return new ClassificationConfig(numTopClasses);
return new ClassificationConfig(numTopClasses, topClassesResultsField);
}
public ClassificationConfig(Integer numTopClasses) {
this(numTopClasses, null);
}
public ClassificationConfig(Integer numTopClasses, String topClassesResultsField) {
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
this.topClassesResultsField = topClassesResultsField == null ? DEFAULT_TOP_CLASSES_RESULT_FIELD : topClassesResultsField;
}
public ClassificationConfig(StreamInput in) throws IOException {
this.numTopClasses = in.readInt();
this.topClassesResultsField = in.readString();
}
public int getNumTopClasses() {
return numTopClasses;
}
public String getTopClassesResultsField() {
return topClassesResultsField;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(numTopClasses);
out.writeString(topClassesResultsField);
}
@Override
@ -59,12 +74,12 @@ public class ClassificationConfig implements InferenceConfig {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassificationConfig that = (ClassificationConfig) o;
return Objects.equals(numTopClasses, that.numTopClasses);
return Objects.equals(numTopClasses, that.numTopClasses) && Objects.equals(topClassesResultsField, that.topClassesResultsField);
}
@Override
public int hashCode() {
return Objects.hash(numTopClasses);
return Objects.hash(numTopClasses, topClassesResultsField);
}
@Override
@ -73,6 +88,7 @@ public class ClassificationConfig implements InferenceConfig {
if (numTopClasses != 0) {
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
}
builder.field(TOP_CLASSES_RESULT_FIELD.getPreferredName(), topClassesResultsField);
builder.endObject();
return builder;
}

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import java.util.Arrays;
import java.util.Collections;
@ -37,7 +38,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
public void testWriteResultsWithClassificationLabel() {
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, "foo", Collections.emptyList());
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
result.writeResult(document, "result_field", ClassificationConfig.EMPTY_PARAMS);
assertThat(document.getFieldValue("result_field", String.class), equalTo("foo"));
}
@ -45,7 +46,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
public void testWriteResultsWithoutClassificationLabel() {
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, Collections.emptyList());
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
result.writeResult(document, "result_field", ClassificationConfig.EMPTY_PARAMS);
assertThat(document.getFieldValue("result_field", String.class), equalTo("1.0"));
}
@ -60,15 +61,17 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
"foo",
entries);
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
result.writeResult(document, "result_field", new ClassificationConfig(3, "bar"));
List<?> list = document.getFieldValue("result_field", List.class);
List<?> list = document.getFieldValue("bar", List.class);
assertThat(list.size(), equalTo(3));
for(int i = 0; i < 3; i++) {
Map<String, Object> map = (Map<String, Object>)list.get(i);
assertThat(map, equalTo(entries.get(i).asValueMap()));
}
assertThat(document.getFieldValue("result_field", String.class), equalTo("foo"));
}
@Override

View File

@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import java.util.HashMap;
@ -24,7 +25,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
public void testWriteResults() {
RegressionInferenceResults result = new RegressionInferenceResults(0.3);
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
result.writeResult(document, "result_field", new RegressionConfig());
assertThat(document.getFieldValue("result_field", Double.class), equalTo(0.3));
}

View File

@ -10,22 +10,27 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
public class ClassificationConfigTests extends AbstractWireSerializingTestCase<ClassificationConfig> {
public static ClassificationConfig randomClassificationConfig() {
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10));
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10),
randomBoolean() ? null : randomAlphaOfLength(10));
}
public void testFromMap() {
ClassificationConfig expected = new ClassificationConfig(0);
ClassificationConfig expected = ClassificationConfig.EMPTY_PARAMS;
assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected));
expected = new ClassificationConfig(3);
assertThat(ClassificationConfig.fromMap(Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3)),
equalTo(expected));
expected = new ClassificationConfig(3, "foo");
Map<String, Object> configMap = new HashMap<>();
configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
configMap.put(ClassificationConfig.TOP_CLASSES_RESULT_FIELD.getPreferredName(), "foo");
assertThat(ClassificationConfig.fromMap(configMap), equalTo(expected));
}
public void testFromMapWithUnknownField() {

View File

@ -148,20 +148,8 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"result_class\",\n" +
" \"inference_config\": {\"classification\":{}},\n" +
" \"model_id\": \"test_classification\",\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"result_class_prob\",\n" +
" \"inference_config\": {\"classification\": {\"num_top_classes\":2}},\n" +
" \"inference_config\": {\"classification\": " +
" {\"num_top_classes\":2, \"top_classes_result_field\": \"result_class_prob\"}},\n" +
" \"model_id\": \"test_classification\",\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +

View File

@ -35,7 +35,6 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@ -154,7 +153,7 @@ public class InferenceProcessor extends AbstractProcessor {
if (response.getInferenceResults().isEmpty()) {
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
}
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField, inferenceConfig);
if (includeModelMetadata) {
ingestDocument.setFieldValue(modelInfoField + "." + MODEL_ID, modelId);
}
@ -227,8 +226,7 @@ public class InferenceProcessor extends AbstractProcessor {
}
@Override
public InferenceProcessor create(Map<String, Processor.Factory> processorFactories, String tag, Map<String, Object> config)
throws Exception {
public InferenceProcessor create(Map<String, Processor.Factory> processorFactories, String tag, Map<String, Object> config) {
if (this.maxIngestProcessors <= currentInferenceProcessors) {
throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. " +
@ -267,7 +265,7 @@ public class InferenceProcessor extends AbstractProcessor {
this.maxIngestProcessors = maxIngestProcessors;
}
InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) throws IOException {
InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) {
ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
if (inferenceConfig.size() != 1) {
@ -284,7 +282,7 @@ public class InferenceProcessor extends AbstractProcessor {
Map<String, Object> valueMap = (Map<String, Object>)value;
if (inferenceConfig.containsKey(ClassificationConfig.NAME)) {
checkSupportedVersion(new ClassificationConfig(0));
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
return ClassificationConfig.fromMap(valueMap);
} else if (inferenceConfig.containsKey(RegressionConfig.NAME)) {
checkSupportedVersion(new RegressionConfig());

View File

@ -51,7 +51,7 @@ public class InferenceProcessorTests extends ESTestCase {
"my_processor",
targetField,
"classification_model",
new ClassificationConfig(0),
ClassificationConfig.EMPTY_PARAMS,
Collections.emptyMap(),
"ml.my_processor",
true);
@ -78,7 +78,7 @@ public class InferenceProcessorTests extends ESTestCase {
"my_processor",
targetField,
"classification_model",
new ClassificationConfig(2),
new ClassificationConfig(2, null),
Collections.emptyMap(),
"ml.my_processor",
true);
@ -96,10 +96,44 @@ public class InferenceProcessorTests extends ESTestCase {
true);
inferenceProcessor.mutateDocument(response, document);
assertThat((List<Map<?,?>>)document.getFieldValue(targetField, List.class),
assertThat((List<Map<?,?>>)document.getFieldValue(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULT_FIELD, List.class),
contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
assertThat(document.getFieldValue("ml", Map.class),
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model"))));
assertThat(document.getFieldValue(targetField, String.class), equalTo("foo"));
}
@SuppressWarnings("unchecked")
public void testMutateDocumentClassificationTopNClassesWithSpecificField() {
String targetField = "classification_value_probabilities";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
auditor,
"my_processor",
targetField,
"classification_model",
new ClassificationConfig(2, "my_top_classes"),
Collections.emptyMap(),
"ml.my_processor",
true);
Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes)),
true);
inferenceProcessor.mutateDocument(response, document);
assertThat((List<Map<?,?>>)document.getFieldValue("my_top_classes", List.class),
contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
assertThat(document.getFieldValue("ml", Map.class),
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model"))));
assertThat(document.getFieldValue(targetField, String.class), equalTo("foo"));
}
public void testMutateDocumentRegression() {
@ -194,7 +228,7 @@ public class InferenceProcessorTests extends ESTestCase {
"my_processor",
"my_field",
modelId,
new ClassificationConfig(topNClasses),
new ClassificationConfig(topNClasses, null),
Collections.emptyMap(),
"ml.my_processor",
false);
@ -225,7 +259,7 @@ public class InferenceProcessorTests extends ESTestCase {
"my_processor",
"my_field",
modelId,
new ClassificationConfig(topNClasses),
new ClassificationConfig(topNClasses, null),
fieldMapping,
"ml.my_processor",
false);

View File

@ -131,7 +131,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
// Test classification
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfig(0), true);
request = new InternalInferModelAction.Request(modelId2, toInfer, ClassificationConfig.EMPTY_PARAMS, true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults()
.stream()
@ -140,7 +140,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
contains("not_to_be", "to_be"));
// Get top classes
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2), true);
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2, null), true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
ClassificationInferenceResults classificationInferenceResults =
@ -159,7 +159,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
// Test that top classes restrict the number returned
request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1), true);
request = new InternalInferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1, null), true);
response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0);