[ML] fix custom feature processor extraction bugs around boolean fields and custom one_hot feature output order (#64937) (#65009)

This commit fixes two problems:

- When extracting a doc value, we allow boolean scalars to be used as input
- The output order of processed feature names is deterministic. Previous custom one hot fields used to be non-deterministic and thus could cause weird bugs.
This commit is contained in:
Benjamin Trent 2020-11-12 11:15:57 -05:00 committed by GitHub
parent e40d7e02ea
commit b888f36388
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 130 additions and 25 deletions

View File

@ -134,6 +134,13 @@ public class NGram implements LenientlyParsedPreProcessor, StrictlyParsedPreProc
if (length > MAX_LENGTH) { if (length > MAX_LENGTH) {
throw ExceptionsHelper.badRequestException("[{}] must be not be greater than [{}]", LENGTH.getPreferredName(), MAX_LENGTH); throw ExceptionsHelper.badRequestException("[{}] must be not be greater than [{}]", LENGTH.getPreferredName(), MAX_LENGTH);
} }
if (Arrays.stream(this.nGrams).anyMatch(i -> i > length)) {
throw ExceptionsHelper.badRequestException(
"[{}] and [{}] are invalid; all ngrams must be shorter than or equal to length [{}]",
NGRAMS.getPreferredName(),
LENGTH.getPreferredName(),
length);
}
this.custom = custom; this.custom = custom;
} }
@ -293,6 +300,9 @@ public class NGram implements LenientlyParsedPreProcessor, StrictlyParsedPreProc
for (int nGram : nGrams) { for (int nGram : nGrams) {
totalNgrams += (length - (nGram - 1)); totalNgrams += (length - (nGram - 1));
} }
if (totalNgrams <= 0) {
return Collections.emptyList();
}
List<String> ngramOutputs = new ArrayList<>(totalNgrams); List<String> ngramOutputs = new ArrayList<>(totalNgrams);
for (int nGram : nGrams) { for (int nGram : nGrams) {

View File

@ -23,6 +23,7 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.TreeMap;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -68,13 +69,13 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
public OneHotEncoding(String field, Map<String, String> hotMap, Boolean custom) { public OneHotEncoding(String field, Map<String, String> hotMap, Boolean custom) {
this.field = ExceptionsHelper.requireNonNull(field, FIELD); this.field = ExceptionsHelper.requireNonNull(field, FIELD);
this.hotMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP)); this.hotMap = Collections.unmodifiableMap(new TreeMap<>(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP)));
this.custom = custom == null ? false : custom; this.custom = custom != null && custom;
} }
public OneHotEncoding(StreamInput in) throws IOException { public OneHotEncoding(StreamInput in) throws IOException {
this.field = in.readString(); this.field = in.readString();
this.hotMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString)); this.hotMap = Collections.unmodifiableMap(new TreeMap<>(in.readMap(StreamInput::readString, StreamInput::readString)));
if (in.getVersion().onOrAfter(Version.V_7_10_0)) { if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.custom = in.readBoolean(); this.custom = in.readBoolean();
} else { } else {

View File

@ -42,11 +42,12 @@ public class NGramTests extends PreProcessingTests<NGram> {
} }
public static NGram createRandom(Boolean isCustom) { public static NGram createRandom(Boolean isCustom) {
int possibleLength = randomIntBetween(1, 10);
return new NGram( return new NGram(
randomAlphaOfLength(10), randomAlphaOfLength(10),
IntStream.generate(() -> randomIntBetween(1, 5)).limit(5).boxed().collect(Collectors.toList()), IntStream.generate(() -> randomIntBetween(1, Math.min(possibleLength, 5))).limit(5).boxed().collect(Collectors.toList()),
randomBoolean() ? null : randomIntBetween(0, 10), randomBoolean() ? null : randomIntBetween(0, 10),
randomBoolean() ? null : randomIntBetween(1, 10), randomBoolean() ? null : possibleLength,
isCustom, isCustom,
randomBoolean() ? null : randomAlphaOfLength(10)); randomBoolean() ? null : randomAlphaOfLength(10));
} }

View File

