[ML][Inference] fix support for nested fields (#50258) (#50335)

This fixes support for nested fields

We now support fully nested, fully collapsed, or a mix of both on inference docs.

ES mappings allow the `_source` to be any combination of nested objects + dot delimited fields.
So, we should do our best to find the best path down the Map for the desired field.
This commit is contained in:
Benjamin Trent 2019-12-18 15:47:06 -05:00 committed by GitHub
parent 06a24f09cf
commit 4396a1f78b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 582 additions and 33 deletions

View File

@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.io.IOException;
import java.util.Collections;
@ -103,7 +104,7 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
@Override
public void process(Map<String, Object> fields) {
Object value = fields.get(field);
Object value = MapHelper.dig(field, fields);
if (value == null) {
return;
}

View File

@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.io.IOException;
import java.util.Collections;
@ -86,7 +87,7 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
@Override
public void process(Map<String, Object> fields) {
Object value = fields.get(field);
Object value = MapHelper.dig(field, fields);
if (value == null) {
return;
}

View File

@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.io.IOException;
import java.util.Collections;
@ -114,7 +115,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
@Override
public void process(Map<String, Object> fields) {
Object value = fields.get(field);
Object value = MapHelper.dig(field, fields);
if (value == null) {
return;
}

View File

@ -28,6 +28,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfi
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.io.IOException;
import java.util.ArrayDeque;
@ -129,7 +130,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
List<Double> features = featureNames.stream().map(f -> InferenceHelpers.toDouble(fields.get(f))).collect(Collectors.toList());
List<Double> features = featureNames.stream()
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
.collect(Collectors.toList());
return infer(features, config);
}

View File

@ -0,0 +1,133 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.utils;
import org.elasticsearch.common.Nullable;
import java.util.Arrays;
import java.util.Map;
import java.util.Stack;
public final class MapHelper {
private MapHelper() {}
/**
* This eagerly digs (depth first search, longer keys first) through the map by tokenizing the provided path on '.'.
*
* It is possible for ES _source docs to have "mixed" path formats. So, we should search all potential paths
* given the current knowledge of the map.
*
* Examples:
*
* The following maps would return `2` given the path "a.b.c.d"
*
* {
* "a.b.c.d" : 2
* }
* {
* "a" :{"b": {"c": {"d" : 2}}}
* }
* {
* "a" :{"b.c": {"d" : 2}}}
* }
* {
* "a" :{"b": {"c": {"d" : 2}}},
* "a.b" :{"c": {"d" : 5}} // we choose the first one found, we go down longer keys first
* }
* {
* "a" :{"b": {"c": {"NOT_d" : 2, "d": 2}}}
* }
*
* Conceptual "Worse case" 5 potential paths explored for "a.b.c.d" until 2 is finally returned
* {
* "a.b.c": {"not_d": 2},
* "a.b": {"c": {"not_d": 2}},
* "a": {"b.c": {"not_d": 2}},
* "a": {"b" :{ "c.not_d": 2}},
* "a" :{"b": {"c": {"not_d" : 2}}},
* "a" :{"b": {"c": {"d" : 2}}},
* }
*
* We don't exhaustively create all potential paths.
* If we did, this would result in 2^n-1 total possible paths, where {@code n = path.split("\\.").length}.
*
* Instead we lazily create potential paths once we know that they are possibilities.
*
* @param path Dot delimited path containing the field desired
* @param map The {@link Map} map to dig
* @return The found object. Returns {@code null} if not found
*/
@Nullable
public static Object dig(String path, Map<String, Object> map) {
// short cut before search
if (map.keySet().contains(path)) {
return map.get(path);
}
String[] fields = path.split("\\.");
if (Arrays.stream(fields).anyMatch(String::isEmpty)) {
throw new IllegalArgumentException("Empty path detected. Invalid field name");
}
Stack<PotentialPath> pathStack = new Stack<>();
pathStack.push(new PotentialPath(map, 0));
return explore(fields, pathStack);
}
@SuppressWarnings("unchecked")
private static Object explore(String[] path, Stack<PotentialPath> pathStack) {
while (pathStack.empty() == false) {
PotentialPath potentialPath = pathStack.pop();
int endPos = potentialPath.pathPosition + 1;
int startPos = potentialPath.pathPosition;
Map<String, Object> map = potentialPath.map;
String candidateKey = null;
while(endPos <= path.length) {
candidateKey = mergePath(path, startPos, endPos);
Object next = map.get(candidateKey);
if (endPos == path.length && next != null) { // exit early, we reached the full path and found something
return next;
}
if (next instanceof Map<?, ?>) { // we found another map, continue exploring down this path
pathStack.push(new PotentialPath((Map<String, Object>)next, endPos));
}
endPos++;
}
if (candidateKey != null && map.containsKey(candidateKey)) { //exit early
return map.get(candidateKey);
}
}
return null;
}
private static String mergePath(String[] path, int start, int end) {
if (start + 1 == end) { // early exit, no need to create sb
return path[start];
}
StringBuilder sb = new StringBuilder();
for (int i = start; i < end - 1; i++) {
sb.append(path[i]);
sb.append(".");
}
sb.append(path[end - 1]);
return sb.toString();
}
private static class PotentialPath {
// Pointer to where to start exploring
private final Map<String, Object> map;
// Where in the requested path are we
private final int pathPosition;
private PotentialPath(Map<String, Object> map, int pathPosition) {
this.map = map;
this.pathPosition = pathPosition;
}
}
}

View File

@ -65,4 +65,22 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
testProcess(encoding, fieldValues, matchers);
}
public void testProcessWithNestedField() {
String field = "categorical.child";
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
v -> randomDoubleBetween(0.0, 1.0, false)));
String encodedFeatureName = "encoded";
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap);
Map<String, Object> fieldValues = new HashMap<String, Object>() {{
put("categorical", new HashMap<String, Object>(){{
put("child", "farequote");
}});
}};
encoding.process(fieldValues);
assertThat(fieldValues.get("encoded"), equalTo(valueMap.get("farequote")));
}
}

