From 4396a1f78b8c9cc7ff054182b371cb1b1d224b36 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 18 Dec 2019 15:47:06 -0500 Subject: [PATCH] [ML][Inference] fix support for nested fields (#50258) (#50335) This fixes support for nested fields We now support fully nested, fully collapsed, or a mix of both on inference docs. ES mappings allow the `_source` to be any combination of nested objects + dot delimited fields. So, we should do our best to find the best path down the Map for the desired field. --- .../preprocessing/FrequencyEncoding.java | 3 +- .../preprocessing/OneHotEncoding.java | 3 +- .../preprocessing/TargetMeanEncoding.java | 3 +- .../ml/inference/trainedmodel/tree/Tree.java | 5 +- .../xpack/core/ml/utils/MapHelper.java | 133 ++++++++++++ .../preprocessing/FrequencyEncodingTests.java | 18 ++ .../preprocessing/OneHotEncodingTests.java | 15 ++ .../TargetMeanEncodingTests.java | 20 ++ .../trainedmodel/ensemble/EnsembleTests.java | 57 ++++++ .../trainedmodel/tree/TreeTests.java | 52 +++++ .../xpack/core/ml/utils/MapHelperTests.java | 193 ++++++++++++++++++ .../inference/ingest/InferenceProcessor.java | 3 +- .../inference/loadingservice/LocalModel.java | 4 +- .../ingest/InferenceProcessorTests.java | 44 +++- .../loadingservice/LocalModelTests.java | 14 +- .../integration/ModelInferenceActionIT.java | 48 +++-- 16 files changed, 582 insertions(+), 33 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MapHelper.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/MapHelperTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java index ed693460edc..e9606d53ae2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.MapHelper; import java.io.IOException; import java.util.Collections; @@ -103,7 +104,7 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP @Override public void process(Map fields) { - Object value = fields.get(field); + Object value = MapHelper.dig(field, fields); if (value == null) { return; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java index a4924a277c0..9bb2537b61e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.MapHelper; import java.io.IOException; import java.util.Collections; @@ -86,7 +87,7 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars @Override public void process(Map fields) { - Object value = fields.get(field); + Object value = MapHelper.dig(field, fields); if (value == null) { return; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java index 8276fc2c8fe..19c3cadbbef 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.MapHelper; import java.io.IOException; import java.util.Collections; @@ -114,7 +115,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly @Override public void process(Map fields) { - Object value = fields.get(field); + Object value = MapHelper.dig(field, fields); if (value == null) { return; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index b137c8f28d5..831838e0f7d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -28,6 +28,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfi import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.MapHelper; import java.io.IOException; import java.util.ArrayDeque; @@ -129,7 +130,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); } - List features = featureNames.stream().map(f -> InferenceHelpers.toDouble(fields.get(f))).collect(Collectors.toList()); + List features = featureNames.stream() + .map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields))) + .collect(Collectors.toList()); return infer(features, config); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MapHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MapHelper.java new file mode 100644 index 00000000000..dcd74af1581 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MapHelper.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.utils; + +import org.elasticsearch.common.Nullable; + +import java.util.Arrays; +import java.util.Map; +import java.util.Stack; + +public final class MapHelper { + + private MapHelper() {} + + /** + * This eagerly digs (depth first search, longer keys first) through the map by tokenizing the provided path on '.'. + * + * It is possible for ES _source docs to have "mixed" path formats. So, we should search all potential paths + * given the current knowledge of the map. + * + * Examples: + * + * The following maps would return `2` given the path "a.b.c.d" + * + * { + * "a.b.c.d" : 2 + * } + * { + * "a" :{"b": {"c": {"d" : 2}}} + * } + * { + * "a" :{"b.c": {"d" : 2}}} + * } + * { + * "a" :{"b": {"c": {"d" : 2}}}, + * "a.b" :{"c": {"d" : 5}} // we choose the first one found, we go down longer keys first + * } + * { + * "a" :{"b": {"c": {"NOT_d" : 2, "d": 2}}} + * } + * + * Conceptual "Worse case" 5 potential paths explored for "a.b.c.d" until 2 is finally returned + * { + * "a.b.c": {"not_d": 2}, + * "a.b": {"c": {"not_d": 2}}, + * "a": {"b.c": {"not_d": 2}}, + * "a": {"b" :{ "c.not_d": 2}}, + * "a" :{"b": {"c": {"not_d" : 2}}}, + * "a" :{"b": {"c": {"d" : 2}}}, + * } + * + * We don't exhaustively create all potential paths. + * If we did, this would result in 2^n-1 total possible paths, where {@code n = path.split("\\.").length}. + * + * Instead we lazily create potential paths once we know that they are possibilities. + * + * @param path Dot delimited path containing the field desired + * @param map The {@link Map} map to dig + * @return The found object. Returns {@code null} if not found + */ + @Nullable + public static Object dig(String path, Map map) { + // short cut before search + if (map.keySet().contains(path)) { + return map.get(path); + } + String[] fields = path.split("\\."); + if (Arrays.stream(fields).anyMatch(String::isEmpty)) { + throw new IllegalArgumentException("Empty path detected. Invalid field name"); + } + Stack pathStack = new Stack<>(); + pathStack.push(new PotentialPath(map, 0)); + return explore(fields, pathStack); + } + + @SuppressWarnings("unchecked") + private static Object explore(String[] path, Stack pathStack) { + while (pathStack.empty() == false) { + PotentialPath potentialPath = pathStack.pop(); + int endPos = potentialPath.pathPosition + 1; + int startPos = potentialPath.pathPosition; + Map map = potentialPath.map; + String candidateKey = null; + while(endPos <= path.length) { + candidateKey = mergePath(path, startPos, endPos); + Object next = map.get(candidateKey); + if (endPos == path.length && next != null) { // exit early, we reached the full path and found something + return next; + } + if (next instanceof Map) { // we found another map, continue exploring down this path + pathStack.push(new PotentialPath((Map)next, endPos)); + } + endPos++; + } + if (candidateKey != null && map.containsKey(candidateKey)) { //exit early + return map.get(candidateKey); + } + } + + return null; + } + + private static String mergePath(String[] path, int start, int end) { + if (start + 1 == end) { // early exit, no need to create sb + return path[start]; + } + + StringBuilder sb = new StringBuilder(); + for (int i = start; i < end - 1; i++) { + sb.append(path[i]); + sb.append("."); + } + sb.append(path[end - 1]); + return sb.toString(); + } + + private static class PotentialPath { + + // Pointer to where to start exploring + private final Map map; + // Where in the requested path are we + private final int pathPosition; + + private PotentialPath(Map map, int pathPosition) { + this.map = map; + this.pathPosition = pathPosition; + } + + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java index 4c0497fa409..590a1197974 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java @@ -65,4 +65,22 @@ public class FrequencyEncodingTests extends PreProcessingTests values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5); + Map valueMap = values.stream().collect(Collectors.toMap(Object::toString, + v -> randomDoubleBetween(0.0, 1.0, false))); + String encodedFeatureName = "encoded"; + FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap); + + Map fieldValues = new HashMap() {{ + put("categorical", new HashMap(){{ + put("child", "farequote"); + }}); + }}; + + encoding.process(fieldValues); + assertThat(fieldValues.get("encoded"), equalTo(valueMap.get("farequote"))); + } + } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java index 8b35b77b5a6..18651d1f0bf 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java @@ -67,4 +67,19 @@ public class OneHotEncodingTests extends PreProcessingTests { testProcess(encoding, fieldValues, matchers); } + public void testProcessWithNestedField() { + String field = "categorical.child"; + List values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5); + Map valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> "Column_" + v.toString())); + OneHotEncoding encoding = new OneHotEncoding(field, valueMap); + Map fieldValues = new HashMap() {{ + put("categorical", new HashMap(){{ + put("child", "farequote"); + }}); + }}; + + encoding.process(fieldValues); + assertThat(fieldValues.get("Column_farequote"), equalTo(1)); + } + } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java index e2aaf1e1256..babb3a33017 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java @@ -68,4 +68,24 @@ public class TargetMeanEncodingTests extends PreProcessingTests values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5); + Map valueMap = values.stream().collect(Collectors.toMap(Object::toString, + v -> randomDoubleBetween(0.0, 1.0, false))); + String encodedFeatureName = "encoded"; + Double defaultvalue = randomDouble(); + TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue); + + Map fieldValues = new HashMap() {{ + put("categorical", new HashMap(){{ + put("child", "farequote"); + }}); + }}; + + encoding.process(fieldValues); + + assertThat(fieldValues.get("encoded"), equalTo(valueMap.get("farequote"))); + } + } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 959bbc7b207..260001569fa 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -445,6 +445,63 @@ public class EnsembleTests extends AbstractSerializingTestCase { closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); } + public void testInferNestedFields() { + List featureNames = Arrays.asList("foo.baz", "bar.biz"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(0.3)) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.1)) + .addNode(TreeNode.builder(4).setLeafValue(0.2)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + Ensemble ensemble = Ensemble.builder() + .setTargetType(TargetType.REGRESSION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2)) + .setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5})) + .build(); + + Map featureMap = new HashMap() {{ + put("foo", new HashMap(){{ + put("baz", 0.4); + }}); + put("bar", new HashMap(){{ + put("biz", 0.0); + }}); + }}; + assertThat(0.9, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + + featureMap = new HashMap() {{ + put("foo", new HashMap(){{ + put("baz", 2.0); + }}); + put("bar", new HashMap(){{ + put("biz", 0.7); + }}); + }}; + assertThat(0.5, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + } + public void testOperationsEstimations() { Tree tree1 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 2); Tree tree2 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index d732d142a20..9c8e390b300 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -169,6 +169,58 @@ public class TreeTests extends AbstractSerializingTestCase { closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); } + public void testInferNestedFields() { + // Build a tree with 2 nodes and 3 leaves using 2 features + // The leaves have unique values 0.1, 0.2, 0.3 + Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION); + TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5); + builder.addLeaf(rootNode.getRightChild(), 0.3); + TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8); + builder.addLeaf(leftChildNode.getLeftChild(), 0.1); + builder.addLeaf(leftChildNode.getRightChild(), 0.2); + + List featureNames = Arrays.asList("foo.baz", "bar.biz"); + Tree tree = builder.setFeatureNames(featureNames).build(); + + // This feature vector should hit the right child of the root node + Map featureMap = new HashMap() {{ + put("foo", new HashMap(){{ + put("baz", 0.6); + }}); + put("bar", new HashMap(){{ + put("biz", 0.0); + }}); + }}; + assertThat(0.3, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + + // This should hit the left child of the left child of the root node + // i.e. it takes the path left, left + featureMap = new HashMap() {{ + put("foo", new HashMap(){{ + put("baz", 0.3); + }}); + put("bar", new HashMap(){{ + put("biz", 0.7); + }}); + }}; + assertThat(0.1, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + + // This should hit the right child of the left child of the root node + // i.e. it takes the path left, right + featureMap = new HashMap() {{ + put("foo", new HashMap(){{ + put("baz", 0.3); + }}); + put("bar", new HashMap(){{ + put("biz", 0.9); + }}); + }}; + assertThat(0.2, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001)); + } + public void testTreeClassificationProbability() { // Build a tree with 2 nodes and 3 leaves using 2 features // The leaves have unique values 0.1, 0.2, 0.3 diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/MapHelperTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/MapHelperTests.java new file mode 100644 index 00000000000..53562aa4d14 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/MapHelperTests.java @@ -0,0 +1,193 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.utils; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; + +public class MapHelperTests extends ESTestCase { + + public void testAbsolutePathStringAsKey() { + String path = "a.b.c.d"; + Map map = Collections.singletonMap(path, 2); + assertThat(MapHelper.dig(path, map), equalTo(2)); + assertThat(MapHelper.dig(path, Collections.emptyMap()), is(nullValue())); + } + + public void testSimplePath() { + String path = "a.b.c.d"; + Map map = Collections.singletonMap("a", + Collections.singletonMap("b", + Collections.singletonMap("c", + Collections.singletonMap("d", 2)))); + assertThat(MapHelper.dig(path, map), equalTo(2)); + + map = Collections.singletonMap("a", + Collections.singletonMap("b", + Collections.singletonMap("e", // Not part of path + Collections.singletonMap("d", 2)))); + assertThat(MapHelper.dig(path, map), is(nullValue())); + } + + public void testSimplePathReturningMap() { + String path = "a.b.c"; + Map map = Collections.singletonMap("a", + Collections.singletonMap("b", + Collections.singletonMap("c", + Collections.singletonMap("d", 2)))); + assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2))); + } + + public void testSimpleMixedPath() { + String path = "a.b.c.d"; + Map map = Collections.singletonMap("a", + Collections.singletonMap("b.c", + Collections.singletonMap("d", 2))); + assertThat(MapHelper.dig(path, map), equalTo(2)); + + map = Collections.singletonMap("a.b", + Collections.singletonMap("c", + Collections.singletonMap("d", 2))); + assertThat(MapHelper.dig(path, map), equalTo(2)); + + map = Collections.singletonMap("a.b.c", + Collections.singletonMap("d", 2)); + assertThat(MapHelper.dig(path, map), equalTo(2)); + + map = Collections.singletonMap("a", + Collections.singletonMap("b", + Collections.singletonMap("c.d", 2))); + assertThat(MapHelper.dig(path, map), equalTo(2)); + + map = Collections.singletonMap("a", + Collections.singletonMap("b.c.d", 2)); + assertThat(MapHelper.dig(path, map), equalTo(2)); + + map = Collections.singletonMap("a.b", + Collections.singletonMap("c.d", 2)); + assertThat(MapHelper.dig(path, map), equalTo(2)); + + map = Collections.singletonMap("a", + Collections.singletonMap("b.foo", + Collections.singletonMap("d", 2))); + assertThat(MapHelper.dig(path, map), is(nullValue())); + + map = Collections.singletonMap("a", + Collections.singletonMap("b.c", + Collections.singletonMap("foo", 2))); + assertThat(MapHelper.dig(path, map), is(nullValue())); + + map = Collections.singletonMap("x", + Collections.singletonMap("b.c", + Collections.singletonMap("d", 2))); + assertThat(MapHelper.dig(path, map), is(nullValue())); + } + + public void testSimpleMixedPathReturningMap() { + String path = "a.b.c"; + Map map = Collections.singletonMap("a", + Collections.singletonMap("b.c", + Collections.singletonMap("d", 2))); + assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2))); + + map = Collections.singletonMap("a", + Collections.singletonMap("b.foo", + Collections.singletonMap("d", 2))); + assertThat(MapHelper.dig(path, map), is(nullValue())); + + map = Collections.singletonMap("a", + Collections.singletonMap("b.not_c", + Collections.singletonMap("foo", 2))); + assertThat(MapHelper.dig(path, map), is(nullValue())); + + map = Collections.singletonMap("x", + Collections.singletonMap("b.c", + Collections.singletonMap("d", 2))); + assertThat(MapHelper.dig(path, map), is(nullValue())); + } + + public void testMultiplePotentialPaths() { + String path = "a.b.c.d"; + Map map = new LinkedHashMap() {{ + put("a", Collections.singletonMap("b", + Collections.singletonMap("c", + Collections.singletonMap("not_d", 5)))); + put("a.b", Collections.singletonMap("c", Collections.singletonMap("d", 2))); + }}; + assertThat(MapHelper.dig(path, map), equalTo(2)); + + map = new LinkedHashMap() {{ + put("a", Collections.singletonMap("b", + Collections.singletonMap("c", + Collections.singletonMap("d", 2)))); + put("a.b", Collections.singletonMap("c", Collections.singletonMap("not_d", 5))); + }}; + assertThat(MapHelper.dig(path, map), equalTo(2)); + + map = new LinkedHashMap() {{ + put("a", Collections.singletonMap("b", + new HashMap() {{ + put("c", Collections.singletonMap("not_d", 5)); + put("c.d", 2); + }})); + }}; + assertThat(MapHelper.dig(path, map), equalTo(2)); + + map = new LinkedHashMap() {{ + put("a", Collections.singletonMap("b", + new HashMap() {{ + put("c", Collections.singletonMap("d", 2)); + put("c.not_d", 5); + }})); + }}; + assertThat(MapHelper.dig(path, map), equalTo(2)); + + map = new LinkedHashMap() {{ + put("a", Collections.singletonMap("b", + Collections.singletonMap("c", + Collections.singletonMap("not_d", 5)))); + put("a.b", Collections.singletonMap("c", Collections.singletonMap("not_d", 2))); + }}; + + assertThat(MapHelper.dig(path, map), is(nullValue())); + } + + public void testMultiplePotentialPathsReturningMap() { + String path = "a.b.c"; + Map map = new LinkedHashMap() {{ + put("a", Collections.singletonMap("b", + Collections.singletonMap("c", + Collections.singletonMap("d", 2)))); + put("a.b", Collections.singletonMap("not_c", Collections.singletonMap("d", 2))); + }}; + assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2))); + + map = new LinkedHashMap() {{ + put("a", Collections.singletonMap("b", + Collections.singletonMap("not_c", + Collections.singletonMap("d", 2)))); + put("a.b", Collections.singletonMap("c", Collections.singletonMap("d", 2))); + }}; + assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2))); + + map = new LinkedHashMap() {{ + put("a", Collections.singletonMap("b", + Collections.singletonMap("not_c", + Collections.singletonMap("d", 2)))); + put("a.b", Collections.singletonMap("not_c", Collections.singletonMap("d", 2))); + }}; + assertThat(MapHelper.dig(path, map), is(nullValue())); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 19c0054b522..18ac57d1ae8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.MapHelper; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import java.util.Arrays; @@ -128,7 +129,7 @@ public class InferenceProcessor extends AbstractProcessor { Map fields = new HashMap<>(ingestDocument.getSourceAndMetadata()); if (fieldMapping != null) { fieldMapping.forEach((src, dest) -> { - Object srcValue = fields.remove(src); + Object srcValue = MapHelper.dig(src, fields); if (srcValue != null) { fields.put(dest, srcValue); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index 4e62c69336b..8a233d534a7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; @@ -16,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.utils.MapHelper; import java.util.HashSet; import java.util.Map; @@ -61,7 +61,7 @@ public class LocalModel implements Model { @Override public void infer(Map fields, InferenceConfig config, ActionListener listener) { try { - if (Sets.haveEmptyIntersection(fieldNames, fields.keySet())) { + if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) { listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId))); return; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 81e5c79135c..5d95d1c79c5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -181,7 +181,7 @@ public class InferenceProcessorTests extends ESTestCase { String modelId = "model"; Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); - Map fieldMapping = new HashMap(3) {{ + Map fieldMapping = new HashMap(5) {{ put("value1", "new_value1"); put("value2", "new_value2"); put("categorical", "new_categorical"); @@ -195,7 +195,7 @@ public class InferenceProcessorTests extends ESTestCase { new ClassificationConfig(topNClasses, null, null), fieldMapping); - Map source = new HashMap(3){{ + Map source = new HashMap(5){{ put("value1", 1); put("categorical", "foo"); put("un_touched", "bar"); @@ -203,8 +203,46 @@ public class InferenceProcessorTests extends ESTestCase { Map ingestMetadata = new HashMap<>(); IngestDocument document = new IngestDocument(source, ingestMetadata); - Map expectedMap = new HashMap(2) {{ + Map expectedMap = new HashMap(7) {{ put("new_value1", 1); + put("value1", 1); + put("categorical", "foo"); + put("new_categorical", "foo"); + put("un_touched", "bar"); + }}; + assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap)); + } + + public void testGenerateWithMappingNestedFields() { + String modelId = "model"; + Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); + + Map fieldMapping = new HashMap(5) {{ + put("value1.foo", "new_value1"); + put("value2", "new_value2"); + put("categorical.bar", "new_categorical"); + }}; + + InferenceProcessor processor = new InferenceProcessor(client, + auditor, + "my_processor", + "my_field", + modelId, + new ClassificationConfig(topNClasses, null, null), + fieldMapping); + + Map source = new HashMap(5){{ + put("value1", Collections.singletonMap("foo", 1)); + put("categorical.bar", "foo"); + put("un_touched", "bar"); + }}; + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + Map expectedMap = new HashMap(7) {{ + put("new_value1", 1); + put("value1", Collections.singletonMap("foo", 1)); + put("categorical.bar", "foo"); put("new_categorical", "foo"); put("un_touched", "bar"); }}; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index b8cb652878f..3ce5de2aced 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -41,7 +41,7 @@ public class LocalModelTests extends ESTestCase { public void testClassificationInfer() throws Exception { String modelId = "classification_model"; - List inputFields = Arrays.asList("foo", "bar", "categorical"); + List inputFields = Arrays.asList("field.foo", "field.bar", "categorical"); TrainedModelDefinition definition = new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildClassification(false)) @@ -49,8 +49,8 @@ public class LocalModelTests extends ESTestCase { Model model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields)); Map fields = new HashMap() {{ - put("foo", 1.0); - put("bar", 0.5); + put("field.foo", 1.0); + put("field.bar", 0.5); put("categorical", "dog"); }}; @@ -93,8 +93,8 @@ public class LocalModelTests extends ESTestCase { Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields)); Map fields = new HashMap() {{ - put("foo", 1.0); - put("bar", 0.5); + put("field.foo", 1.0); + put("field.bar", 0.5); put("categorical", "dog"); }}; @@ -147,7 +147,7 @@ public class LocalModelTests extends ESTestCase { } public static TrainedModel buildClassification(boolean includeLabels) { - List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); + List featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog"); Tree tree1 = Tree.builder() .setFeatureNames(featureNames) .setRoot(TreeNode.builder(0) @@ -193,7 +193,7 @@ public class LocalModelTests extends ESTestCase { } public static TrainedModel buildRegression() { - List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); + List featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog"); Tree tree1 = Tree.builder() .setFeatureNames(featureNames) .setRoot(TreeNode.builder(0) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 0d4aefc26b5..b71f69aea76 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -66,9 +66,9 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { oneHotEncoding.put("cat", "animal_cat"); oneHotEncoding.put("dog", "animal_dog"); TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2) - .setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical"))) + .setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical"))) .setParsedDefinition(new TrainedModelDefinition.Builder() - .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding))) .setTrainedModel(buildClassification(true))) .setVersion(Version.CURRENT) .setLicenseLevel(License.OperationMode.PLATINUM.description()) @@ -77,9 +77,9 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { .setEstimatedHeapMemory(0) .build(); TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1) - .setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical"))) + .setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical"))) .setParsedDefinition(new TrainedModelDefinition.Builder() - .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding))) .setTrainedModel(buildRegression())) .setVersion(Version.CURRENT) .setEstimatedOperations(0) @@ -99,26 +99,42 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { List> toInfer = new ArrayList<>(); toInfer.add(new HashMap() {{ - put("foo", 1.0); - put("bar", 0.5); - put("categorical", "dog"); + put("field", new HashMap(){{ + put("foo", 1.0); + put("bar", 0.5); + }}); + put("other", new HashMap(){{ + put("categorical", "dog"); + }}); }}); toInfer.add(new HashMap() {{ - put("foo", 0.9); - put("bar", 1.5); - put("categorical", "cat"); + put("field", new HashMap(){{ + put("foo", 0.9); + put("bar", 1.5); + }}); + put("other", new HashMap(){{ + put("categorical", "cat"); + }}); }}); List> toInfer2 = new ArrayList<>(); toInfer2.add(new HashMap() {{ - put("foo", 0.0); - put("bar", 0.01); - put("categorical", "dog"); + put("field", new HashMap(){{ + put("foo", 0.0); + put("bar", 0.01); + }}); + put("other", new HashMap(){{ + put("categorical", "dog"); + }}); }}); toInfer2.add(new HashMap() {{ - put("foo", 1.0); - put("bar", 0.0); - put("categorical", "cat"); + put("field", new HashMap(){{ + put("foo", 1.0); + put("bar", 0.0); + }}); + put("other", new HashMap(){{ + put("categorical", "cat"); + }}); }}); // Test regression