[7.x] [ML] add new `custom` field to trained model processors (#59542) (#59700)

* [ML] add new `custom` field to trained model processors (#59542)

This commit adds the new configurable field `custom`.

`custom` indicates if the preprocessor was submitted by a user or automatically created by the analytics job.

Eventually, this field will be used in calculating feature importance. When `custom` is true, the feature importance for
the processed fields is calculated. When `false` the current behavior is the same (we calculate the importance for the originating field/feature).

This also adds new required methods to the preprocessor interface. If users are to supply their own preprocessors
in the analytics job configuration, we need to know the input and output field names.
This commit is contained in:
Benjamin Trent 2020-07-16 10:57:38 -04:00 committed by GitHub
parent 3a228906a9
commit a28547c4b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 405 additions and 157 deletions

View File

@ -40,18 +40,20 @@ public class FrequencyEncoding implements PreProcessor {
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map");
public static final ParseField CUSTOM = new ParseField("custom");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<FrequencyEncoding, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2]));
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Boolean)a[3]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
PARSER.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
FREQUENCY_MAP);
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
}
public static FrequencyEncoding fromXContent(XContentParser parser) {
@ -61,11 +63,13 @@ public class FrequencyEncoding implements PreProcessor {
private final String field;
private final String featureName;
private final Map<String, Double> frequencyMap;
private final Boolean custom;
public FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap) {
FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap, Boolean custom) {
this.field = Objects.requireNonNull(field);
this.featureName = Objects.requireNonNull(featureName);
this.frequencyMap = Collections.unmodifiableMap(Objects.requireNonNull(frequencyMap));
this.custom = custom;
}
/**
@ -94,12 +98,19 @@ public class FrequencyEncoding implements PreProcessor {
return NAME;
}
public Boolean getCustom() {
return custom;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(FIELD.getPreferredName(), field);
builder.field(FEATURE_NAME.getPreferredName(), featureName);
builder.field(FREQUENCY_MAP.getPreferredName(), frequencyMap);
if (custom != null) {
builder.field(CUSTOM.getPreferredName(), custom);
}
builder.endObject();
return builder;
}
@ -111,12 +122,13 @@ public class FrequencyEncoding implements PreProcessor {
FrequencyEncoding that = (FrequencyEncoding) o;
return Objects.equals(field, that.field)
&& Objects.equals(featureName, that.featureName)
&& Objects.equals(custom, that.custom)
&& Objects.equals(frequencyMap, that.frequencyMap);
}
@Override
public int hashCode() {
return Objects.hash(field, featureName, frequencyMap);
return Objects.hash(field, featureName, frequencyMap, custom);
}
public Builder builder(String field) {
@ -128,6 +140,7 @@ public class FrequencyEncoding implements PreProcessor {
private String field;
private String featureName;
private Map<String, Double> frequencyMap = new HashMap<>();
private Boolean custom;
public Builder(String field) {
this.field = field;
@ -153,8 +166,13 @@ public class FrequencyEncoding implements PreProcessor {
return this;
}
public Builder setCustom(boolean custom) {
this.custom = custom;
return this;
}
public FrequencyEncoding build() {
return new FrequencyEncoding(field, featureName, frequencyMap);
return new FrequencyEncoding(field, featureName, frequencyMap, custom);
}
}

View File

@ -38,15 +38,17 @@ public class OneHotEncoding implements PreProcessor {
public static final String NAME = "one_hot_encoding";
public static final ParseField FIELD = new ParseField("field");
public static final ParseField HOT_MAP = new ParseField("hot_map");
public static final ParseField CUSTOM = new ParseField("custom");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<OneHotEncoding, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1]));
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1], (Boolean)a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP);
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
}
public static OneHotEncoding fromXContent(XContentParser parser) {
@ -55,12 +57,13 @@ public class OneHotEncoding implements PreProcessor {
private final String field;
private final Map<String, String> hotMap;
private final Boolean custom;
public OneHotEncoding(String field, Map<String, String> hotMap) {
OneHotEncoding(String field, Map<String, String> hotMap, Boolean custom) {
this.field = Objects.requireNonNull(field);
this.hotMap = Collections.unmodifiableMap(Objects.requireNonNull(hotMap));
this.custom = custom;
}
/**
* @return Field name on which to one hot encode
*/
@ -80,11 +83,18 @@ public class OneHotEncoding implements PreProcessor {
return NAME;
}
public Boolean getCustom() {
return custom;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(FIELD.getPreferredName(), field);
builder.field(HOT_MAP.getPreferredName(), hotMap);
if (custom != null) {
builder.field(CUSTOM.getPreferredName(), custom);
}
builder.endObject();
return builder;
}
@ -95,12 +105,13 @@ public class OneHotEncoding implements PreProcessor {
if (o == null || getClass() != o.getClass()) return false;
OneHotEncoding that = (OneHotEncoding) o;
return Objects.equals(field, that.field)
&& Objects.equals(hotMap, that.hotMap);
&& Objects.equals(hotMap, that.hotMap)
&& Objects.equals(custom, that.custom);
}
@Override
public int hashCode() {
return Objects.hash(field, hotMap);
return Objects.hash(field, hotMap, custom);
}
public Builder builder(String field) {
@ -111,6 +122,7 @@ public class OneHotEncoding implements PreProcessor {
private String field;
private Map<String, String> hotMap = new HashMap<>();
private Boolean custom;
public Builder(String field) {
this.field = field;
@ -131,8 +143,13 @@ public class OneHotEncoding implements PreProcessor {
return this;
}
public Builder setCustom(boolean custom) {
this.custom = custom;
return this;
}
public OneHotEncoding build() {
return new OneHotEncoding(field, hotMap);
return new OneHotEncoding(field, hotMap, custom);
}
}
}

View File

@ -41,12 +41,13 @@ public class TargetMeanEncoding implements PreProcessor {
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField TARGET_MAP = new ParseField("target_map");
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
public static final ParseField CUSTOM = new ParseField("custom");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<TargetMeanEncoding, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3]));
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3], (Boolean)a[4]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
@ -54,6 +55,7 @@ public class TargetMeanEncoding implements PreProcessor {
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
TARGET_MAP);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
}
public static TargetMeanEncoding fromXContent(XContentParser parser) {
@ -64,12 +66,14 @@ public class TargetMeanEncoding implements PreProcessor {
private final String featureName;
private final Map<String, Double> meanMap;
private final double defaultValue;
private final Boolean custom;
public TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue) {
TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue, Boolean custom) {
this.field = Objects.requireNonNull(field);
this.featureName = Objects.requireNonNull(featureName);
this.meanMap = Collections.unmodifiableMap(Objects.requireNonNull(meanMap));
this.defaultValue = Objects.requireNonNull(defaultValue);
this.custom = custom;
}
/**
@ -100,6 +104,10 @@ public class TargetMeanEncoding implements PreProcessor {
return featureName;
}
public Boolean getCustom() {
return custom;
}
@Override
public String getName() {
return NAME;
@ -112,6 +120,9 @@ public class TargetMeanEncoding implements PreProcessor {
builder.field(FEATURE_NAME.getPreferredName(), featureName);
builder.field(TARGET_MAP.getPreferredName(), meanMap);
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
if (custom != null) {
builder.field(CUSTOM.getPreferredName(), custom);
}
builder.endObject();
return builder;
}
@ -124,12 +135,13 @@ public class TargetMeanEncoding implements PreProcessor {
return Objects.equals(field, that.field)
&& Objects.equals(featureName, that.featureName)
&& Objects.equals(meanMap, that.meanMap)
&& Objects.equals(defaultValue, that.defaultValue);
&& Objects.equals(defaultValue, that.defaultValue)
&& Objects.equals(custom, that.custom);
}
@Override
public int hashCode() {
return Objects.hash(field, featureName, meanMap, defaultValue);
return Objects.hash(field, featureName, meanMap, defaultValue, custom);
}
public Builder builder(String field) {
@ -142,6 +154,7 @@ public class TargetMeanEncoding implements PreProcessor {
private String featureName;
private Map<String, Double> meanMap = new HashMap<>();
private double defaultValue;
private Boolean custom;
public Builder(String field) {
this.field = field;
@ -176,8 +189,13 @@ public class TargetMeanEncoding implements PreProcessor {
return this;
}
public Builder setCustom(boolean custom) {
this.custom = custom;
return this;
}
public TargetMeanEncoding build() {
return new TargetMeanEncoding(field, featureName, meanMap, defaultValue);
return new TargetMeanEncoding(field, featureName, meanMap, defaultValue, custom);
}
}
}

View File

@ -55,6 +55,9 @@ public class FrequencyEncodingTests extends AbstractXContentTestCase<FrequencyEn
for (int i = 0; i < valuesSize; i++) {
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
}
return new FrequencyEncoding(randomAlphaOfLength(10), randomAlphaOfLength(10), valueMap);
return new FrequencyEncoding(randomAlphaOfLength(10),
randomAlphaOfLength(10),
valueMap,
randomBoolean() ? null : randomBoolean());
}
}

View File

@ -55,7 +55,7 @@ public class OneHotEncodingTests extends AbstractXContentTestCase<OneHotEncoding
for (int i = 0; i < valuesSize; i++) {
valueMap.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
}
return new OneHotEncoding(randomAlphaOfLength(10), valueMap);
return new OneHotEncoding(randomAlphaOfLength(10), valueMap, randomBoolean() ? null : randomBoolean());
}
}

View File

@ -58,7 +58,8 @@ public class TargetMeanEncodingTests extends AbstractXContentTestCase<TargetMean
return new TargetMeanEncoding(randomAlphaOfLength(10),
randomAlphaOfLength(10),
valueMap,
randomDoubleBetween(0.0, 1.0, false));
randomDoubleBetween(0.0, 1.0, false),
randomBoolean() ? null : randomBoolean());
}
}

View File

@ -94,6 +94,10 @@ The field name to encode.
`frequency_map`::
(Required, object map of string:double)
Object that maps the field value to the frequency encoded value.
`custom`::
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=custom-preprocessor]
======
//End frequency encoding
@ -112,6 +116,10 @@ The field name to encode.
`hot_map`::
(Required, object map of strings)
String map of "field_value: one_hot_column_name".
`custom`::
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=custom-preprocessor]
======
//End one hot encoding
@ -138,6 +146,10 @@ The field name to encode.
`target_map`:::
(Required, object map of string:double)
Object that maps the field value to the target mean value.
`custom`::
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=custom-preprocessor]
======
//End target mean encoding
=====

View File

@ -291,6 +291,15 @@ Specifies whether the feature influence calculation is enabled. Defaults to
`true`.
end::compute-feature-influence[]
tag::custom-preprocessor[]
(Optional, boolean)
Boolean value indicating if the analytics job created the preprocessor
or if a user provided it. This adjusts the feature importance calculation.
When `true`, the feature importance calculation returns importance for the
processed feature. When `false`, the total importance of the original field
is returned. Default is `false`.
end::custom-preprocessor[]
tag::custom-rules[]
An array of custom rule objects, which enable you to customize the way detectors
operate. For example, a rule may dictate to the detector conditions under which

View File

@ -220,6 +220,16 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
return data[row * colDim + col];
}
@Override
public List<String> inputFields() {
return Collections.singletonList(fieldName);
}
@Override
public List<String> outputFields() {
return Collections.singletonList(destField);
}
@Override
public void process(Map<String, Object> fields) {
Object field = fields.get(fieldName);
@ -241,6 +251,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
return Collections.singletonMap(destField, fieldName);
}
@Override
public boolean isCustom() {
return false;
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
@ -18,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -33,6 +35,7 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map");
public static final ParseField CUSTOM = new ParseField("custom");
public static final ConstructingObjectParser<FrequencyEncoding, Void> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<FrequencyEncoding, Void> LENIENT_PARSER = createParser(true);
@ -42,12 +45,13 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
ConstructingObjectParser<FrequencyEncoding, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
lenient,
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2]));
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Boolean)a[3]));
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
parser.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
FREQUENCY_MAP);
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
return parser;
}
@ -62,17 +66,24 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
private final String field;
private final String featureName;
private final Map<String, Double> frequencyMap;
private final boolean custom;
public FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap) {
public FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap, Boolean custom) {
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
this.featureName = ExceptionsHelper.requireNonNull(featureName, FEATURE_NAME);
this.frequencyMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(frequencyMap, FREQUENCY_MAP));
this.custom = custom == null ? false : custom;
}
public FrequencyEncoding(StreamInput in) throws IOException {
this.field = in.readString();
this.featureName = in.readString();
this.frequencyMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readDouble));
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.custom = in.readBoolean();
} else {
this.custom = false;
}
}
/**
@ -101,11 +112,26 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
return Collections.singletonMap(featureName, field);
}
@Override
public boolean isCustom() {
return custom;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public List<String> inputFields() {
return Collections.singletonList(field);
}
@Override
public List<String> outputFields() {
return Collections.singletonList(featureName);
}
@Override
public void process(Map<String, Object> fields) {
Object value = fields.get(field);
@ -125,6 +151,9 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
out.writeString(field);
out.writeString(featureName);
out.writeMap(frequencyMap, StreamOutput::writeString, StreamOutput::writeDouble);
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeBoolean(custom);
}
}
@Override
@ -133,6 +162,7 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
builder.field(FIELD.getPreferredName(), field);
builder.field(FEATURE_NAME.getPreferredName(), featureName);
builder.field(FREQUENCY_MAP.getPreferredName(), frequencyMap);
builder.field(CUSTOM.getPreferredName(), custom);
builder.endObject();
return builder;
}
@ -144,12 +174,13 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
FrequencyEncoding that = (FrequencyEncoding) o;
return Objects.equals(field, that.field)
&& Objects.equals(featureName, that.featureName)
&& Objects.equals(frequencyMap, that.frequencyMap);
&& Objects.equals(frequencyMap, that.frequencyMap)
&& Objects.equals(custom, that.custom);
}
@Override
public int hashCode() {
return Objects.hash(field, featureName, frequencyMap);
return Objects.hash(field, featureName, frequencyMap, custom);
}
@Override

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
@ -16,10 +17,12 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
@ -31,6 +34,7 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
public static final ParseField NAME = new ParseField("one_hot_encoding");
public static final ParseField FIELD = new ParseField("field");
public static final ParseField HOT_MAP = new ParseField("hot_map");
public static final ParseField CUSTOM = new ParseField("custom");
public static final ConstructingObjectParser<OneHotEncoding, Void> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<OneHotEncoding, Void> LENIENT_PARSER = createParser(true);
@ -40,9 +44,10 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
ConstructingObjectParser<OneHotEncoding, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
lenient,
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1]));
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1], (Boolean)a[2]));
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP);
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
return parser;
}
@ -56,15 +61,22 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
private final String field;
private final Map<String, String> hotMap;
private final boolean custom;
public OneHotEncoding(String field, Map<String, String> hotMap) {
public OneHotEncoding(String field, Map<String, String> hotMap, Boolean custom) {
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
this.hotMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP));
this.custom = custom == null ? false : custom;
}
public OneHotEncoding(StreamInput in) throws IOException {
this.field = in.readString();
this.hotMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString));
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.custom = in.readBoolean();
} else {
this.custom = false;
}
}
/**
@ -83,7 +95,12 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
@Override
public Map<String, String> reverseLookup() {
return hotMap.entrySet().stream().collect(Collectors.toMap(HashMap.Entry::getValue, (entry) -> field));
return hotMap.values().stream().collect(Collectors.toMap(Function.identity(), (value) -> field));
}
@Override
public boolean isCustom() {
return custom;
}
@Override
@ -91,6 +108,16 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
return NAME.getPreferredName();
}
@Override
public List<String> inputFields() {
return Collections.singletonList(field);
}
@Override
public List<String> outputFields() {
return new ArrayList<>(hotMap.values());
}
@Override
public void process(Map<String, Object> fields) {
Object value = fields.get(field);
@ -112,6 +139,9 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
public void writeTo(StreamOutput out) throws IOException {
out.writeString(field);
out.writeMap(hotMap, StreamOutput::writeString, StreamOutput::writeString);
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeBoolean(custom);
}
}
@Override
@ -119,6 +149,7 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
builder.startObject();
builder.field(FIELD.getPreferredName(), field);
builder.field(HOT_MAP.getPreferredName(), hotMap);
builder.field(CUSTOM.getPreferredName(), custom);
builder.endObject();
return builder;
}
@ -129,12 +160,13 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
if (o == null || getClass() != o.getClass()) return false;
OneHotEncoding that = (OneHotEncoding) o;
return Objects.equals(field, that.field)
&& Objects.equals(hotMap, that.hotMap);
&& Objects.equals(hotMap, that.hotMap)
&& Objects.equals(custom, that.custom);
}
@Override
public int hashCode() {
return Objects.hash(field, hotMap);
return Objects.hash(field, hotMap, custom);
}
@Override

View File

@ -9,6 +9,7 @@ import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
import java.util.List;
import java.util.Map;
/**
@ -17,6 +18,16 @@ import java.util.Map;
*/
public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable {
/**
* The expected input fields
*/
List<String> inputFields();
/**
* @return The resulting output fields
*/
List<String> outputFields();
/**
* Process the given fields and their values and return the modified map.
*
@ -29,4 +40,12 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
* @return Reverse lookup map to match resulting features to their original feature name
*/
Map<String, String> reverseLookup();
/**
* @return Is the pre-processor a custom one provided by the user, or automatically created?
* This changes how feature importance is calculated, as fields generated by custom processors get individual feature
* importance calculations.
*/
boolean isCustom();
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
@ -18,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -33,6 +35,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField TARGET_MAP = new ParseField("target_map");
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
public static final ParseField CUSTOM = new ParseField("custom");
public static final ConstructingObjectParser<TargetMeanEncoding, Void> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<TargetMeanEncoding, Void> LENIENT_PARSER = createParser(true);
@ -42,13 +45,14 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
ConstructingObjectParser<TargetMeanEncoding, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
lenient,
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3]));
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3], (Boolean)a[4]));
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
parser.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
TARGET_MAP);
parser.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
return parser;
}
@ -64,12 +68,14 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
private final String featureName;
private final Map<String, Double> meanMap;
private final double defaultValue;
private final boolean custom;
public TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue) {
public TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue, Boolean custom) {
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
this.featureName = ExceptionsHelper.requireNonNull(featureName, FEATURE_NAME);
this.meanMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(meanMap, TARGET_MAP));
this.defaultValue = ExceptionsHelper.requireNonNull(defaultValue, DEFAULT_VALUE);
this.custom = custom == null ? false : custom;
}
public TargetMeanEncoding(StreamInput in) throws IOException {
@ -77,6 +83,11 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
this.featureName = in.readString();
this.meanMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readDouble));
this.defaultValue = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.custom = in.readBoolean();
} else {
this.custom = false;
}
}
/**
@ -112,11 +123,26 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
return Collections.singletonMap(featureName, field);
}
@Override
public boolean isCustom() {
return custom;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public List<String> inputFields() {
return Collections.singletonList(field);
}
@Override
public List<String> outputFields() {
return Collections.singletonList(featureName);
}
@Override
public void process(Map<String, Object> fields) {
Object value = fields.get(field);
@ -137,6 +163,9 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
out.writeString(featureName);
out.writeMap(meanMap, StreamOutput::writeString, StreamOutput::writeDouble);
out.writeDouble(defaultValue);
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeBoolean(custom);
}
}
@Override
@ -146,6 +175,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
builder.field(FEATURE_NAME.getPreferredName(), featureName);
builder.field(TARGET_MAP.getPreferredName(), meanMap);
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
builder.field(CUSTOM.getPreferredName(), custom);
builder.endObject();
return builder;
}
@ -158,12 +188,13 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
return Objects.equals(field, that.field)
&& Objects.equals(featureName, that.featureName)
&& Objects.equals(meanMap, that.meanMap)
&& Objects.equals(defaultValue, that.defaultValue);
&& Objects.equals(defaultValue, that.defaultValue)
&& Objects.equals(custom, that.custom);
}
@Override
public int hashCode() {
return Objects.hash(field, featureName, meanMap, defaultValue);
return Objects.hash(field, featureName, meanMap, defaultValue, custom);
}
@Override

View File

@ -17,6 +17,7 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding> {
@ -37,7 +38,10 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
for (int i = 0; i < valuesSize; i++) {
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
}
return new FrequencyEncoding(randomAlphaOfLength(10), randomAlphaOfLength(10), valueMap);
return new FrequencyEncoding(randomAlphaOfLength(10),
randomAlphaOfLength(10),
valueMap,
randomBoolean() ? null : randomBoolean());
}
@Override
@ -51,7 +55,7 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
v -> randomDoubleBetween(0.0, 1.0, false)));
String encodedFeatureName = "encoded";
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap);
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap, false);
Object fieldValue = randomFrom(values);
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName,
equalTo(valueMap.get(fieldValue.toString())));
@ -65,4 +69,15 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
testProcess(encoding, fieldValues, matchers);
}
public void testInputOutputFields() {
String field = randomAlphaOfLength(10);
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
v -> randomDoubleBetween(0.0, 1.0, false)));
String encodedFeatureName = randomAlphaOfLength(10);
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap, false);
assertThat(encoding.inputFields(), containsInAnyOrder(field));
assertThat(encoding.outputFields(), containsInAnyOrder(encodedFeatureName));
}
}

View File

@ -17,6 +17,7 @@ import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
@ -37,7 +38,9 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
for (int i = 0; i < valuesSize; i++) {
valueMap.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
}
return new OneHotEncoding(randomAlphaOfLength(10), valueMap);
return new OneHotEncoding(randomAlphaOfLength(10),
valueMap,
randomBoolean() ? randomBoolean() : null);
}
@Override
@ -49,7 +52,7 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
String field = "categorical";
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.0);
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> "Column_" + v.toString()));
OneHotEncoding encoding = new OneHotEncoding(field, valueMap);
OneHotEncoding encoding = new OneHotEncoding(field, valueMap, false);
Object fieldValue = randomFrom(values);
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
@ -67,4 +70,14 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
testProcess(encoding, fieldValues, matchers);
}
public void testInputOutputFields() {
String field = randomAlphaOfLength(10);
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.0);
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> "Column_" + v.toString()));
OneHotEncoding encoding = new OneHotEncoding(field, valueMap, false);
assertThat(encoding.inputFields(), containsInAnyOrder(field));
assertThat(encoding.outputFields(),
containsInAnyOrder(values.stream().map(v -> "Column_" + v.toString()).toArray(String[]::new)));
}
}

View File

@ -17,6 +17,7 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncoding> {
@ -40,7 +41,8 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
return new TargetMeanEncoding(randomAlphaOfLength(10),
randomAlphaOfLength(10),
valueMap,
randomDoubleBetween(0.0, 1.0, false));
randomDoubleBetween(0.0, 1.0, false),
randomBoolean() ? randomBoolean() : null);
}
@Override
@ -55,7 +57,7 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
v -> randomDoubleBetween(0.0, 1.0, false)));
String encodedFeatureName = "encoded";
Double defaultvalue = randomDouble();
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue);
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue, false);
Object fieldValue = randomFrom(values);
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName,
equalTo(valueMap.get(fieldValue.toString())));
@ -68,4 +70,16 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
testProcess(encoding, fieldValues, matchers);
}
public void testInputOutputFields() {
String field = randomAlphaOfLength(10);
String encodedFeatureName = randomAlphaOfLength(10);
Double defaultvalue = randomDouble();
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.0);
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
v -> randomDoubleBetween(0.0, 1.0, false)));
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue, false);
assertThat(encoding.inputFields(), containsInAnyOrder(field));
assertThat(encoding.outputFields(), containsInAnyOrder(encodedFeatureName));
}
}

View File

@ -74,7 +74,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2)
.setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
.setParsedDefinition(new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding, false)))
.setTrainedModel(buildClassification(true)))
.setVersion(Version.CURRENT)
.setLicenseLevel(License.OperationMode.PLATINUM.description())
@ -85,7 +85,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1)
.setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
.setParsedDefinition(new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding, false)))
.setTrainedModel(buildRegression()))
.setVersion(Version.CURRENT)
.setEstimatedOperations(0)
@ -203,7 +203,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId)
.setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
.setParsedDefinition(new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding, false)))
.setTrainedModel(buildMultiClassClassification()))
.setVersion(Version.CURRENT)
.setLicenseLevel(License.OperationMode.PLATINUM.description())
@ -320,7 +320,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId)
.setInput(new TrainedModelInput(Arrays.asList("field1", "field2")))
.setParsedDefinition(new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding, false)))
.setTrainedModel(buildRegression()))
.setVersion(Version.CURRENT)
.setEstimatedOperations(0)

View File

@ -67,7 +67,7 @@ public class LocalModelTests extends ESTestCase {
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
InferenceDefinition definition = InferenceDefinition.builder()
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap())))
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap(), false)))
.setTrainedModel(buildClassificationInference(false))
.build();
@ -99,7 +99,7 @@ public class LocalModelTests extends ESTestCase {
// Test with labels
definition = InferenceDefinition.builder()
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap())))
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap(), false)))
.setTrainedModel(buildClassificationInference(true))
.build();
model = new LocalModel(modelId,
@ -142,7 +142,7 @@ public class LocalModelTests extends ESTestCase {
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical");
InferenceDefinition definition = InferenceDefinition.builder()
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap())))
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap(), false)))
.setTrainedModel(buildClassificationInference(true))
.build();
@ -200,7 +200,7 @@ public class LocalModelTests extends ESTestCase {
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
InferenceDefinition trainedModelDefinition = InferenceDefinition.builder()
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap())))
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap(), false)))
.setTrainedModel(buildRegressionInference())
.build();
LocalModel model = new LocalModel("regression_model",
@ -228,7 +228,7 @@ public class LocalModelTests extends ESTestCase {
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
InferenceDefinition trainedModelDefinition = InferenceDefinition.builder()
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap())))
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap(), false)))
.setTrainedModel(buildRegressionInference())
.build();
LocalModel model = new LocalModel(
@ -260,7 +260,7 @@ public class LocalModelTests extends ESTestCase {
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
InferenceDefinition definition = InferenceDefinition.builder()
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap())))
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap(), false)))
.setTrainedModel(buildClassificationInference(false))
.build();