View File

@ -67,4 +67,19 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
testProcess(encoding, fieldValues, matchers);
}
public void testProcessWithNestedField() {
String field = "categorical.child";
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> "Column_" + v.toString()));
OneHotEncoding encoding = new OneHotEncoding(field, valueMap);
Map<String, Object> fieldValues = new HashMap<String, Object>() {{
put("categorical", new HashMap<String, Object>(){{
put("child", "farequote");
}});
}};
encoding.process(fieldValues);
assertThat(fieldValues.get("Column_farequote"), equalTo(1));
}
}

View File

@ -68,4 +68,24 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
testProcess(encoding, fieldValues, matchers);
}
public void testProcessWithNestedField() {
String field = "categorical.child";
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
v -> randomDoubleBetween(0.0, 1.0, false)));
String encodedFeatureName = "encoded";
Double defaultvalue = randomDouble();
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue);
Map<String, Object> fieldValues = new HashMap<String, Object>() {{
put("categorical", new HashMap<String, Object>(){{
put("child", "farequote");
}});
}};
encoding.process(fieldValues);
assertThat(fieldValues.get("encoded"), equalTo(valueMap.get("farequote")));
}
}

View File

@ -445,6 +445,63 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
}
public void testInferNestedFields() {
List<String> featureNames = Arrays.asList("foo.baz", "bar.biz");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.3))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.1))
.addNode(TreeNode.builder(4).setLeafValue(0.2)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(1.5))
.addNode(TreeNode.builder(2).setLeafValue(0.9))
.build();
Ensemble ensemble = Ensemble.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2))
.setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5}))
.build();
Map<String, Object> featureMap = new HashMap<String, Object>() {{
put("foo", new HashMap<String, Object>(){{
put("baz", 0.4);
}});
put("bar", new HashMap<String, Object>(){{
put("biz", 0.0);
}});
}};
assertThat(0.9,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
featureMap = new HashMap<String, Object>() {{
put("foo", new HashMap<String, Object>(){{
put("baz", 2.0);
}});
put("bar", new HashMap<String, Object>(){{
put("biz", 0.7);
}});
}};
assertThat(0.5,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
}
public void testOperationsEstimations() {
Tree tree1 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 2);
Tree tree2 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);

View File

