mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-25 22:36:20 +00:00
This fixes support for nested fields We now support fully nested, fully collapsed, or a mix of both on inference docs. ES mappings allow the `_source` to be any combination of nested objects + dot delimited fields. So, we should do our best to find the best path down the Map for the desired field.
This commit is contained in:
parent
06a24f09cf
commit
4396a1f78b
@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
@ -103,7 +104,7 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
||||
|
||||
@Override
|
||||
public void process(Map<String, Object> fields) {
|
||||
Object value = fields.get(field);
|
||||
Object value = MapHelper.dig(field, fields);
|
||||
if (value == null) {
|
||||
return;
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
@ -86,7 +87,7 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
||||
|
||||
@Override
|
||||
public void process(Map<String, Object> fields) {
|
||||
Object value = fields.get(field);
|
||||
Object value = MapHelper.dig(field, fields);
|
||||
if (value == null) {
|
||||
return;
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
@ -114,7 +115,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
||||
|
||||
@Override
|
||||
public void process(Map<String, Object> fields) {
|
||||
Object value = fields.get(field);
|
||||
Object value = MapHelper.dig(field, fields);
|
||||
if (value == null) {
|
||||
return;
|
||||
}
|
||||
|
@ -28,6 +28,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfi
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayDeque;
|
||||
@ -129,7 +130,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
|
||||
}
|
||||
|
||||
List<Double> features = featureNames.stream().map(f -> InferenceHelpers.toDouble(fields.get(f))).collect(Collectors.toList());
|
||||
List<Double> features = featureNames.stream()
|
||||
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
|
||||
.collect(Collectors.toList());
|
||||
return infer(features, config);
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,133 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.utils;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
import java.util.Stack;
|
||||
|
||||
public final class MapHelper {
|
||||
|
||||
private MapHelper() {}
|
||||
|
||||
/**
|
||||
* This eagerly digs (depth first search, longer keys first) through the map by tokenizing the provided path on '.'.
|
||||
*
|
||||
* It is possible for ES _source docs to have "mixed" path formats. So, we should search all potential paths
|
||||
* given the current knowledge of the map.
|
||||
*
|
||||
* Examples:
|
||||
*
|
||||
* The following maps would return `2` given the path "a.b.c.d"
|
||||
*
|
||||
* {
|
||||
* "a.b.c.d" : 2
|
||||
* }
|
||||
* {
|
||||
* "a" :{"b": {"c": {"d" : 2}}}
|
||||
* }
|
||||
* {
|
||||
* "a" :{"b.c": {"d" : 2}}}
|
||||
* }
|
||||
* {
|
||||
* "a" :{"b": {"c": {"d" : 2}}},
|
||||
* "a.b" :{"c": {"d" : 5}} // we choose the first one found, we go down longer keys first
|
||||
* }
|
||||
* {
|
||||
* "a" :{"b": {"c": {"NOT_d" : 2, "d": 2}}}
|
||||
* }
|
||||
*
|
||||
* Conceptual "Worse case" 5 potential paths explored for "a.b.c.d" until 2 is finally returned
|
||||
* {
|
||||
* "a.b.c": {"not_d": 2},
|
||||
* "a.b": {"c": {"not_d": 2}},
|
||||
* "a": {"b.c": {"not_d": 2}},
|
||||
* "a": {"b" :{ "c.not_d": 2}},
|
||||
* "a" :{"b": {"c": {"not_d" : 2}}},
|
||||
* "a" :{"b": {"c": {"d" : 2}}},
|
||||
* }
|
||||
*
|
||||
* We don't exhaustively create all potential paths.
|
||||
* If we did, this would result in 2^n-1 total possible paths, where {@code n = path.split("\\.").length}.
|
||||
*
|
||||
* Instead we lazily create potential paths once we know that they are possibilities.
|
||||
*
|
||||
* @param path Dot delimited path containing the field desired
|
||||
* @param map The {@link Map} map to dig
|
||||
* @return The found object. Returns {@code null} if not found
|
||||
*/
|
||||
@Nullable
|
||||
public static Object dig(String path, Map<String, Object> map) {
|
||||
// short cut before search
|
||||
if (map.keySet().contains(path)) {
|
||||
return map.get(path);
|
||||
}
|
||||
String[] fields = path.split("\\.");
|
||||
if (Arrays.stream(fields).anyMatch(String::isEmpty)) {
|
||||
throw new IllegalArgumentException("Empty path detected. Invalid field name");
|
||||
}
|
||||
Stack<PotentialPath> pathStack = new Stack<>();
|
||||
pathStack.push(new PotentialPath(map, 0));
|
||||
return explore(fields, pathStack);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static Object explore(String[] path, Stack<PotentialPath> pathStack) {
|
||||
while (pathStack.empty() == false) {
|
||||
PotentialPath potentialPath = pathStack.pop();
|
||||
int endPos = potentialPath.pathPosition + 1;
|
||||
int startPos = potentialPath.pathPosition;
|
||||
Map<String, Object> map = potentialPath.map;
|
||||
String candidateKey = null;
|
||||
while(endPos <= path.length) {
|
||||
candidateKey = mergePath(path, startPos, endPos);
|
||||
Object next = map.get(candidateKey);
|
||||
if (endPos == path.length && next != null) { // exit early, we reached the full path and found something
|
||||
return next;
|
||||
}
|
||||
if (next instanceof Map<?, ?>) { // we found another map, continue exploring down this path
|
||||
pathStack.push(new PotentialPath((Map<String, Object>)next, endPos));
|
||||
}
|
||||
endPos++;
|
||||
}
|
||||
if (candidateKey != null && map.containsKey(candidateKey)) { //exit early
|
||||
return map.get(candidateKey);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private static String mergePath(String[] path, int start, int end) {
|
||||
if (start + 1 == end) { // early exit, no need to create sb
|
||||
return path[start];
|
||||
}
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (int i = start; i < end - 1; i++) {
|
||||
sb.append(path[i]);
|
||||
sb.append(".");
|
||||
}
|
||||
sb.append(path[end - 1]);
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
private static class PotentialPath {
|
||||
|
||||
// Pointer to where to start exploring
|
||||
private final Map<String, Object> map;
|
||||
// Where in the requested path are we
|
||||
private final int pathPosition;
|
||||
|
||||
private PotentialPath(Map<String, Object> map, int pathPosition) {
|
||||
this.map = map;
|
||||
this.pathPosition = pathPosition;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -65,4 +65,22 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
|
||||
testProcess(encoding, fieldValues, matchers);
|
||||
}
|
||||
|
||||
public void testProcessWithNestedField() {
|
||||
String field = "categorical.child";
|
||||
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
|
||||
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
|
||||
v -> randomDoubleBetween(0.0, 1.0, false)));
|
||||
String encodedFeatureName = "encoded";
|
||||
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap);
|
||||
|
||||
Map<String, Object> fieldValues = new HashMap<String, Object>() {{
|
||||
put("categorical", new HashMap<String, Object>(){{
|
||||
put("child", "farequote");
|
||||
}});
|
||||
}};
|
||||
|
||||
encoding.process(fieldValues);
|
||||
assertThat(fieldValues.get("encoded"), equalTo(valueMap.get("farequote")));
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -67,4 +67,19 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
|
||||
testProcess(encoding, fieldValues, matchers);
|
||||
}
|
||||
|
||||
public void testProcessWithNestedField() {
|
||||
String field = "categorical.child";
|
||||
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
|
||||
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> "Column_" + v.toString()));
|
||||
OneHotEncoding encoding = new OneHotEncoding(field, valueMap);
|
||||
Map<String, Object> fieldValues = new HashMap<String, Object>() {{
|
||||
put("categorical", new HashMap<String, Object>(){{
|
||||
put("child", "farequote");
|
||||
}});
|
||||
}};
|
||||
|
||||
encoding.process(fieldValues);
|
||||
assertThat(fieldValues.get("Column_farequote"), equalTo(1));
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -68,4 +68,24 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
|
||||
testProcess(encoding, fieldValues, matchers);
|
||||
}
|
||||
|
||||
public void testProcessWithNestedField() {
|
||||
String field = "categorical.child";
|
||||
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
|
||||
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
|
||||
v -> randomDoubleBetween(0.0, 1.0, false)));
|
||||
String encodedFeatureName = "encoded";
|
||||
Double defaultvalue = randomDouble();
|
||||
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue);
|
||||
|
||||
Map<String, Object> fieldValues = new HashMap<String, Object>() {{
|
||||
put("categorical", new HashMap<String, Object>(){{
|
||||
put("child", "farequote");
|
||||
}});
|
||||
}};
|
||||
|
||||
encoding.process(fieldValues);
|
||||
|
||||
assertThat(fieldValues.get("encoded"), equalTo(valueMap.get("farequote")));
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -445,6 +445,63 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
}
|
||||
|
||||
public void testInferNestedFields() {
|
||||
List<String> featureNames = Arrays.asList("foo.baz", "bar.biz");
|
||||
Tree tree1 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(0.3))
|
||||
.addNode(TreeNode.builder(2)
|
||||
.setThreshold(0.8)
|
||||
.setSplitFeature(1)
|
||||
.setLeftChild(3)
|
||||
.setRightChild(4))
|
||||
.addNode(TreeNode.builder(3).setLeafValue(0.1))
|
||||
.addNode(TreeNode.builder(4).setLeafValue(0.2)).build();
|
||||
Tree tree2 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setSplitFeature(0)
|
||||
.setThreshold(0.5))
|
||||
.addNode(TreeNode.builder(1).setLeafValue(1.5))
|
||||
.addNode(TreeNode.builder(2).setLeafValue(0.9))
|
||||
.build();
|
||||
Ensemble ensemble = Ensemble.builder()
|
||||
.setTargetType(TargetType.REGRESSION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2))
|
||||
.setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5}))
|
||||
.build();
|
||||
|
||||
Map<String, Object> featureMap = new HashMap<String, Object>() {{
|
||||
put("foo", new HashMap<String, Object>(){{
|
||||
put("baz", 0.4);
|
||||
}});
|
||||
put("bar", new HashMap<String, Object>(){{
|
||||
put("biz", 0.0);
|
||||
}});
|
||||
}};
|
||||
assertThat(0.9,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
|
||||
featureMap = new HashMap<String, Object>() {{
|
||||
put("foo", new HashMap<String, Object>(){{
|
||||
put("baz", 2.0);
|
||||
}});
|
||||
put("bar", new HashMap<String, Object>(){{
|
||||
put("biz", 0.7);
|
||||
}});
|
||||
}};
|
||||
assertThat(0.5,
|
||||
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
}
|
||||
|
||||
public void testOperationsEstimations() {
|
||||
Tree tree1 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 2);
|
||||
Tree tree2 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
|
||||
|
@ -169,6 +169,58 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
}
|
||||
|
||||
public void testInferNestedFields() {
|
||||
// Build a tree with 2 nodes and 3 leaves using 2 features
|
||||
// The leaves have unique values 0.1, 0.2, 0.3
|
||||
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
|
||||
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
|
||||
builder.addLeaf(rootNode.getRightChild(), 0.3);
|
||||
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
|
||||
builder.addLeaf(leftChildNode.getLeftChild(), 0.1);
|
||||
builder.addLeaf(leftChildNode.getRightChild(), 0.2);
|
||||
|
||||
List<String> featureNames = Arrays.asList("foo.baz", "bar.biz");
|
||||
Tree tree = builder.setFeatureNames(featureNames).build();
|
||||
|
||||
// This feature vector should hit the right child of the root node
|
||||
Map<String, Object> featureMap = new HashMap<String, Object>() {{
|
||||
put("foo", new HashMap<String, Object>(){{
|
||||
put("baz", 0.6);
|
||||
}});
|
||||
put("bar", new HashMap<String, Object>(){{
|
||||
put("biz", 0.0);
|
||||
}});
|
||||
}};
|
||||
assertThat(0.3,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
|
||||
// This should hit the left child of the left child of the root node
|
||||
// i.e. it takes the path left, left
|
||||
featureMap = new HashMap<String, Object>() {{
|
||||
put("foo", new HashMap<String, Object>(){{
|
||||
put("baz", 0.3);
|
||||
}});
|
||||
put("bar", new HashMap<String, Object>(){{
|
||||
put("biz", 0.7);
|
||||
}});
|
||||
}};
|
||||
assertThat(0.1,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
|
||||
// This should hit the right child of the left child of the root node
|
||||
// i.e. it takes the path left, right
|
||||
featureMap = new HashMap<String, Object>() {{
|
||||
put("foo", new HashMap<String, Object>(){{
|
||||
put("baz", 0.3);
|
||||
}});
|
||||
put("bar", new HashMap<String, Object>(){{
|
||||
put("biz", 0.9);
|
||||
}});
|
||||
}};
|
||||
assertThat(0.2,
|
||||
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
||||
}
|
||||
|
||||
public void testTreeClassificationProbability() {
|
||||
// Build a tree with 2 nodes and 3 leaves using 2 features
|
||||
// The leaves have unique values 0.1, 0.2, 0.3
|
||||
|
@ -0,0 +1,193 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.utils;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class MapHelperTests extends ESTestCase {
|
||||
|
||||
public void testAbsolutePathStringAsKey() {
|
||||
String path = "a.b.c.d";
|
||||
Map<String, Object> map = Collections.singletonMap(path, 2);
|
||||
assertThat(MapHelper.dig(path, map), equalTo(2));
|
||||
assertThat(MapHelper.dig(path, Collections.emptyMap()), is(nullValue()));
|
||||
}
|
||||
|
||||
public void testSimplePath() {
|
||||
String path = "a.b.c.d";
|
||||
Map<String, Object> map = Collections.singletonMap("a",
|
||||
Collections.singletonMap("b",
|
||||
Collections.singletonMap("c",
|
||||
Collections.singletonMap("d", 2))));
|
||||
assertThat(MapHelper.dig(path, map), equalTo(2));
|
||||
|
||||
map = Collections.singletonMap("a",
|
||||
Collections.singletonMap("b",
|
||||
Collections.singletonMap("e", // Not part of path
|
||||
Collections.singletonMap("d", 2))));
|
||||
assertThat(MapHelper.dig(path, map), is(nullValue()));
|
||||
}
|
||||
|
||||
public void testSimplePathReturningMap() {
|
||||
String path = "a.b.c";
|
||||
Map<String, Object> map = Collections.singletonMap("a",
|
||||
Collections.singletonMap("b",
|
||||
Collections.singletonMap("c",
|
||||
Collections.singletonMap("d", 2))));
|
||||
assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2)));
|
||||
}
|
||||
|
||||
public void testSimpleMixedPath() {
|
||||
String path = "a.b.c.d";
|
||||
Map<String, Object> map = Collections.singletonMap("a",
|
||||
Collections.singletonMap("b.c",
|
||||
Collections.singletonMap("d", 2)));
|
||||
assertThat(MapHelper.dig(path, map), equalTo(2));
|
||||
|
||||
map = Collections.singletonMap("a.b",
|
||||
Collections.singletonMap("c",
|
||||
Collections.singletonMap("d", 2)));
|
||||
assertThat(MapHelper.dig(path, map), equalTo(2));
|
||||
|
||||
map = Collections.singletonMap("a.b.c",
|
||||
Collections.singletonMap("d", 2));
|
||||
assertThat(MapHelper.dig(path, map), equalTo(2));
|
||||
|
||||
map = Collections.singletonMap("a",
|
||||
Collections.singletonMap("b",
|
||||
Collections.singletonMap("c.d", 2)));
|
||||
assertThat(MapHelper.dig(path, map), equalTo(2));
|
||||
|
||||
map = Collections.singletonMap("a",
|
||||
Collections.singletonMap("b.c.d", 2));
|
||||
assertThat(MapHelper.dig(path, map), equalTo(2));
|
||||
|
||||
map = Collections.singletonMap("a.b",
|
||||
Collections.singletonMap("c.d", 2));
|
||||
assertThat(MapHelper.dig(path, map), equalTo(2));
|
||||
|
||||
map = Collections.singletonMap("a",
|
||||
Collections.singletonMap("b.foo",
|
||||
Collections.singletonMap("d", 2)));
|
||||
assertThat(MapHelper.dig(path, map), is(nullValue()));
|
||||
|
||||
map = Collections.singletonMap("a",
|
||||
Collections.singletonMap("b.c",
|
||||
Collections.singletonMap("foo", 2)));
|
||||
assertThat(MapHelper.dig(path, map), is(nullValue()));
|
||||
|
||||
map = Collections.singletonMap("x",
|
||||
Collections.singletonMap("b.c",
|
||||
Collections.singletonMap("d", 2)));
|
||||
assertThat(MapHelper.dig(path, map), is(nullValue()));
|
||||
}
|
||||
|
||||
public void testSimpleMixedPathReturningMap() {
|
||||
String path = "a.b.c";
|
||||
Map<String, Object> map = Collections.singletonMap("a",
|
||||
Collections.singletonMap("b.c",
|
||||
Collections.singletonMap("d", 2)));
|
||||
assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2)));
|
||||
|
||||
map = Collections.singletonMap("a",
|
||||
Collections.singletonMap("b.foo",
|
||||
Collections.singletonMap("d", 2)));
|
||||
assertThat(MapHelper.dig(path, map), is(nullValue()));
|
||||
|
||||
map = Collections.singletonMap("a",
|
||||
Collections.singletonMap("b.not_c",
|
||||
Collections.singletonMap("foo", 2)));
|
||||
assertThat(MapHelper.dig(path, map), is(nullValue()));
|
||||
|
||||
map = Collections.singletonMap("x",
|
||||
Collections.singletonMap("b.c",
|
||||
Collections.singletonMap("d", 2)));
|
||||
assertThat(MapHelper.dig(path, map), is(nullValue()));
|
||||
}
|
||||
|
||||
public void testMultiplePotentialPaths() {
|
||||
String path = "a.b.c.d";
|
||||
Map<String, Object> map = new LinkedHashMap<String, Object>() {{
|
||||
put("a", Collections.singletonMap("b",
|
||||
Collections.singletonMap("c",
|
||||
Collections.singletonMap("not_d", 5))));
|
||||
put("a.b", Collections.singletonMap("c", Collections.singletonMap("d", 2)));
|
||||
}};
|
||||
assertThat(MapHelper.dig(path, map), equalTo(2));
|
||||
|
||||
map = new LinkedHashMap<String, Object>() {{
|
||||
put("a", Collections.singletonMap("b",
|
||||
Collections.singletonMap("c",
|
||||
Collections.singletonMap("d", 2))));
|
||||
put("a.b", Collections.singletonMap("c", Collections.singletonMap("not_d", 5)));
|
||||
}};
|
||||
assertThat(MapHelper.dig(path, map), equalTo(2));
|
||||
|
||||
map = new LinkedHashMap<String, Object>() {{
|
||||
put("a", Collections.singletonMap("b",
|
||||
new HashMap<String, Object>() {{
|
||||
put("c", Collections.singletonMap("not_d", 5));
|
||||
put("c.d", 2);
|
||||
}}));
|
||||
}};
|
||||
assertThat(MapHelper.dig(path, map), equalTo(2));
|
||||
|
||||
map = new LinkedHashMap<String, Object>() {{
|
||||
put("a", Collections.singletonMap("b",
|
||||
new HashMap<String, Object>() {{
|
||||
put("c", Collections.singletonMap("d", 2));
|
||||
put("c.not_d", 5);
|
||||
}}));
|
||||
}};
|
||||
assertThat(MapHelper.dig(path, map), equalTo(2));
|
||||
|
||||
map = new LinkedHashMap<String, Object>() {{
|
||||
put("a", Collections.singletonMap("b",
|
||||
Collections.singletonMap("c",
|
||||
Collections.singletonMap("not_d", 5))));
|
||||
put("a.b", Collections.singletonMap("c", Collections.singletonMap("not_d", 2)));
|
||||
}};
|
||||
|
||||
assertThat(MapHelper.dig(path, map), is(nullValue()));
|
||||
}
|
||||
|
||||
public void testMultiplePotentialPathsReturningMap() {
|
||||
String path = "a.b.c";
|
||||
Map<String, Object> map = new LinkedHashMap<String, Object>() {{
|
||||
put("a", Collections.singletonMap("b",
|
||||
Collections.singletonMap("c",
|
||||
Collections.singletonMap("d", 2))));
|
||||
put("a.b", Collections.singletonMap("not_c", Collections.singletonMap("d", 2)));
|
||||
}};
|
||||
assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2)));
|
||||
|
||||
map = new LinkedHashMap<String, Object>() {{
|
||||
put("a", Collections.singletonMap("b",
|
||||
Collections.singletonMap("not_c",
|
||||
Collections.singletonMap("d", 2))));
|
||||
put("a.b", Collections.singletonMap("c", Collections.singletonMap("d", 2)));
|
||||
}};
|
||||
assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2)));
|
||||
|
||||
map = new LinkedHashMap<String, Object>() {{
|
||||
put("a", Collections.singletonMap("b",
|
||||
Collections.singletonMap("not_c",
|
||||
Collections.singletonMap("d", 2))));
|
||||
put("a.b", Collections.singletonMap("not_c", Collections.singletonMap("d", 2)));
|
||||
}};
|
||||
assertThat(MapHelper.dig(path, map), is(nullValue()));
|
||||
}
|
||||
|
||||
}
|
@ -35,6 +35,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
|
||||
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
||||
|
||||
import java.util.Arrays;
|
||||
@ -128,7 +129,7 @@ public class InferenceProcessor extends AbstractProcessor {
|
||||
Map<String, Object> fields = new HashMap<>(ingestDocument.getSourceAndMetadata());
|
||||
if (fieldMapping != null) {
|
||||
fieldMapping.forEach((src, dest) -> {
|
||||
Object srcValue = fields.remove(src);
|
||||
Object srcValue = MapHelper.dig(src, fields);
|
||||
if (srcValue != null) {
|
||||
fields.put(dest, srcValue);
|
||||
}
|
||||
|
@ -6,7 +6,6 @@
|
||||
package org.elasticsearch.xpack.ml.inference.loadingservice;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
@ -16,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
@ -61,7 +61,7 @@ public class LocalModel implements Model {
|
||||
@Override
|
||||
public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
|
||||
try {
|
||||
if (Sets.haveEmptyIntersection(fieldNames, fields.keySet())) {
|
||||
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
|
||||
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
|
||||
return;
|
||||
}
|
||||
|
@ -181,7 +181,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
||||
String modelId = "model";
|
||||
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
|
||||
|
||||
Map<String, String> fieldMapping = new HashMap<String, String>(3) {{
|
||||
Map<String, String> fieldMapping = new HashMap<String, String>(5) {{
|
||||
put("value1", "new_value1");
|
||||
put("value2", "new_value2");
|
||||
put("categorical", "new_categorical");
|
||||
@ -195,7 +195,7 @@ public class InferenceProcessorTests extends ESTestCase {
|
||||
new ClassificationConfig(topNClasses, null, null),
|
||||
fieldMapping);
|
||||
|
||||
Map<String, Object> source = new HashMap<String, Object>(3){{
|
||||
Map<String, Object> source = new HashMap<String, Object>(5){{
|
||||
put("value1", 1);
|
||||
put("categorical", "foo");
|
||||
put("un_touched", "bar");
|
||||
@ -203,8 +203,46 @@ public class InferenceProcessorTests extends ESTestCase {
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
Map<String, Object> expectedMap = new HashMap<String, Object>(2) {{
|
||||
Map<String, Object> expectedMap = new HashMap<String, Object>(7) {{
|
||||
put("new_value1", 1);
|
||||
put("value1", 1);
|
||||
put("categorical", "foo");
|
||||
put("new_categorical", "foo");
|
||||
put("un_touched", "bar");
|
||||
}};
|
||||
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap));
|
||||
}
|
||||
|
||||
public void testGenerateWithMappingNestedFields() {
|
||||
String modelId = "model";
|
||||
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
|
||||
|
||||
Map<String, String> fieldMapping = new HashMap<String, String>(5) {{
|
||||
put("value1.foo", "new_value1");
|
||||
put("value2", "new_value2");
|
||||
put("categorical.bar", "new_categorical");
|
||||
}};
|
||||
|
||||
InferenceProcessor processor = new InferenceProcessor(client,
|
||||
auditor,
|
||||
"my_processor",
|
||||
"my_field",
|
||||
modelId,
|
||||
new ClassificationConfig(topNClasses, null, null),
|
||||
fieldMapping);
|
||||
|
||||
Map<String, Object> source = new HashMap<String, Object>(5){{
|
||||
put("value1", Collections.singletonMap("foo", 1));
|
||||
put("categorical.bar", "foo");
|
||||
put("un_touched", "bar");
|
||||
}};
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
Map<String, Object> expectedMap = new HashMap<String, Object>(7) {{
|
||||
put("new_value1", 1);
|
||||
put("value1", Collections.singletonMap("foo", 1));
|
||||
put("categorical.bar", "foo");
|
||||
put("new_categorical", "foo");
|
||||
put("un_touched", "bar");
|
||||
}};
|
||||
|
@ -41,7 +41,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
|
||||
public void testClassificationInfer() throws Exception {
|
||||
String modelId = "classification_model";
|
||||
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
|
||||
List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
|
||||
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setTrainedModel(buildClassification(false))
|
||||
@ -49,8 +49,8 @@ public class LocalModelTests extends ESTestCase {
|
||||
|
||||
Model model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields));
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.5);
|
||||
put("field.foo", 1.0);
|
||||
put("field.bar", 0.5);
|
||||
put("categorical", "dog");
|
||||
}};
|
||||
|
||||
@ -93,8 +93,8 @@ public class LocalModelTests extends ESTestCase {
|
||||
Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields));
|
||||
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.5);
|
||||
put("field.foo", 1.0);
|
||||
put("field.bar", 0.5);
|
||||
put("categorical", "dog");
|
||||
}};
|
||||
|
||||
@ -147,7 +147,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
}
|
||||
|
||||
public static TrainedModel buildClassification(boolean includeLabels) {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog");
|
||||
List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
|
||||
Tree tree1 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
@ -193,7 +193,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
}
|
||||
|
||||
public static TrainedModel buildRegression() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog");
|
||||
List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
|
||||
Tree tree1 = Tree.builder()
|
||||
.setFeatureNames(featureNames)
|
||||
.setRoot(TreeNode.builder(0)
|
||||
|
@ -66,9 +66,9 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
||||
oneHotEncoding.put("cat", "animal_cat");
|
||||
oneHotEncoding.put("dog", "animal_dog");
|
||||
TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical")))
|
||||
.setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
|
||||
.setParsedDefinition(new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
|
||||
.setTrainedModel(buildClassification(true)))
|
||||
.setVersion(Version.CURRENT)
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
@ -77,9 +77,9 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
||||
.setEstimatedHeapMemory(0)
|
||||
.build();
|
||||
TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical")))
|
||||
.setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
|
||||
.setParsedDefinition(new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
|
||||
.setTrainedModel(buildRegression()))
|
||||
.setVersion(Version.CURRENT)
|
||||
.setEstimatedOperations(0)
|
||||
@ -99,26 +99,42 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
||||
|
||||
List<Map<String, Object>> toInfer = new ArrayList<>();
|
||||
toInfer.add(new HashMap<String, Object>() {{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.5);
|
||||
put("categorical", "dog");
|
||||
put("field", new HashMap<String, Object>(){{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.5);
|
||||
}});
|
||||
put("other", new HashMap<String, String>(){{
|
||||
put("categorical", "dog");
|
||||
}});
|
||||
}});
|
||||
toInfer.add(new HashMap<String, Object>() {{
|
||||
put("foo", 0.9);
|
||||
put("bar", 1.5);
|
||||
put("categorical", "cat");
|
||||
put("field", new HashMap<String, Object>(){{
|
||||
put("foo", 0.9);
|
||||
put("bar", 1.5);
|
||||
}});
|
||||
put("other", new HashMap<String, String>(){{
|
||||
put("categorical", "cat");
|
||||
}});
|
||||
}});
|
||||
|
||||
List<Map<String, Object>> toInfer2 = new ArrayList<>();
|
||||
toInfer2.add(new HashMap<String, Object>() {{
|
||||
put("foo", 0.0);
|
||||
put("bar", 0.01);
|
||||
put("categorical", "dog");
|
||||
put("field", new HashMap<String, Object>(){{
|
||||
put("foo", 0.0);
|
||||
put("bar", 0.01);
|
||||
}});
|
||||
put("other", new HashMap<String, String>(){{
|
||||
put("categorical", "dog");
|
||||
}});
|
||||
}});
|
||||
toInfer2.add(new HashMap<String, Object>() {{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.0);
|
||||
put("categorical", "cat");
|
||||
put("field", new HashMap<String, Object>(){{
|
||||
put("foo", 1.0);
|
||||
put("bar", 0.0);
|
||||
}});
|
||||
put("other", new HashMap<String, String>(){{
|
||||
put("categorical", "cat");
|
||||
}});
|
||||
}});
|
||||
|
||||
// Test regression
|
||||
|
Loading…
x
Reference in New Issue
Block a user