From b888f363889f53eb92a540316e21a5afdcf620e4 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 12 Nov 2020 11:15:57 -0500 Subject: [PATCH] [ML] fix custom feature processor extraction bugs around boolean fields and custom one_hot feature output order (#64937) (#65009) This commit fixes two problems: - When extracting a doc value, we allow boolean scalars to be used as input - The output order of processed feature names is deterministic. Previous custom one hot fields used to be non-deterministic and thus could cause weird bugs. --- .../ml/inference/preprocessing/NGram.java | 10 ++ .../preprocessing/OneHotEncoding.java | 7 +- .../inference/preprocessing/NGramTests.java | 5 +- .../preprocessing/PreProcessingTests.java | 18 +++ .../extractor/DataFrameDataExtractor.java | 2 +- .../xpack/ml/extractor/ProcessedField.java | 9 +- .../ml/extractor/ProcessedFieldTests.java | 104 +++++++++++++++--- 7 files changed, 130 insertions(+), 25 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java index 5ec1d16e2a1..229dd092e3c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java @@ -134,6 +134,13 @@ public class NGram implements LenientlyParsedPreProcessor, StrictlyParsedPreProc if (length > MAX_LENGTH) { throw ExceptionsHelper.badRequestException("[{}] must be not be greater than [{}]", LENGTH.getPreferredName(), MAX_LENGTH); } + if (Arrays.stream(this.nGrams).anyMatch(i -> i > length)) { + throw ExceptionsHelper.badRequestException( + "[{}] and [{}] are invalid; all ngrams must be shorter than or equal to length [{}]", + NGRAMS.getPreferredName(), + LENGTH.getPreferredName(), + length); + } this.custom = custom; } @@ -293,6 +300,9 @@ public class NGram implements LenientlyParsedPreProcessor, StrictlyParsedPreProc for (int nGram : nGrams) { totalNgrams += (length - (nGram - 1)); } + if (totalNgrams <= 0) { + return Collections.emptyList(); + } List ngramOutputs = new ArrayList<>(totalNgrams); for (int nGram : nGrams) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java index ade6a659f88..3103291d8c2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.TreeMap; import java.util.function.Function; import java.util.stream.Collectors; @@ -68,13 +69,13 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars public OneHotEncoding(String field, Map hotMap, Boolean custom) { this.field = ExceptionsHelper.requireNonNull(field, FIELD); - this.hotMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP)); - this.custom = custom == null ? false : custom; + this.hotMap = Collections.unmodifiableMap(new TreeMap<>(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP))); + this.custom = custom != null && custom; } public OneHotEncoding(StreamInput in) throws IOException { this.field = in.readString(); - this.hotMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString)); + this.hotMap = Collections.unmodifiableMap(new TreeMap<>(in.readMap(StreamInput::readString, StreamInput::readString))); if (in.getVersion().onOrAfter(Version.V_7_10_0)) { this.custom = in.readBoolean(); } else { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGramTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGramTests.java index 9fe21a4fa5c..fb141fdc8f7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGramTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGramTests.java @@ -42,11 +42,12 @@ public class NGramTests extends PreProcessingTests { } public static NGram createRandom(Boolean isCustom) { + int possibleLength = randomIntBetween(1, 10); return new NGram( randomAlphaOfLength(10), - IntStream.generate(() -> randomIntBetween(1, 5)).limit(5).boxed().collect(Collectors.toList()), + IntStream.generate(() -> randomIntBetween(1, Math.min(possibleLength, 5))).limit(5).boxed().collect(Collectors.toList()), randomBoolean() ? null : randomIntBetween(0, 10), - randomBoolean() ? null : randomIntBetween(1, 10), + randomBoolean() ? null : possibleLength, isCustom, randomBoolean() ? null : randomAlphaOfLength(10)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessingTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessingTests.java index c4e8b879bcd..269caa26207 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessingTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessingTests.java @@ -9,10 +9,12 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.hamcrest.Matcher; import org.junit.Before; +import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.function.Predicate; +import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester; import static org.hamcrest.Matchers.equalTo; public abstract class PreProcessingTests extends AbstractSerializingTestCase { @@ -41,6 +43,22 @@ public abstract class PreProcessingTests extends Abstrac ); } + public void testInputOutputFieldOrderConsistency() throws IOException { + xContentTester(this::createParser, this::createXContextTestInstance, getToXContentParams(), this::doParseInstance) + .numberOfTestRuns(NUMBER_OF_TEST_RUNS) + .supportsUnknownFields(supportsUnknownFields()) + .shuffleFieldsExceptions(getShuffleFieldsExceptions()) + .randomFieldsExcludeFilter(getRandomFieldsExcludeFilter()) + .assertEqualsConsumer(this::assertFieldConsistency) + .assertToXContentEquivalence(false) + .test(); + } + + private void assertFieldConsistency(T lft, T rgt) { + assertThat(lft.inputFields(), equalTo(rgt.inputFields())); + assertThat(lft.outputFields(), equalTo(rgt.outputFields())); + } + public void testWithMissingField() { Map fields = randomFieldValues(); PreProcessor preProcessor = this.createTestInstance(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index defd071bfc9..336df670969 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -352,7 +352,7 @@ public class DataFrameDataExtractor { return ExtractedFieldsDetector.getCategoricalOutputFields(context.extractedFields, analysis); } - private static boolean isValidValue(Object value) { + public static boolean isValidValue(Object value) { // We should allow a number, string or a boolean. // It is possible for a field to be categorical and have a `keyword` mapping, but be any of these // three types, in the same index. diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ProcessedField.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ProcessedField.java index 50f13f94086..3102a3e14b5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ProcessedField.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ProcessedField.java @@ -16,6 +16,8 @@ import java.util.Objects; import java.util.Set; import java.util.function.Function; +import static org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor.isValidValue; + public class ProcessedField { private final PreProcessor preProcessor; @@ -36,8 +38,9 @@ public class ProcessedField { } public Object[] value(SearchHit hit, Function fieldExtractor) { - Map inputs = new HashMap<>(preProcessor.inputFields().size(), 1.0f); - for (String field : preProcessor.inputFields()) { + List inputFields = getInputFieldNames(); + Map inputs = new HashMap<>(inputFields.size(), 1.0f); + for (String field : inputFields) { ExtractedField extractedField = fieldExtractor.apply(field); if (extractedField == null) { return new Object[0]; @@ -47,7 +50,7 @@ public class ProcessedField { continue; } final Object value = values[0]; - if (values.length == 1 && (value instanceof String || value instanceof Number)) { + if (values.length == 1 && (isValidValue(value))) { inputs.put(field, value); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ProcessedFieldTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ProcessedFieldTests.java index 48604833f08..98ce397626e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ProcessedFieldTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ProcessedFieldTests.java @@ -5,16 +5,20 @@ */ package org.elasticsearch.xpack.ml.extractor; +import org.elasticsearch.common.collect.MapBuilder; import org.elasticsearch.search.SearchHit; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding; import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import java.util.Arrays; import java.util.Collections; -import java.util.function.Function; -import java.util.stream.Collectors; +import java.util.LinkedHashMap; +import java.util.Map; import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.emptyArray; @@ -30,7 +34,7 @@ public class ProcessedFieldTests extends ESTestCase { public void testOneHotGetters() { String inputField = "foo"; - ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz")); + ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(inputField, "bar", "baz")); assertThat(processedField.getInputFieldNames(), hasItems(inputField)); assertThat(processedField.getOutputFieldNames(), hasItems("bar_column", "baz_column")); assertThat(processedField.getOutputFieldType("bar_column"), equalTo(Collections.singleton("integer"))); @@ -39,28 +43,92 @@ public class ProcessedFieldTests extends ESTestCase { } public void testMissingExtractor() { - String inputField = "foo"; - ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz")); + ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "baz")); assertThat(processedField.value(makeHit(), (s) -> null), emptyArray()); } public void testMissingInputValues() { - String inputField = "foo"; ExtractedField extractedField = makeExtractedField(new Object[0]); - ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz")); + ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "baz")); assertThat(processedField.value(makeHit(), (s) -> extractedField), arrayContaining(is(nullValue()), is(nullValue()))); } - public void testProcessedField() { - ProcessedField processedField = new ProcessedField(makePreProcessor("foo", "bar", "baz")); - assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "bar" })), arrayContaining(1, 0)); - assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "baz" })), arrayContaining(0, 1)); + public void testProcessedFieldFrequencyEncoding() { + testProcessedField( + new FrequencyEncoding(randomAlphaOfLength(10), + randomAlphaOfLength(10), + MapBuilder.newMapBuilder().put("bar", 1.0).put("1", 0.5).put("false", 0.0).map(), + randomBoolean()), + new Object[]{"bar", 1, false}, + new Object[][]{ + new Object[]{1.0}, + new Object[]{0.5}, + new Object[]{0.0}, + }); } - private static PreProcessor makePreProcessor(String inputField, String... expectedExtractedValues) { - return new OneHotEncoding(inputField, - Arrays.stream(expectedExtractedValues).collect(Collectors.toMap(Function.identity(), (s) -> s + "_column")), - true); + public void testProcessedFieldTargetMeanEncoding() { + testProcessedField( + new TargetMeanEncoding(randomAlphaOfLength(10), + randomAlphaOfLength(10), + MapBuilder.newMapBuilder().put("bar", 1.0).put("1", 0.5).put("false", 0.0).map(), + 0.8, + randomBoolean()), + new Object[]{"bar", 1, false, "unknown"}, + new Object[][]{ + new Object[]{1.0}, + new Object[]{0.5}, + new Object[]{0.0}, + new Object[]{0.8}, + }); + } + + public void testProcessedFieldNGramEncoding() { + testProcessedField( + new NGram(randomAlphaOfLength(10), + randomAlphaOfLength(10), + new int[]{1}, + 0, + 3, + randomBoolean()), + new Object[]{"bar", 1, false}, + new Object[][]{ + new Object[]{"b", "a", "r"}, + new Object[]{"1", null, null}, + new Object[]{"f", "a", "l"} + }); + } + + public void testProcessedFieldOneHot() { + testProcessedField( + makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "1", "false"), + new Object[]{"bar", 1, false}, + new Object[][]{ + new Object[]{0, 1, 0}, + new Object[]{1, 0, 0}, + new Object[]{0, 0, 1}, + }); + } + + public void testProcessedField(PreProcessor preProcessor, Object[] inputs, Object[][] expectedOutputs) { + ProcessedField processedField = new ProcessedField(preProcessor); + assert inputs.length == expectedOutputs.length; + for (int i = 0; i < inputs.length; i++) { + Object input = inputs[i]; + Object[] result = processedField.value(makeHit(input), (s) -> makeExtractedField(new Object[] { input })); + assertThat( + "Input [" + input + "] Expected " + Arrays.toString(expectedOutputs[i]) + " but received " + Arrays.toString(result), + result, + equalTo(expectedOutputs[i])); + } + } + + private static PreProcessor makeOneHotPreProcessor(String inputField, String... expectedExtractedValues) { + Map map = new LinkedHashMap<>(); + for (String v : expectedExtractedValues) { + map.put(v, v + "_column"); + } + return new OneHotEncoding(inputField, map,true); } private static ExtractedField makeExtractedField(Object[] value) { @@ -70,7 +138,11 @@ public class ProcessedFieldTests extends ESTestCase { } private static SearchHit makeHit() { - return new SearchHitBuilder(42).addField("a_keyword", "bar").build(); + return makeHit("bar"); + } + + private static SearchHit makeHit(Object value) { + return new SearchHitBuilder(42).addField("a_keyword", value).build(); } }