@ -169,6 +169,58 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
}
public void testInferNestedFields() {
// Build a tree with 2 nodes and 3 leaves using 2 features
// The leaves have unique values 0.1, 0.2, 0.3
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
builder.addLeaf(rootNode.getRightChild(), 0.3);
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
builder.addLeaf(leftChildNode.getLeftChild(), 0.1);
builder.addLeaf(leftChildNode.getRightChild(), 0.2);
List<String> featureNames = Arrays.asList("foo.baz", "bar.biz");
Tree tree = builder.setFeatureNames(featureNames).build();
// This feature vector should hit the right child of the root node
Map<String, Object> featureMap = new HashMap<String, Object>() {{
put("foo", new HashMap<String, Object>(){{
put("baz", 0.6);
}});
put("bar", new HashMap<String, Object>(){{
put("biz", 0.0);
}});
}};
assertThat(0.3,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
// This should hit the left child of the left child of the root node
// i.e. it takes the path left, left
featureMap = new HashMap<String, Object>() {{
put("foo", new HashMap<String, Object>(){{
put("baz", 0.3);
}});
put("bar", new HashMap<String, Object>(){{
put("biz", 0.7);
}});
}};
assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
// This should hit the right child of the left child of the root node
// i.e. it takes the path left, right
featureMap = new HashMap<String, Object>() {{
put("foo", new HashMap<String, Object>(){{
put("baz", 0.3);
}});
put("bar", new HashMap<String, Object>(){{
put("biz", 0.9);
}});
}};
assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
}
public void testTreeClassificationProbability() {
// Build a tree with 2 nodes and 3 leaves using 2 features
// The leaves have unique values 0.1, 0.2, 0.3

View File

@ -0,0 +1,193 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.utils;
import org.elasticsearch.test.ESTestCase;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
public class MapHelperTests extends ESTestCase {
public void testAbsolutePathStringAsKey() {
String path = "a.b.c.d";
Map<String, Object> map = Collections.singletonMap(path, 2);
assertThat(MapHelper.dig(path, map), equalTo(2));
assertThat(MapHelper.dig(path, Collections.emptyMap()), is(nullValue()));
}
public void testSimplePath() {
String path = "a.b.c.d";
Map<String, Object> map = Collections.singletonMap("a",
Collections.singletonMap("b",
Collections.singletonMap("c",
Collections.singletonMap("d", 2))));
assertThat(MapHelper.dig(path, map), equalTo(2));
map = Collections.singletonMap("a",
Collections.singletonMap("b",
Collections.singletonMap("e", // Not part of path
Collections.singletonMap("d", 2))));
assertThat(MapHelper.dig(path, map), is(nullValue()));
}
public void testSimplePathReturningMap() {
String path = "a.b.c";
Map<String, Object> map = Collections.singletonMap("a",
Collections.singletonMap("b",
Collections.singletonMap("c",
Collections.singletonMap("d", 2))));
assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2)));
}
public void testSimpleMixedPath() {
String path = "a.b.c.d";
Map<String, Object> map = Collections.singletonMap("a",
Collections.singletonMap("b.c",
Collections.singletonMap("d", 2)));
assertThat(MapHelper.dig(path, map), equalTo(2));
map = Collections.singletonMap("a.b",
Collections.singletonMap("c",
Collections.singletonMap("d", 2)));
assertThat(MapHelper.dig(path, map), equalTo(2));
map = Collections.singletonMap("a.b.c",
Collections.singletonMap("d", 2));
assertThat(MapHelper.dig(path, map), equalTo(2));
map = Collections.singletonMap("a",
Collections.singletonMap("b",
Collections.singletonMap("c.d", 2)));
assertThat(MapHelper.dig(path, map), equalTo(2));
map = Collections.singletonMap("a",
Collections.singletonMap("b.c.d", 2));
assertThat(MapHelper.dig(path, map), equalTo(2));
map = Collections.singletonMap("a.b",
Collections.singletonMap("c.d", 2));
assertThat(MapHelper.dig(path, map), equalTo(2));
map = Collections.singletonMap("a",
Collections.singletonMap("b.foo",
Collections.singletonMap("d", 2)));
assertThat(MapHelper.dig(path, map), is(nullValue()));
map = Collections.singletonMap("a",
Collections.singletonMap("b.c",
Collections.singletonMap("foo", 2)));
assertThat(MapHelper.dig(path, map), is(nullValue()));
map = Collections.singletonMap("x",
Collections.singletonMap("b.c",
Collections.singletonMap("d", 2)));
assertThat(MapHelper.dig(path, map), is(nullValue()));
}
public void testSimpleMixedPathReturningMap() {
String path = "a.b.c";
Map<String, Object> map = Collections.singletonMap("a",
Collections.singletonMap("b.c",
Collections.singletonMap("d", 2)));
assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2)));
map = Collections.singletonMap("a",
Collections.singletonMap("b.foo",
Collections.singletonMap("d", 2)));
assertThat(MapHelper.dig(path, map), is(nullValue()));
map = Collections.singletonMap("a",
Collections.singletonMap("b.not_c",
Collections.singletonMap("foo", 2)));
assertThat(MapHelper.dig(path, map), is(nullValue()));
map = Collections.singletonMap("x",
Collections.singletonMap("b.c",
Collections.singletonMap("d", 2)));
assertThat(MapHelper.dig(path, map), is(nullValue()));
}
public void testMultiplePotentialPaths() {
String path = "a.b.c.d";
Map<String, Object> map = new LinkedHashMap<String, Object>() {{
put("a", Collections.singletonMap("b",
Collections.singletonMap("c",
Collections.singletonMap("not_d", 5))));
put("a.b", Collections.singletonMap("c", Collections.singletonMap("d", 2)));
}};
assertThat(MapHelper.dig(path, map), equalTo(2));
map = new LinkedHashMap<String, Object>() {{
put("a", Collections.singletonMap("b",
Collections.singletonMap("c",
Collections.singletonMap("d", 2))));
put("a.b", Collections.singletonMap("c", Collections.singletonMap("not_d", 5)));
}};
assertThat(MapHelper.dig(path, map), equalTo(2));
map = new LinkedHashMap<String, Object>() {{
put("a", Collections.singletonMap("b",
new HashMap<String, Object>() {{
put("c", Collections.singletonMap("not_d", 5));
put("c.d", 2);
}}));
}};
assertThat(MapHelper.dig(path, map), equalTo(2));
map = new LinkedHashMap<String, Object>() {{
put("a", Collections.singletonMap("b",
new HashMap<String, Object>() {{
put("c", Collections.singletonMap("d", 2));
put("c.not_d", 5);
}}));
}};
assertThat(MapHelper.dig(path, map), equalTo(2));
map = new LinkedHashMap<String, Object>() {{
put("a", Collections.singletonMap("b",
Collections.singletonMap("c",
Collections.singletonMap("not_d", 5))));
put("a.b", Collections.singletonMap("c", Collections.singletonMap("not_d", 2)));
}};
assertThat(MapHelper.dig(path, map), is(nullValue()));
}
public void testMultiplePotentialPathsReturningMap() {
String path = "a.b.c";
Map<String, Object> map = new LinkedHashMap<String, Object>() {{
put("a", Collections.singletonMap("b",
Collections.singletonMap("c",
Collections.singletonMap("d", 2))));
put("a.b", Collections.singletonMap("not_c", Collections.singletonMap("d", 2)));
}};
assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2)));
map = new LinkedHashMap<String, Object>() {{
put("a", Collections.singletonMap("b",
Collections.singletonMap("not_c",
Collections.singletonMap("d", 2))));
put("a.b", Collections.singletonMap("c", Collections.singletonMap("d", 2)));
}};
assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2)));
map = new LinkedHashMap<String, Object>() {{
put("a", Collections.singletonMap("b",
Collections.singletonMap("not_c",
Collections.singletonMap("d", 2))));
put("a.b", Collections.singletonMap("not_c", Collections.singletonMap("d", 2)));
}};
assertThat(MapHelper.dig(path, map), is(nullValue()));
}
}