@ -9,10 +9,12 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
import org.hamcrest.Matcher; import org.hamcrest.Matcher;
import org.junit.Before; import org.junit.Before;
import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.function.Predicate; import java.util.function.Predicate;
import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
public abstract class PreProcessingTests<T extends PreProcessor> extends AbstractSerializingTestCase<T> { public abstract class PreProcessingTests<T extends PreProcessor> extends AbstractSerializingTestCase<T> {
@ -41,6 +43,22 @@ public abstract class PreProcessingTests<T extends PreProcessor> extends Abstrac
); );
} }
public void testInputOutputFieldOrderConsistency() throws IOException {
xContentTester(this::createParser, this::createXContextTestInstance, getToXContentParams(), this::doParseInstance)
.numberOfTestRuns(NUMBER_OF_TEST_RUNS)
.supportsUnknownFields(supportsUnknownFields())
.shuffleFieldsExceptions(getShuffleFieldsExceptions())
.randomFieldsExcludeFilter(getRandomFieldsExcludeFilter())
.assertEqualsConsumer(this::assertFieldConsistency)
.assertToXContentEquivalence(false)
.test();
}
private void assertFieldConsistency(T lft, T rgt) {
assertThat(lft.inputFields(), equalTo(rgt.inputFields()));
assertThat(lft.outputFields(), equalTo(rgt.outputFields()));
}
public void testWithMissingField() { public void testWithMissingField() {
Map<String, Object> fields = randomFieldValues(); Map<String, Object> fields = randomFieldValues();
PreProcessor preProcessor = this.createTestInstance(); PreProcessor preProcessor = this.createTestInstance();

View File

@ -352,7 +352,7 @@ public class DataFrameDataExtractor {
return ExtractedFieldsDetector.getCategoricalOutputFields(context.extractedFields, analysis); return ExtractedFieldsDetector.getCategoricalOutputFields(context.extractedFields, analysis);
} }
private static boolean isValidValue(Object value) { public static boolean isValidValue(Object value) {
// We should allow a number, string or a boolean. // We should allow a number, string or a boolean.
// It is possible for a field to be categorical and have a `keyword` mapping, but be any of these // It is possible for a field to be categorical and have a `keyword` mapping, but be any of these
// three types, in the same index. // three types, in the same index.

View File

@ -16,6 +16,8 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.function.Function; import java.util.function.Function;
import static org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor.isValidValue;
public class ProcessedField { public class ProcessedField {
private final PreProcessor preProcessor; private final PreProcessor preProcessor;
@ -36,8 +38,9 @@ public class ProcessedField {
} }
public Object[] value(SearchHit hit, Function<String, ExtractedField> fieldExtractor) { public Object[] value(SearchHit hit, Function<String, ExtractedField> fieldExtractor) {
Map<String, Object> inputs = new HashMap<>(preProcessor.inputFields().size(), 1.0f); List<String> inputFields = getInputFieldNames();
for (String field : preProcessor.inputFields()) { Map<String, Object> inputs = new HashMap<>(inputFields.size(), 1.0f);
for (String field : inputFields) {
ExtractedField extractedField = fieldExtractor.apply(field); ExtractedField extractedField = fieldExtractor.apply(field);
if (extractedField == null) { if (extractedField == null) {
return new Object[0]; return new Object[0];
@ -47,7 +50,7 @@ public class ProcessedField {
continue; continue;
} }
final Object value = values[0]; final Object value = values[0];
if (values.length == 1 && (value instanceof String || value instanceof Number)) { if (values.length == 1 && (isValidValue(value))) {
inputs.put(field, value); inputs.put(field, value);
} }
} }

View File

@ -5,16 +5,20 @@
*/ */
package org.elasticsearch.xpack.ml.extractor; package org.elasticsearch.xpack.ml.extractor;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.function.Function; import java.util.LinkedHashMap;
import java.util.stream.Collectors; import java.util.Map;
import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.arrayContaining;
import static org.hamcrest.Matchers.emptyArray; import static org.hamcrest.Matchers.emptyArray;
@ -30,7 +34,7 @@ public class ProcessedFieldTests extends ESTestCase {
public void testOneHotGetters() { public void testOneHotGetters() {
String inputField = "foo"; String inputField = "foo";
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz")); ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(inputField, "bar", "baz"));
assertThat(processedField.getInputFieldNames(), hasItems(inputField)); assertThat(processedField.getInputFieldNames(), hasItems(inputField));
assertThat(processedField.getOutputFieldNames(), hasItems("bar_column", "baz_column")); assertThat(processedField.getOutputFieldNames(), hasItems("bar_column", "baz_column"));
assertThat(processedField.getOutputFieldType("bar_column"), equalTo(Collections.singleton("integer"))); assertThat(processedField.getOutputFieldType("bar_column"), equalTo(Collections.singleton("integer")));
@ -39,28 +43,92 @@ public class ProcessedFieldTests extends ESTestCase {
} }
public void testMissingExtractor() { public void testMissingExtractor() {
String inputField = "foo"; ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "baz"));
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
assertThat(processedField.value(makeHit(), (s) -> null), emptyArray()); assertThat(processedField.value(makeHit(), (s) -> null), emptyArray());
} }
public void testMissingInputValues() { public void testMissingInputValues() {
String inputField = "foo";
ExtractedField extractedField = makeExtractedField(new Object[0]); ExtractedField extractedField = makeExtractedField(new Object[0]);
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz")); ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "baz"));
assertThat(processedField.value(makeHit(), (s) -> extractedField), arrayContaining(is(nullValue()), is(nullValue()))); assertThat(processedField.value(makeHit(), (s) -> extractedField), arrayContaining(is(nullValue()), is(nullValue())));
} }
public void testProcessedField() { public void testProcessedFieldFrequencyEncoding() {
ProcessedField processedField = new ProcessedField(makePreProcessor("foo", "bar", "baz")); testProcessedField(
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "bar" })), arrayContaining(1, 0)); new FrequencyEncoding(randomAlphaOfLength(10),
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "baz" })), arrayContaining(0, 1)); randomAlphaOfLength(10),
MapBuilder.<String, Double>newMapBuilder().put("bar", 1.0).put("1", 0.5).put("false", 0.0).map(),
randomBoolean()),
new Object[]{"bar", 1, false},
new Object[][]{
new Object[]{1.0},
new Object[]{0.5},
new Object[]{0.0},
});
} }
private static PreProcessor makePreProcessor(String inputField, String... expectedExtractedValues) { public void testProcessedFieldTargetMeanEncoding() {
return new OneHotEncoding(inputField, testProcessedField(
Arrays.stream(expectedExtractedValues).collect(Collectors.toMap(Function.identity(), (s) -> s + "_column")), new TargetMeanEncoding(randomAlphaOfLength(10),
true); randomAlphaOfLength(10),
MapBuilder.<String, Double>newMapBuilder().put("bar", 1.0).put("1", 0.5).put("false", 0.0).map(),
0.8,
randomBoolean()),
new Object[]{"bar", 1, false, "unknown"},
new Object[][]{
new Object[]{1.0},
new Object[]{0.5},
new Object[]{0.0},
new Object[]{0.8},
});
}
public void testProcessedFieldNGramEncoding() {
testProcessedField(
new NGram(randomAlphaOfLength(10),
randomAlphaOfLength(10),
new int[]{1},
0,
3,
randomBoolean()),
new Object[]{"bar", 1, false},
new Object[][]{
new Object[]{"b", "a", "r"},
new Object[]{"1", null, null},
new Object[]{"f", "a", "l"}
});
}
public void testProcessedFieldOneHot() {
testProcessedField(
makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "1", "false"),
new Object[]{"bar", 1, false},
new Object[][]{
new Object[]{0, 1, 0},
new Object[]{1, 0, 0},
new Object[]{0, 0, 1},
});
}
public void testProcessedField(PreProcessor preProcessor, Object[] inputs, Object[][] expectedOutputs) {
ProcessedField processedField = new ProcessedField(preProcessor);
assert inputs.length == expectedOutputs.length;
for (int i = 0; i < inputs.length; i++) {
Object input = inputs[i];
Object[] result = processedField.value(makeHit(input), (s) -> makeExtractedField(new Object[] { input }));
assertThat(
"Input [" + input + "] Expected " + Arrays.toString(expectedOutputs[i]) + " but received " + Arrays.toString(result),
result,
equalTo(expectedOutputs[i]));
}
}
private static PreProcessor makeOneHotPreProcessor(String inputField, String... expectedExtractedValues) {
Map<String, String> map = new LinkedHashMap<>();
for (String v : expectedExtractedValues) {
map.put(v, v + "_column");
}
return new OneHotEncoding(inputField, map,true);
} }
private static ExtractedField makeExtractedField(Object[] value) { private static ExtractedField makeExtractedField(Object[] value) {
@ -70,7 +138,11 @@ public class ProcessedFieldTests extends ESTestCase {
} }
private static SearchHit makeHit() { private static SearchHit makeHit() {
return new SearchHitBuilder(42).addField("a_keyword", "bar").build(); return makeHit("bar");
}
private static SearchHit makeHit(Object value) {
return new SearchHitBuilder(42).addField("a_keyword", value).build();
} }
} }