[ML] fix custom feature processor extraction bugs around boolean fields and custom one_hot feature output order (#64937) (#65009)

This commit fixes two problems:

- When extracting a doc value, we allow boolean scalars to be used as input
- The output order of processed feature names is deterministic. Previous custom one hot fields used to be non-deterministic and thus could cause weird bugs.
This commit is contained in:
Benjamin Trent 2020-11-12 11:15:57 -05:00 committed by GitHub
parent e40d7e02ea
commit b888f36388
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 130 additions and 25 deletions

View File

@ -134,6 +134,13 @@ public class NGram implements LenientlyParsedPreProcessor, StrictlyParsedPreProc
if (length > MAX_LENGTH) {
throw ExceptionsHelper.badRequestException("[{}] must be not be greater than [{}]", LENGTH.getPreferredName(), MAX_LENGTH);
}
if (Arrays.stream(this.nGrams).anyMatch(i -> i > length)) {
throw ExceptionsHelper.badRequestException(
"[{}] and [{}] are invalid; all ngrams must be shorter than or equal to length [{}]",
NGRAMS.getPreferredName(),
LENGTH.getPreferredName(),
length);
}
this.custom = custom;
}
@ -293,6 +300,9 @@ public class NGram implements LenientlyParsedPreProcessor, StrictlyParsedPreProc
for (int nGram : nGrams) {
totalNgrams += (length - (nGram - 1));
}
if (totalNgrams <= 0) {
return Collections.emptyList();
}
List<String> ngramOutputs = new ArrayList<>(totalNgrams);
for (int nGram : nGrams) {

View File

@ -23,6 +23,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
@ -68,13 +69,13 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
public OneHotEncoding(String field, Map<String, String> hotMap, Boolean custom) {
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
this.hotMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP));
this.custom = custom == null ? false : custom;
this.hotMap = Collections.unmodifiableMap(new TreeMap<>(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP)));
this.custom = custom != null && custom;
}
public OneHotEncoding(StreamInput in) throws IOException {
this.field = in.readString();
this.hotMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString));
this.hotMap = Collections.unmodifiableMap(new TreeMap<>(in.readMap(StreamInput::readString, StreamInput::readString)));
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.custom = in.readBoolean();
} else {

View File

@ -42,11 +42,12 @@ public class NGramTests extends PreProcessingTests<NGram> {
}
public static NGram createRandom(Boolean isCustom) {
int possibleLength = randomIntBetween(1, 10);
return new NGram(
randomAlphaOfLength(10),
IntStream.generate(() -> randomIntBetween(1, 5)).limit(5).boxed().collect(Collectors.toList()),
IntStream.generate(() -> randomIntBetween(1, Math.min(possibleLength, 5))).limit(5).boxed().collect(Collectors.toList()),
randomBoolean() ? null : randomIntBetween(0, 10),
randomBoolean() ? null : randomIntBetween(1, 10),
randomBoolean() ? null : possibleLength,
isCustom,
randomBoolean() ? null : randomAlphaOfLength(10));
}

View File

@ -9,10 +9,12 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
import org.hamcrest.Matcher;
import org.junit.Before;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Predicate;
import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester;
import static org.hamcrest.Matchers.equalTo;
public abstract class PreProcessingTests<T extends PreProcessor> extends AbstractSerializingTestCase<T> {
@ -41,6 +43,22 @@ public abstract class PreProcessingTests<T extends PreProcessor> extends Abstrac
);
}
public void testInputOutputFieldOrderConsistency() throws IOException {
xContentTester(this::createParser, this::createXContextTestInstance, getToXContentParams(), this::doParseInstance)
.numberOfTestRuns(NUMBER_OF_TEST_RUNS)
.supportsUnknownFields(supportsUnknownFields())
.shuffleFieldsExceptions(getShuffleFieldsExceptions())
.randomFieldsExcludeFilter(getRandomFieldsExcludeFilter())
.assertEqualsConsumer(this::assertFieldConsistency)
.assertToXContentEquivalence(false)
.test();
}
private void assertFieldConsistency(T lft, T rgt) {
assertThat(lft.inputFields(), equalTo(rgt.inputFields()));
assertThat(lft.outputFields(), equalTo(rgt.outputFields()));
}
public void testWithMissingField() {
Map<String, Object> fields = randomFieldValues();
PreProcessor preProcessor = this.createTestInstance();

View File

@ -352,7 +352,7 @@ public class DataFrameDataExtractor {
return ExtractedFieldsDetector.getCategoricalOutputFields(context.extractedFields, analysis);
}
private static boolean isValidValue(Object value) {
public static boolean isValidValue(Object value) {
// We should allow a number, string or a boolean.
// It is possible for a field to be categorical and have a `keyword` mapping, but be any of these
// three types, in the same index.

View File

@ -16,6 +16,8 @@ import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import static org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor.isValidValue;
public class ProcessedField {
private final PreProcessor preProcessor;
@ -36,8 +38,9 @@ public class ProcessedField {
}
public Object[] value(SearchHit hit, Function<String, ExtractedField> fieldExtractor) {
Map<String, Object> inputs = new HashMap<>(preProcessor.inputFields().size(), 1.0f);
for (String field : preProcessor.inputFields()) {
List<String> inputFields = getInputFieldNames();
Map<String, Object> inputs = new HashMap<>(inputFields.size(), 1.0f);
for (String field : inputFields) {
ExtractedField extractedField = fieldExtractor.apply(field);
if (extractedField == null) {
return new Object[0];
@ -47,7 +50,7 @@ public class ProcessedField {
continue;
}
final Object value = values[0];
if (values.length == 1 && (value instanceof String || value instanceof Number)) {
if (values.length == 1 && (isValidValue(value))) {
inputs.put(field, value);
}
}

View File

@ -5,16 +5,20 @@
*/
package org.elasticsearch.xpack.ml.extractor;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import java.util.Arrays;
import java.util.Collections;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.LinkedHashMap;
import java.util.Map;
import static org.hamcrest.Matchers.arrayContaining;
import static org.hamcrest.Matchers.emptyArray;
@ -30,7 +34,7 @@ public class ProcessedFieldTests extends ESTestCase {
public void testOneHotGetters() {
String inputField = "foo";
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(inputField, "bar", "baz"));
assertThat(processedField.getInputFieldNames(), hasItems(inputField));
assertThat(processedField.getOutputFieldNames(), hasItems("bar_column", "baz_column"));
assertThat(processedField.getOutputFieldType("bar_column"), equalTo(Collections.singleton("integer")));
@ -39,28 +43,92 @@ public class ProcessedFieldTests extends ESTestCase {
}
public void testMissingExtractor() {
String inputField = "foo";
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "baz"));
assertThat(processedField.value(makeHit(), (s) -> null), emptyArray());
}
public void testMissingInputValues() {
String inputField = "foo";
ExtractedField extractedField = makeExtractedField(new Object[0]);
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "baz"));
assertThat(processedField.value(makeHit(), (s) -> extractedField), arrayContaining(is(nullValue()), is(nullValue())));
}
public void testProcessedField() {
ProcessedField processedField = new ProcessedField(makePreProcessor("foo", "bar", "baz"));
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "bar" })), arrayContaining(1, 0));
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "baz" })), arrayContaining(0, 1));
public void testProcessedFieldFrequencyEncoding() {
testProcessedField(
new FrequencyEncoding(randomAlphaOfLength(10),
randomAlphaOfLength(10),
MapBuilder.<String, Double>newMapBuilder().put("bar", 1.0).put("1", 0.5).put("false", 0.0).map(),
randomBoolean()),
new Object[]{"bar", 1, false},
new Object[][]{
new Object[]{1.0},
new Object[]{0.5},
new Object[]{0.0},
});
}
private static PreProcessor makePreProcessor(String inputField, String... expectedExtractedValues) {
return new OneHotEncoding(inputField,
Arrays.stream(expectedExtractedValues).collect(Collectors.toMap(Function.identity(), (s) -> s + "_column")),
true);
public void testProcessedFieldTargetMeanEncoding() {
testProcessedField(
new TargetMeanEncoding(randomAlphaOfLength(10),
randomAlphaOfLength(10),
MapBuilder.<String, Double>newMapBuilder().put("bar", 1.0).put("1", 0.5).put("false", 0.0).map(),
0.8,
randomBoolean()),
new Object[]{"bar", 1, false, "unknown"},
new Object[][]{
new Object[]{1.0},
new Object[]{0.5},
new Object[]{0.0},
new Object[]{0.8},
});
}
public void testProcessedFieldNGramEncoding() {
testProcessedField(
new NGram(randomAlphaOfLength(10),
randomAlphaOfLength(10),
new int[]{1},
0,
3,
randomBoolean()),
new Object[]{"bar", 1, false},
new Object[][]{
new Object[]{"b", "a", "r"},
new Object[]{"1", null, null},
new Object[]{"f", "a", "l"}
});
}
public void testProcessedFieldOneHot() {
testProcessedField(
makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "1", "false"),
new Object[]{"bar", 1, false},
new Object[][]{
new Object[]{0, 1, 0},
new Object[]{1, 0, 0},
new Object[]{0, 0, 1},
});
}
public void testProcessedField(PreProcessor preProcessor, Object[] inputs, Object[][] expectedOutputs) {
ProcessedField processedField = new ProcessedField(preProcessor);
assert inputs.length == expectedOutputs.length;
for (int i = 0; i < inputs.length; i++) {
Object input = inputs[i];
Object[] result = processedField.value(makeHit(input), (s) -> makeExtractedField(new Object[] { input }));
assertThat(
"Input [" + input + "] Expected " + Arrays.toString(expectedOutputs[i]) + " but received " + Arrays.toString(result),
result,
equalTo(expectedOutputs[i]));
}
}
private static PreProcessor makeOneHotPreProcessor(String inputField, String... expectedExtractedValues) {
Map<String, String> map = new LinkedHashMap<>();
for (String v : expectedExtractedValues) {
map.put(v, v + "_column");
}
return new OneHotEncoding(inputField, map,true);
}
private static ExtractedField makeExtractedField(Object[] value) {
@ -70,7 +138,11 @@ public class ProcessedFieldTests extends ESTestCase {
}
private static SearchHit makeHit() {
return new SearchHitBuilder(42).addField("a_keyword", "bar").build();
return makeHit("bar");
}
private static SearchHit makeHit(Object value) {
return new SearchHitBuilder(42).addField("a_keyword", value).build();
}
}