View File

@ -35,6 +35,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import java.util.Arrays;
@ -128,7 +129,7 @@ public class InferenceProcessor extends AbstractProcessor {
Map<String, Object> fields = new HashMap<>(ingestDocument.getSourceAndMetadata());
if (fieldMapping != null) {
fieldMapping.forEach((src, dest) -> {
Object srcValue = fields.remove(src);
Object srcValue = MapHelper.dig(src, fields);
if (srcValue != null) {
fields.put(dest, srcValue);
}

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
@ -16,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import java.util.HashSet;
import java.util.Map;
@ -61,7 +61,7 @@ public class LocalModel implements Model {
@Override
public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
try {
if (Sets.haveEmptyIntersection(fieldNames, fields.keySet())) {
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
return;
}

View File

@ -181,7 +181,7 @@ public class InferenceProcessorTests extends ESTestCase {
String modelId = "model";
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
Map<String, String> fieldMapping = new HashMap<String, String>(3) {{
Map<String, String> fieldMapping = new HashMap<String, String>(5) {{
put("value1", "new_value1");
put("value2", "new_value2");
put("categorical", "new_categorical");
@ -195,7 +195,7 @@ public class InferenceProcessorTests extends ESTestCase {
new ClassificationConfig(topNClasses, null, null),
fieldMapping);
Map<String, Object> source = new HashMap<String, Object>(3){{
Map<String, Object> source = new HashMap<String, Object>(5){{
put("value1", 1);
put("categorical", "foo");
put("un_touched", "bar");
@ -203,8 +203,46 @@ public class InferenceProcessorTests extends ESTestCase {
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
Map<String, Object> expectedMap = new HashMap<String, Object>(2) {{
Map<String, Object> expectedMap = new HashMap<String, Object>(7) {{
put("new_value1", 1);
put("value1", 1);
put("categorical", "foo");
put("new_categorical", "foo");
put("un_touched", "bar");
}};
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap));
}
public void testGenerateWithMappingNestedFields() {
String modelId = "model";
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
Map<String, String> fieldMapping = new HashMap<String, String>(5) {{
put("value1.foo", "new_value1");
put("value2", "new_value2");
put("categorical.bar", "new_categorical");
}};
InferenceProcessor processor = new InferenceProcessor(client,
auditor,
"my_processor",
"my_field",
modelId,
new ClassificationConfig(topNClasses, null, null),
fieldMapping);
Map<String, Object> source = new HashMap<String, Object>(5){{
put("value1", Collections.singletonMap("foo", 1));
put("categorical.bar", "foo");
put("un_touched", "bar");
}};
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
Map<String, Object> expectedMap = new HashMap<String, Object>(7) {{
put("new_value1", 1);
put("value1", Collections.singletonMap("foo", 1));
put("categorical.bar", "foo");
put("new_categorical", "foo");
put("un_touched", "bar");
}};

View File

@ -41,7 +41,7 @@ public class LocalModelTests extends ESTestCase {
public void testClassificationInfer() throws Exception {
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(false))
@ -49,8 +49,8 @@ public class LocalModelTests extends ESTestCase {
Model model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields));
Map<String, Object> fields = new HashMap<String, Object>() {{
put("foo", 1.0);
put("bar", 0.5);
put("field.foo", 1.0);
put("field.bar", 0.5);
put("categorical", "dog");
}};
@ -93,8 +93,8 @@ public class LocalModelTests extends ESTestCase {
Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields));
Map<String, Object> fields = new HashMap<String, Object>() {{
put("foo", 1.0);
put("bar", 0.5);
put("field.foo", 1.0);
put("field.bar", 0.5);
put("categorical", "dog");
}};
@ -147,7 +147,7 @@ public class LocalModelTests extends ESTestCase {
}
public static TrainedModel buildClassification(boolean includeLabels) {
List<String> featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog");
List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
@ -193,7 +193,7 @@ public class LocalModelTests extends ESTestCase {
}
public static TrainedModel buildRegression() {
List<String> featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog");
List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)

View File

@ -66,9 +66,9 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
oneHotEncoding.put("cat", "animal_cat");
oneHotEncoding.put("dog", "animal_dog");
TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2)
.setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical")))
.setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
.setParsedDefinition(new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
.setTrainedModel(buildClassification(true)))
.setVersion(Version.CURRENT)
.setLicenseLevel(License.OperationMode.PLATINUM.description())
@ -77,9 +77,9 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
.setEstimatedHeapMemory(0)
.build();
TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1)
.setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical")))
.setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
.setParsedDefinition(new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
.setTrainedModel(buildRegression()))
.setVersion(Version.CURRENT)
.setEstimatedOperations(0)
@ -99,26 +99,42 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
List<Map<String, Object>> toInfer = new ArrayList<>();
toInfer.add(new HashMap<String, Object>() {{
put("foo", 1.0);
put("bar", 0.5);
put("categorical", "dog");
put("field", new HashMap<String, Object>(){{
put("foo", 1.0);
put("bar", 0.5);
}});
put("other", new HashMap<String, String>(){{
put("categorical", "dog");
}});
}});
toInfer.add(new HashMap<String, Object>() {{
put("foo", 0.9);
put("bar", 1.5);
put("categorical", "cat");
put("field", new HashMap<String, Object>(){{
put("foo", 0.9);
put("bar", 1.5);
}});
put("other", new HashMap<String, String>(){{
put("categorical", "cat");
}});
}});
List<Map<String, Object>> toInfer2 = new ArrayList<>();
toInfer2.add(new HashMap<String, Object>() {{
put("foo", 0.0);
put("bar", 0.01);
put("categorical", "dog");
put("field", new HashMap<String, Object>(){{
put("foo", 0.0);
put("bar", 0.01);
}});
put("other", new HashMap<String, String>(){{
put("categorical", "dog");
}});
}});
toInfer2.add(new HashMap<String, Object>() {{
put("foo", 1.0);
put("bar", 0.0);
put("categorical", "cat");
put("field", new HashMap<String, Object>(){{
put("foo", 1.0);
put("bar", 0.0);
}});
put("other", new HashMap<String, String>(){{
put("categorical", "cat");
}});
}});
// Test regression