[ML] fix custom feature processor extraction bugs around boolean fields and custom one_hot feature output order (#64937) (#65009)
This commit fixes two problems: - When extracting a doc value, we allow boolean scalars to be used as input - The output order of processed feature names is deterministic. Previous custom one hot fields used to be non-deterministic and thus could cause weird bugs.
This commit is contained in:
parent
e40d7e02ea
commit
b888f36388
|
@ -134,6 +134,13 @@ public class NGram implements LenientlyParsedPreProcessor, StrictlyParsedPreProc
|
|||
if (length > MAX_LENGTH) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be not be greater than [{}]", LENGTH.getPreferredName(), MAX_LENGTH);
|
||||
}
|
||||
if (Arrays.stream(this.nGrams).anyMatch(i -> i > length)) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] and [{}] are invalid; all ngrams must be shorter than or equal to length [{}]",
|
||||
NGRAMS.getPreferredName(),
|
||||
LENGTH.getPreferredName(),
|
||||
length);
|
||||
}
|
||||
this.custom = custom;
|
||||
}
|
||||
|
||||
|
@ -293,6 +300,9 @@ public class NGram implements LenientlyParsedPreProcessor, StrictlyParsedPreProc
|
|||
for (int nGram : nGrams) {
|
||||
totalNgrams += (length - (nGram - 1));
|
||||
}
|
||||
if (totalNgrams <= 0) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
List<String> ngramOutputs = new ArrayList<>(totalNgrams);
|
||||
|
||||
for (int nGram : nGrams) {
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.TreeMap;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
@ -68,13 +69,13 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
|||
|
||||
public OneHotEncoding(String field, Map<String, String> hotMap, Boolean custom) {
|
||||
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
|
||||
this.hotMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP));
|
||||
this.custom = custom == null ? false : custom;
|
||||
this.hotMap = Collections.unmodifiableMap(new TreeMap<>(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP)));
|
||||
this.custom = custom != null && custom;
|
||||
}
|
||||
|
||||
public OneHotEncoding(StreamInput in) throws IOException {
|
||||
this.field = in.readString();
|
||||
this.hotMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString));
|
||||
this.hotMap = Collections.unmodifiableMap(new TreeMap<>(in.readMap(StreamInput::readString, StreamInput::readString)));
|
||||
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
this.custom = in.readBoolean();
|
||||
} else {
|
||||
|
|
|
@ -42,11 +42,12 @@ public class NGramTests extends PreProcessingTests<NGram> {
|
|||
}
|
||||
|
||||
public static NGram createRandom(Boolean isCustom) {
|
||||
int possibleLength = randomIntBetween(1, 10);
|
||||
return new NGram(
|
||||
randomAlphaOfLength(10),
|
||||
IntStream.generate(() -> randomIntBetween(1, 5)).limit(5).boxed().collect(Collectors.toList()),
|
||||
IntStream.generate(() -> randomIntBetween(1, Math.min(possibleLength, 5))).limit(5).boxed().collect(Collectors.toList()),
|
||||
randomBoolean() ? null : randomIntBetween(0, 10),
|
||||
randomBoolean() ? null : randomIntBetween(1, 10),
|
||||
randomBoolean() ? null : possibleLength,
|
||||
isCustom,
|
||||
randomBoolean() ? null : randomAlphaOfLength(10));
|
||||
}
|
||||
|
|
|
@ -9,10 +9,12 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
|
|||
import org.hamcrest.Matcher;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public abstract class PreProcessingTests<T extends PreProcessor> extends AbstractSerializingTestCase<T> {
|
||||
|
@ -41,6 +43,22 @@ public abstract class PreProcessingTests<T extends PreProcessor> extends Abstrac
|
|||
);
|
||||
}
|
||||
|
||||
public void testInputOutputFieldOrderConsistency() throws IOException {
|
||||
xContentTester(this::createParser, this::createXContextTestInstance, getToXContentParams(), this::doParseInstance)
|
||||
.numberOfTestRuns(NUMBER_OF_TEST_RUNS)
|
||||
.supportsUnknownFields(supportsUnknownFields())
|
||||
.shuffleFieldsExceptions(getShuffleFieldsExceptions())
|
||||
.randomFieldsExcludeFilter(getRandomFieldsExcludeFilter())
|
||||
.assertEqualsConsumer(this::assertFieldConsistency)
|
||||
.assertToXContentEquivalence(false)
|
||||
.test();
|
||||
}
|
||||
|
||||
private void assertFieldConsistency(T lft, T rgt) {
|
||||
assertThat(lft.inputFields(), equalTo(rgt.inputFields()));
|
||||
assertThat(lft.outputFields(), equalTo(rgt.outputFields()));
|
||||
}
|
||||
|
||||
public void testWithMissingField() {
|
||||
Map<String, Object> fields = randomFieldValues();
|
||||
PreProcessor preProcessor = this.createTestInstance();
|
||||
|
|
|
@ -352,7 +352,7 @@ public class DataFrameDataExtractor {
|
|||
return ExtractedFieldsDetector.getCategoricalOutputFields(context.extractedFields, analysis);
|
||||
}
|
||||
|
||||
private static boolean isValidValue(Object value) {
|
||||
public static boolean isValidValue(Object value) {
|
||||
// We should allow a number, string or a boolean.
|
||||
// It is possible for a field to be categorical and have a `keyword` mapping, but be any of these
|
||||
// three types, in the same index.
|
||||
|
|
|
@ -16,6 +16,8 @@ import java.util.Objects;
|
|||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor.isValidValue;
|
||||
|
||||
public class ProcessedField {
|
||||
private final PreProcessor preProcessor;
|
||||
|
||||
|
@ -36,8 +38,9 @@ public class ProcessedField {
|
|||
}
|
||||
|
||||
public Object[] value(SearchHit hit, Function<String, ExtractedField> fieldExtractor) {
|
||||
Map<String, Object> inputs = new HashMap<>(preProcessor.inputFields().size(), 1.0f);
|
||||
for (String field : preProcessor.inputFields()) {
|
||||
List<String> inputFields = getInputFieldNames();
|
||||
Map<String, Object> inputs = new HashMap<>(inputFields.size(), 1.0f);
|
||||
for (String field : inputFields) {
|
||||
ExtractedField extractedField = fieldExtractor.apply(field);
|
||||
if (extractedField == null) {
|
||||
return new Object[0];
|
||||
|
@ -47,7 +50,7 @@ public class ProcessedField {
|
|||
continue;
|
||||
}
|
||||
final Object value = values[0];
|
||||
if (values.length == 1 && (value instanceof String || value instanceof Number)) {
|
||||
if (values.length == 1 && (isValidValue(value))) {
|
||||
inputs.put(field, value);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,16 +5,20 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.extractor;
|
||||
|
||||
import org.elasticsearch.common.collect.MapBuilder;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.arrayContaining;
|
||||
import static org.hamcrest.Matchers.emptyArray;
|
||||
|
@ -30,7 +34,7 @@ public class ProcessedFieldTests extends ESTestCase {
|
|||
|
||||
public void testOneHotGetters() {
|
||||
String inputField = "foo";
|
||||
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
|
||||
ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(inputField, "bar", "baz"));
|
||||
assertThat(processedField.getInputFieldNames(), hasItems(inputField));
|
||||
assertThat(processedField.getOutputFieldNames(), hasItems("bar_column", "baz_column"));
|
||||
assertThat(processedField.getOutputFieldType("bar_column"), equalTo(Collections.singleton("integer")));
|
||||
|
@ -39,28 +43,92 @@ public class ProcessedFieldTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testMissingExtractor() {
|
||||
String inputField = "foo";
|
||||
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
|
||||
ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "baz"));
|
||||
assertThat(processedField.value(makeHit(), (s) -> null), emptyArray());
|
||||
}
|
||||
|
||||
public void testMissingInputValues() {
|
||||
String inputField = "foo";
|
||||
ExtractedField extractedField = makeExtractedField(new Object[0]);
|
||||
ProcessedField processedField = new ProcessedField(makePreProcessor(inputField, "bar", "baz"));
|
||||
ProcessedField processedField = new ProcessedField(makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "baz"));
|
||||
assertThat(processedField.value(makeHit(), (s) -> extractedField), arrayContaining(is(nullValue()), is(nullValue())));
|
||||
}
|
||||
|
||||
public void testProcessedField() {
|
||||
ProcessedField processedField = new ProcessedField(makePreProcessor("foo", "bar", "baz"));
|
||||
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "bar" })), arrayContaining(1, 0));
|
||||
assertThat(processedField.value(makeHit(), (s) -> makeExtractedField(new Object[] { "baz" })), arrayContaining(0, 1));
|
||||
public void testProcessedFieldFrequencyEncoding() {
|
||||
testProcessedField(
|
||||
new FrequencyEncoding(randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
MapBuilder.<String, Double>newMapBuilder().put("bar", 1.0).put("1", 0.5).put("false", 0.0).map(),
|
||||
randomBoolean()),
|
||||
new Object[]{"bar", 1, false},
|
||||
new Object[][]{
|
||||
new Object[]{1.0},
|
||||
new Object[]{0.5},
|
||||
new Object[]{0.0},
|
||||
});
|
||||
}
|
||||
|
||||
private static PreProcessor makePreProcessor(String inputField, String... expectedExtractedValues) {
|
||||
return new OneHotEncoding(inputField,
|
||||
Arrays.stream(expectedExtractedValues).collect(Collectors.toMap(Function.identity(), (s) -> s + "_column")),
|
||||
true);
|
||||
public void testProcessedFieldTargetMeanEncoding() {
|
||||
testProcessedField(
|
||||
new TargetMeanEncoding(randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
MapBuilder.<String, Double>newMapBuilder().put("bar", 1.0).put("1", 0.5).put("false", 0.0).map(),
|
||||
0.8,
|
||||
randomBoolean()),
|
||||
new Object[]{"bar", 1, false, "unknown"},
|
||||
new Object[][]{
|
||||
new Object[]{1.0},
|
||||
new Object[]{0.5},
|
||||
new Object[]{0.0},
|
||||
new Object[]{0.8},
|
||||
});
|
||||
}
|
||||
|
||||
public void testProcessedFieldNGramEncoding() {
|
||||
testProcessedField(
|
||||
new NGram(randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
new int[]{1},
|
||||
0,
|
||||
3,
|
||||
randomBoolean()),
|
||||
new Object[]{"bar", 1, false},
|
||||
new Object[][]{
|
||||
new Object[]{"b", "a", "r"},
|
||||
new Object[]{"1", null, null},
|
||||
new Object[]{"f", "a", "l"}
|
||||
});
|
||||
}
|
||||
|
||||
public void testProcessedFieldOneHot() {
|
||||
testProcessedField(
|
||||
makeOneHotPreProcessor(randomAlphaOfLength(10), "bar", "1", "false"),
|
||||
new Object[]{"bar", 1, false},
|
||||
new Object[][]{
|
||||
new Object[]{0, 1, 0},
|
||||
new Object[]{1, 0, 0},
|
||||
new Object[]{0, 0, 1},
|
||||
});
|
||||
}
|
||||
|
||||
public void testProcessedField(PreProcessor preProcessor, Object[] inputs, Object[][] expectedOutputs) {
|
||||
ProcessedField processedField = new ProcessedField(preProcessor);
|
||||
assert inputs.length == expectedOutputs.length;
|
||||
for (int i = 0; i < inputs.length; i++) {
|
||||
Object input = inputs[i];
|
||||
Object[] result = processedField.value(makeHit(input), (s) -> makeExtractedField(new Object[] { input }));
|
||||
assertThat(
|
||||
"Input [" + input + "] Expected " + Arrays.toString(expectedOutputs[i]) + " but received " + Arrays.toString(result),
|
||||
result,
|
||||
equalTo(expectedOutputs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
private static PreProcessor makeOneHotPreProcessor(String inputField, String... expectedExtractedValues) {
|
||||
Map<String, String> map = new LinkedHashMap<>();
|
||||
for (String v : expectedExtractedValues) {
|
||||
map.put(v, v + "_column");
|
||||
}
|
||||
return new OneHotEncoding(inputField, map,true);
|
||||
}
|
||||
|
||||
private static ExtractedField makeExtractedField(Object[] value) {
|
||||
|
@ -70,7 +138,11 @@ public class ProcessedFieldTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private static SearchHit makeHit() {
|
||||
return new SearchHitBuilder(42).addField("a_keyword", "bar").build();
|
||||
return makeHit("bar");
|
||||
}
|
||||
|
||||
private static SearchHit makeHit(Object value) {
|
||||
return new SearchHitBuilder(42).addField("a_keyword", value).build();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue