* [ML] add new `custom` field to trained model processors (#59542) This commit adds the new configurable field `custom`. `custom` indicates if the preprocessor was submitted by a user or automatically created by the analytics job. Eventually, this field will be used in calculating feature importance. When `custom` is true, the feature importance for the processed fields is calculated. When `false` the current behavior is the same (we calculate the importance for the originating field/feature). This also adds new required methods to the preprocessor interface. If users are to supply their own preprocessors in the analytics job configuration, we need to know the input and output field names.
This commit is contained in:
parent
3a228906a9
commit
a28547c4b4
|
@ -40,18 +40,20 @@ public class FrequencyEncoding implements PreProcessor {
|
|||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map");
|
||||
public static final ParseField CUSTOM = new ParseField("custom");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<FrequencyEncoding, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2]));
|
||||
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Boolean)a[3]));
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
||||
FREQUENCY_MAP);
|
||||
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
|
||||
}
|
||||
|
||||
public static FrequencyEncoding fromXContent(XContentParser parser) {
|
||||
|
@ -61,11 +63,13 @@ public class FrequencyEncoding implements PreProcessor {
|
|||
private final String field;
|
||||
private final String featureName;
|
||||
private final Map<String, Double> frequencyMap;
|
||||
private final Boolean custom;
|
||||
|
||||
public FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap) {
|
||||
FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap, Boolean custom) {
|
||||
this.field = Objects.requireNonNull(field);
|
||||
this.featureName = Objects.requireNonNull(featureName);
|
||||
this.frequencyMap = Collections.unmodifiableMap(Objects.requireNonNull(frequencyMap));
|
||||
this.custom = custom;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -94,12 +98,19 @@ public class FrequencyEncoding implements PreProcessor {
|
|||
return NAME;
|
||||
}
|
||||
|
||||
public Boolean getCustom() {
|
||||
return custom;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD.getPreferredName(), field);
|
||||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||
builder.field(FREQUENCY_MAP.getPreferredName(), frequencyMap);
|
||||
if (custom != null) {
|
||||
builder.field(CUSTOM.getPreferredName(), custom);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -111,12 +122,13 @@ public class FrequencyEncoding implements PreProcessor {
|
|||
FrequencyEncoding that = (FrequencyEncoding) o;
|
||||
return Objects.equals(field, that.field)
|
||||
&& Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(custom, that.custom)
|
||||
&& Objects.equals(frequencyMap, that.frequencyMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, featureName, frequencyMap);
|
||||
return Objects.hash(field, featureName, frequencyMap, custom);
|
||||
}
|
||||
|
||||
public Builder builder(String field) {
|
||||
|
@ -128,6 +140,7 @@ public class FrequencyEncoding implements PreProcessor {
|
|||
private String field;
|
||||
private String featureName;
|
||||
private Map<String, Double> frequencyMap = new HashMap<>();
|
||||
private Boolean custom;
|
||||
|
||||
public Builder(String field) {
|
||||
this.field = field;
|
||||
|
@ -153,8 +166,13 @@ public class FrequencyEncoding implements PreProcessor {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setCustom(boolean custom) {
|
||||
this.custom = custom;
|
||||
return this;
|
||||
}
|
||||
|
||||
public FrequencyEncoding build() {
|
||||
return new FrequencyEncoding(field, featureName, frequencyMap);
|
||||
return new FrequencyEncoding(field, featureName, frequencyMap, custom);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -38,15 +38,17 @@ public class OneHotEncoding implements PreProcessor {
|
|||
public static final String NAME = "one_hot_encoding";
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField HOT_MAP = new ParseField("hot_map");
|
||||
public static final ParseField CUSTOM = new ParseField("custom");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<OneHotEncoding, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1]));
|
||||
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1], (Boolean)a[2]));
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP);
|
||||
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
|
||||
}
|
||||
|
||||
public static OneHotEncoding fromXContent(XContentParser parser) {
|
||||
|
@ -55,12 +57,13 @@ public class OneHotEncoding implements PreProcessor {
|
|||
|
||||
private final String field;
|
||||
private final Map<String, String> hotMap;
|
||||
private final Boolean custom;
|
||||
|
||||
public OneHotEncoding(String field, Map<String, String> hotMap) {
|
||||
OneHotEncoding(String field, Map<String, String> hotMap, Boolean custom) {
|
||||
this.field = Objects.requireNonNull(field);
|
||||
this.hotMap = Collections.unmodifiableMap(Objects.requireNonNull(hotMap));
|
||||
this.custom = custom;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Field name on which to one hot encode
|
||||
*/
|
||||
|
@ -80,11 +83,18 @@ public class OneHotEncoding implements PreProcessor {
|
|||
return NAME;
|
||||
}
|
||||
|
||||
public Boolean getCustom() {
|
||||
return custom;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD.getPreferredName(), field);
|
||||
builder.field(HOT_MAP.getPreferredName(), hotMap);
|
||||
if (custom != null) {
|
||||
builder.field(CUSTOM.getPreferredName(), custom);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -95,12 +105,13 @@ public class OneHotEncoding implements PreProcessor {
|
|||
if (o == null || getClass() != o.getClass()) return false;
|
||||
OneHotEncoding that = (OneHotEncoding) o;
|
||||
return Objects.equals(field, that.field)
|
||||
&& Objects.equals(hotMap, that.hotMap);
|
||||
&& Objects.equals(hotMap, that.hotMap)
|
||||
&& Objects.equals(custom, that.custom);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, hotMap);
|
||||
return Objects.hash(field, hotMap, custom);
|
||||
}
|
||||
|
||||
public Builder builder(String field) {
|
||||
|
@ -111,6 +122,7 @@ public class OneHotEncoding implements PreProcessor {
|
|||
|
||||
private String field;
|
||||
private Map<String, String> hotMap = new HashMap<>();
|
||||
private Boolean custom;
|
||||
|
||||
public Builder(String field) {
|
||||
this.field = field;
|
||||
|
@ -131,8 +143,13 @@ public class OneHotEncoding implements PreProcessor {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setCustom(boolean custom) {
|
||||
this.custom = custom;
|
||||
return this;
|
||||
}
|
||||
|
||||
public OneHotEncoding build() {
|
||||
return new OneHotEncoding(field, hotMap);
|
||||
return new OneHotEncoding(field, hotMap, custom);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,12 +41,13 @@ public class TargetMeanEncoding implements PreProcessor {
|
|||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
public static final ParseField TARGET_MAP = new ParseField("target_map");
|
||||
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
|
||||
public static final ParseField CUSTOM = new ParseField("custom");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<TargetMeanEncoding, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3]));
|
||||
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3], (Boolean)a[4]));
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
|
@ -54,6 +55,7 @@ public class TargetMeanEncoding implements PreProcessor {
|
|||
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
||||
TARGET_MAP);
|
||||
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
|
||||
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
|
||||
}
|
||||
|
||||
public static TargetMeanEncoding fromXContent(XContentParser parser) {
|
||||
|
@ -64,12 +66,14 @@ public class TargetMeanEncoding implements PreProcessor {
|
|||
private final String featureName;
|
||||
private final Map<String, Double> meanMap;
|
||||
private final double defaultValue;
|
||||
private final Boolean custom;
|
||||
|
||||
public TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue) {
|
||||
TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue, Boolean custom) {
|
||||
this.field = Objects.requireNonNull(field);
|
||||
this.featureName = Objects.requireNonNull(featureName);
|
||||
this.meanMap = Collections.unmodifiableMap(Objects.requireNonNull(meanMap));
|
||||
this.defaultValue = Objects.requireNonNull(defaultValue);
|
||||
this.custom = custom;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -100,6 +104,10 @@ public class TargetMeanEncoding implements PreProcessor {
|
|||
return featureName;
|
||||
}
|
||||
|
||||
public Boolean getCustom() {
|
||||
return custom;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
|
@ -112,6 +120,9 @@ public class TargetMeanEncoding implements PreProcessor {
|
|||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||
builder.field(TARGET_MAP.getPreferredName(), meanMap);
|
||||
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
|
||||
if (custom != null) {
|
||||
builder.field(CUSTOM.getPreferredName(), custom);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -124,12 +135,13 @@ public class TargetMeanEncoding implements PreProcessor {
|
|||
return Objects.equals(field, that.field)
|
||||
&& Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(meanMap, that.meanMap)
|
||||
&& Objects.equals(defaultValue, that.defaultValue);
|
||||
&& Objects.equals(defaultValue, that.defaultValue)
|
||||
&& Objects.equals(custom, that.custom);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, featureName, meanMap, defaultValue);
|
||||
return Objects.hash(field, featureName, meanMap, defaultValue, custom);
|
||||
}
|
||||
|
||||
public Builder builder(String field) {
|
||||
|
@ -142,6 +154,7 @@ public class TargetMeanEncoding implements PreProcessor {
|
|||
private String featureName;
|
||||
private Map<String, Double> meanMap = new HashMap<>();
|
||||
private double defaultValue;
|
||||
private Boolean custom;
|
||||
|
||||
public Builder(String field) {
|
||||
this.field = field;
|
||||
|
@ -176,8 +189,13 @@ public class TargetMeanEncoding implements PreProcessor {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setCustom(boolean custom) {
|
||||
this.custom = custom;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TargetMeanEncoding build() {
|
||||
return new TargetMeanEncoding(field, featureName, meanMap, defaultValue);
|
||||
return new TargetMeanEncoding(field, featureName, meanMap, defaultValue, custom);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,6 +55,9 @@ public class FrequencyEncodingTests extends AbstractXContentTestCase<FrequencyEn
|
|||
for (int i = 0; i < valuesSize; i++) {
|
||||
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
|
||||
}
|
||||
return new FrequencyEncoding(randomAlphaOfLength(10), randomAlphaOfLength(10), valueMap);
|
||||
return new FrequencyEncoding(randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
valueMap,
|
||||
randomBoolean() ? null : randomBoolean());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,7 +55,7 @@ public class OneHotEncodingTests extends AbstractXContentTestCase<OneHotEncoding
|
|||
for (int i = 0; i < valuesSize; i++) {
|
||||
valueMap.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
|
||||
}
|
||||
return new OneHotEncoding(randomAlphaOfLength(10), valueMap);
|
||||
return new OneHotEncoding(randomAlphaOfLength(10), valueMap, randomBoolean() ? null : randomBoolean());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -58,7 +58,8 @@ public class TargetMeanEncodingTests extends AbstractXContentTestCase<TargetMean
|
|||
return new TargetMeanEncoding(randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
valueMap,
|
||||
randomDoubleBetween(0.0, 1.0, false));
|
||||
randomDoubleBetween(0.0, 1.0, false),
|
||||
randomBoolean() ? null : randomBoolean());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -94,6 +94,10 @@ The field name to encode.
|
|||
`frequency_map`::
|
||||
(Required, object map of string:double)
|
||||
Object that maps the field value to the frequency encoded value.
|
||||
|
||||
`custom`::
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=custom-preprocessor]
|
||||
|
||||
======
|
||||
//End frequency encoding
|
||||
|
||||
|
@ -112,6 +116,10 @@ The field name to encode.
|
|||
`hot_map`::
|
||||
(Required, object map of strings)
|
||||
String map of "field_value: one_hot_column_name".
|
||||
|
||||
`custom`::
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=custom-preprocessor]
|
||||
|
||||
======
|
||||
//End one hot encoding
|
||||
|
||||
|
@ -138,6 +146,10 @@ The field name to encode.
|
|||
`target_map`:::
|
||||
(Required, object map of string:double)
|
||||
Object that maps the field value to the target mean value.
|
||||
|
||||
`custom`::
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=custom-preprocessor]
|
||||
|
||||
======
|
||||
//End target mean encoding
|
||||
=====
|
||||
|
|
|
@ -291,6 +291,15 @@ Specifies whether the feature influence calculation is enabled. Defaults to
|
|||
`true`.
|
||||
end::compute-feature-influence[]
|
||||
|
||||
tag::custom-preprocessor[]
|
||||
(Optional, boolean)
|
||||
Boolean value indicating if the analytics job created the preprocessor
|
||||
or if a user provided it. This adjusts the feature importance calculation.
|
||||
When `true`, the feature importance calculation returns importance for the
|
||||
processed feature. When `false`, the total importance of the original field
|
||||
is returned. Default is `false`.
|
||||
end::custom-preprocessor[]
|
||||
|
||||
tag::custom-rules[]
|
||||
An array of custom rule objects, which enable you to customize the way detectors
|
||||
operate. For example, a rule may dictate to the detector conditions under which
|
||||
|
|
|
@ -220,6 +220,16 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
|
|||
return data[row * colDim + col];
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> inputFields() {
|
||||
return Collections.singletonList(fieldName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> outputFields() {
|
||||
return Collections.singletonList(destField);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Map<String, Object> fields) {
|
||||
Object field = fields.get(fieldName);
|
||||
|
@ -241,6 +251,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
|
|||
return Collections.singletonMap(destField, fieldName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCustom() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long size = SHALLOW_SIZE;
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -18,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
|
@ -33,6 +35,7 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
|||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map");
|
||||
public static final ParseField CUSTOM = new ParseField("custom");
|
||||
|
||||
public static final ConstructingObjectParser<FrequencyEncoding, Void> STRICT_PARSER = createParser(false);
|
||||
public static final ConstructingObjectParser<FrequencyEncoding, Void> LENIENT_PARSER = createParser(true);
|
||||
|
@ -42,12 +45,13 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
|||
ConstructingObjectParser<FrequencyEncoding, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2]));
|
||||
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Boolean)a[3]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
||||
FREQUENCY_MAP);
|
||||
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -62,17 +66,24 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
|||
private final String field;
|
||||
private final String featureName;
|
||||
private final Map<String, Double> frequencyMap;
|
||||
private final boolean custom;
|
||||
|
||||
public FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap) {
|
||||
public FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap, Boolean custom) {
|
||||
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
|
||||
this.featureName = ExceptionsHelper.requireNonNull(featureName, FEATURE_NAME);
|
||||
this.frequencyMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(frequencyMap, FREQUENCY_MAP));
|
||||
this.custom = custom == null ? false : custom;
|
||||
}
|
||||
|
||||
public FrequencyEncoding(StreamInput in) throws IOException {
|
||||
this.field = in.readString();
|
||||
this.featureName = in.readString();
|
||||
this.frequencyMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readDouble));
|
||||
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
this.custom = in.readBoolean();
|
||||
} else {
|
||||
this.custom = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -101,11 +112,26 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
|||
return Collections.singletonMap(featureName, field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCustom() {
|
||||
return custom;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> inputFields() {
|
||||
return Collections.singletonList(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> outputFields() {
|
||||
return Collections.singletonList(featureName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Map<String, Object> fields) {
|
||||
Object value = fields.get(field);
|
||||
|
@ -125,6 +151,9 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
|||
out.writeString(field);
|
||||
out.writeString(featureName);
|
||||
out.writeMap(frequencyMap, StreamOutput::writeString, StreamOutput::writeDouble);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeBoolean(custom);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -133,6 +162,7 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
|||
builder.field(FIELD.getPreferredName(), field);
|
||||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||
builder.field(FREQUENCY_MAP.getPreferredName(), frequencyMap);
|
||||
builder.field(CUSTOM.getPreferredName(), custom);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -144,12 +174,13 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
|
|||
FrequencyEncoding that = (FrequencyEncoding) o;
|
||||
return Objects.equals(field, that.field)
|
||||
&& Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(frequencyMap, that.frequencyMap);
|
||||
&& Objects.equals(frequencyMap, that.frequencyMap)
|
||||
&& Objects.equals(custom, that.custom);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, featureName, frequencyMap);
|
||||
return Objects.hash(field, featureName, frequencyMap, custom);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -16,10 +17,12 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
|
@ -31,6 +34,7 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
|||
public static final ParseField NAME = new ParseField("one_hot_encoding");
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField HOT_MAP = new ParseField("hot_map");
|
||||
public static final ParseField CUSTOM = new ParseField("custom");
|
||||
|
||||
public static final ConstructingObjectParser<OneHotEncoding, Void> STRICT_PARSER = createParser(false);
|
||||
public static final ConstructingObjectParser<OneHotEncoding, Void> LENIENT_PARSER = createParser(true);
|
||||
|
@ -40,9 +44,10 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
|||
ConstructingObjectParser<OneHotEncoding, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1]));
|
||||
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1], (Boolean)a[2]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP);
|
||||
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -56,15 +61,22 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
|||
|
||||
private final String field;
|
||||
private final Map<String, String> hotMap;
|
||||
private final boolean custom;
|
||||
|
||||
public OneHotEncoding(String field, Map<String, String> hotMap) {
|
||||
public OneHotEncoding(String field, Map<String, String> hotMap, Boolean custom) {
|
||||
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
|
||||
this.hotMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP));
|
||||
this.custom = custom == null ? false : custom;
|
||||
}
|
||||
|
||||
public OneHotEncoding(StreamInput in) throws IOException {
|
||||
this.field = in.readString();
|
||||
this.hotMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString));
|
||||
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
this.custom = in.readBoolean();
|
||||
} else {
|
||||
this.custom = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -83,7 +95,12 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
|||
|
||||
@Override
|
||||
public Map<String, String> reverseLookup() {
|
||||
return hotMap.entrySet().stream().collect(Collectors.toMap(HashMap.Entry::getValue, (entry) -> field));
|
||||
return hotMap.values().stream().collect(Collectors.toMap(Function.identity(), (value) -> field));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCustom() {
|
||||
return custom;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -91,6 +108,16 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> inputFields() {
|
||||
return Collections.singletonList(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> outputFields() {
|
||||
return new ArrayList<>(hotMap.values());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Map<String, Object> fields) {
|
||||
Object value = fields.get(field);
|
||||
|
@ -112,6 +139,9 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
|||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(field);
|
||||
out.writeMap(hotMap, StreamOutput::writeString, StreamOutput::writeString);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeBoolean(custom);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -119,6 +149,7 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
|||
builder.startObject();
|
||||
builder.field(FIELD.getPreferredName(), field);
|
||||
builder.field(HOT_MAP.getPreferredName(), hotMap);
|
||||
builder.field(CUSTOM.getPreferredName(), custom);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -129,12 +160,13 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
|
|||
if (o == null || getClass() != o.getClass()) return false;
|
||||
OneHotEncoding that = (OneHotEncoding) o;
|
||||
return Objects.equals(field, that.field)
|
||||
&& Objects.equals(hotMap, that.hotMap);
|
||||
&& Objects.equals(hotMap, that.hotMap)
|
||||
&& Objects.equals(custom, that.custom);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, hotMap);
|
||||
return Objects.hash(field, hotMap, custom);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -9,6 +9,7 @@ import org.apache.lucene.util.Accountable;
|
|||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
|
@ -17,6 +18,16 @@ import java.util.Map;
|
|||
*/
|
||||
public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable {
|
||||
|
||||
/**
|
||||
* The expected input fields
|
||||
*/
|
||||
List<String> inputFields();
|
||||
|
||||
/**
|
||||
* @return The resulting output fields
|
||||
*/
|
||||
List<String> outputFields();
|
||||
|
||||
/**
|
||||
* Process the given fields and their values and return the modified map.
|
||||
*
|
||||
|
@ -29,4 +40,12 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
|
|||
* @return Reverse lookup map to match resulting features to their original feature name
|
||||
*/
|
||||
Map<String, String> reverseLookup();
|
||||
|
||||
/**
|
||||
* @return Is the pre-processor a custom one provided by the user, or automatically created?
|
||||
* This changes how feature importance is calculated, as fields generated by custom processors get individual feature
|
||||
* importance calculations.
|
||||
*/
|
||||
boolean isCustom();
|
||||
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -18,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
|
@ -33,6 +35,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
public static final ParseField TARGET_MAP = new ParseField("target_map");
|
||||
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
|
||||
public static final ParseField CUSTOM = new ParseField("custom");
|
||||
|
||||
public static final ConstructingObjectParser<TargetMeanEncoding, Void> STRICT_PARSER = createParser(false);
|
||||
public static final ConstructingObjectParser<TargetMeanEncoding, Void> LENIENT_PARSER = createParser(true);
|
||||
|
@ -42,13 +45,14 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
ConstructingObjectParser<TargetMeanEncoding, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3]));
|
||||
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3], (Boolean)a[4]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
||||
TARGET_MAP);
|
||||
parser.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
|
||||
parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -64,12 +68,14 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
private final String featureName;
|
||||
private final Map<String, Double> meanMap;
|
||||
private final double defaultValue;
|
||||
private final boolean custom;
|
||||
|
||||
public TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue) {
|
||||
public TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue, Boolean custom) {
|
||||
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
|
||||
this.featureName = ExceptionsHelper.requireNonNull(featureName, FEATURE_NAME);
|
||||
this.meanMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(meanMap, TARGET_MAP));
|
||||
this.defaultValue = ExceptionsHelper.requireNonNull(defaultValue, DEFAULT_VALUE);
|
||||
this.custom = custom == null ? false : custom;
|
||||
}
|
||||
|
||||
public TargetMeanEncoding(StreamInput in) throws IOException {
|
||||
|
@ -77,6 +83,11 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
this.featureName = in.readString();
|
||||
this.meanMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readDouble));
|
||||
this.defaultValue = in.readDouble();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
this.custom = in.readBoolean();
|
||||
} else {
|
||||
this.custom = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -112,11 +123,26 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
return Collections.singletonMap(featureName, field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCustom() {
|
||||
return custom;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> inputFields() {
|
||||
return Collections.singletonList(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> outputFields() {
|
||||
return Collections.singletonList(featureName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Map<String, Object> fields) {
|
||||
Object value = fields.get(field);
|
||||
|
@ -137,6 +163,9 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
out.writeString(featureName);
|
||||
out.writeMap(meanMap, StreamOutput::writeString, StreamOutput::writeDouble);
|
||||
out.writeDouble(defaultValue);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeBoolean(custom);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -146,6 +175,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||
builder.field(TARGET_MAP.getPreferredName(), meanMap);
|
||||
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
|
||||
builder.field(CUSTOM.getPreferredName(), custom);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -158,12 +188,13 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
|
|||
return Objects.equals(field, that.field)
|
||||
&& Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(meanMap, that.meanMap)
|
||||
&& Objects.equals(defaultValue, that.defaultValue);
|
||||
&& Objects.equals(defaultValue, that.defaultValue)
|
||||
&& Objects.equals(custom, that.custom);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, featureName, meanMap, defaultValue);
|
||||
return Objects.hash(field, featureName, meanMap, defaultValue, custom);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -17,6 +17,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding> {
|
||||
|
@ -37,7 +38,10 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
|
|||
for (int i = 0; i < valuesSize; i++) {
|
||||
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
|
||||
}
|
||||
return new FrequencyEncoding(randomAlphaOfLength(10), randomAlphaOfLength(10), valueMap);
|
||||
return new FrequencyEncoding(randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
valueMap,
|
||||
randomBoolean() ? null : randomBoolean());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -51,7 +55,7 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
|
|||
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
|
||||
v -> randomDoubleBetween(0.0, 1.0, false)));
|
||||
String encodedFeatureName = "encoded";
|
||||
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap);
|
||||
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap, false);
|
||||
Object fieldValue = randomFrom(values);
|
||||
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName,
|
||||
equalTo(valueMap.get(fieldValue.toString())));
|
||||
|
@ -65,4 +69,15 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
|
|||
testProcess(encoding, fieldValues, matchers);
|
||||
}
|
||||
|
||||
public void testInputOutputFields() {
|
||||
String field = randomAlphaOfLength(10);
|
||||
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
|
||||
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
|
||||
v -> randomDoubleBetween(0.0, 1.0, false)));
|
||||
String encodedFeatureName = randomAlphaOfLength(10);
|
||||
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap, false);
|
||||
assertThat(encoding.inputFields(), containsInAnyOrder(field));
|
||||
assertThat(encoding.outputFields(), containsInAnyOrder(encodedFeatureName));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ import java.util.Map;
|
|||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
|
||||
|
@ -37,7 +38,9 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
|
|||
for (int i = 0; i < valuesSize; i++) {
|
||||
valueMap.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
|
||||
}
|
||||
return new OneHotEncoding(randomAlphaOfLength(10), valueMap);
|
||||
return new OneHotEncoding(randomAlphaOfLength(10),
|
||||
valueMap,
|
||||
randomBoolean() ? randomBoolean() : null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -49,7 +52,7 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
|
|||
String field = "categorical";
|
||||
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.0);
|
||||
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> "Column_" + v.toString()));
|
||||
OneHotEncoding encoding = new OneHotEncoding(field, valueMap);
|
||||
OneHotEncoding encoding = new OneHotEncoding(field, valueMap, false);
|
||||
Object fieldValue = randomFrom(values);
|
||||
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
|
||||
|
||||
|
@ -67,4 +70,14 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
|
|||
testProcess(encoding, fieldValues, matchers);
|
||||
}
|
||||
|
||||
public void testInputOutputFields() {
|
||||
String field = randomAlphaOfLength(10);
|
||||
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.0);
|
||||
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> "Column_" + v.toString()));
|
||||
OneHotEncoding encoding = new OneHotEncoding(field, valueMap, false);
|
||||
assertThat(encoding.inputFields(), containsInAnyOrder(field));
|
||||
assertThat(encoding.outputFields(),
|
||||
containsInAnyOrder(values.stream().map(v -> "Column_" + v.toString()).toArray(String[]::new)));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncoding> {
|
||||
|
@ -40,7 +41,8 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
|
|||
return new TargetMeanEncoding(randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
valueMap,
|
||||
randomDoubleBetween(0.0, 1.0, false));
|
||||
randomDoubleBetween(0.0, 1.0, false),
|
||||
randomBoolean() ? randomBoolean() : null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -55,7 +57,7 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
|
|||
v -> randomDoubleBetween(0.0, 1.0, false)));
|
||||
String encodedFeatureName = "encoded";
|
||||
Double defaultvalue = randomDouble();
|
||||
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue);
|
||||
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue, false);
|
||||
Object fieldValue = randomFrom(values);
|
||||
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName,
|
||||
equalTo(valueMap.get(fieldValue.toString())));
|
||||
|
@ -68,4 +70,16 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
|
|||
testProcess(encoding, fieldValues, matchers);
|
||||
}
|
||||
|
||||
public void testInputOutputFields() {
|
||||
String field = randomAlphaOfLength(10);
|
||||
String encodedFeatureName = randomAlphaOfLength(10);
|
||||
Double defaultvalue = randomDouble();
|
||||
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.0);
|
||||
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
|
||||
v -> randomDoubleBetween(0.0, 1.0, false)));
|
||||
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue, false);
|
||||
assertThat(encoding.inputFields(), containsInAnyOrder(field));
|
||||
assertThat(encoding.outputFields(), containsInAnyOrder(encodedFeatureName));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -74,7 +74,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
|
||||
.setParsedDefinition(new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding, false)))
|
||||
.setTrainedModel(buildClassification(true)))
|
||||
.setVersion(Version.CURRENT)
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
|
@ -85,7 +85,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
|
||||
.setParsedDefinition(new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding, false)))
|
||||
.setTrainedModel(buildRegression()))
|
||||
.setVersion(Version.CURRENT)
|
||||
.setEstimatedOperations(0)
|
||||
|
@ -203,7 +203,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
|
||||
.setParsedDefinition(new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding, false)))
|
||||
.setTrainedModel(buildMultiClassClassification()))
|
||||
.setVersion(Version.CURRENT)
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
|
@ -320,7 +320,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId)
|
||||
.setInput(new TrainedModelInput(Arrays.asList("field1", "field2")))
|
||||
.setParsedDefinition(new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding, false)))
|
||||
.setTrainedModel(buildRegression()))
|
||||
.setVersion(Version.CURRENT)
|
||||
.setEstimatedOperations(0)
|
||||
|
|
|
@ -67,7 +67,7 @@ public class LocalModelTests extends ESTestCase {
|
|||
String modelId = "classification_model";
|
||||
List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
|
||||
InferenceDefinition definition = InferenceDefinition.builder()
|
||||
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap(), false)))
|
||||
.setTrainedModel(buildClassificationInference(false))
|
||||
.build();
|
||||
|
||||
|
@ -99,7 +99,7 @@ public class LocalModelTests extends ESTestCase {
|
|||
|
||||
// Test with labels
|
||||
definition = InferenceDefinition.builder()
|
||||
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap(), false)))
|
||||
.setTrainedModel(buildClassificationInference(true))
|
||||
.build();
|
||||
model = new LocalModel(modelId,
|
||||
|
@ -142,7 +142,7 @@ public class LocalModelTests extends ESTestCase {
|
|||
String modelId = "classification_model";
|
||||
List<String> inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical");
|
||||
InferenceDefinition definition = InferenceDefinition.builder()
|
||||
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap(), false)))
|
||||
.setTrainedModel(buildClassificationInference(true))
|
||||
.build();
|
||||
|
||||
|
@ -200,7 +200,7 @@ public class LocalModelTests extends ESTestCase {
|
|||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
|
||||
InferenceDefinition trainedModelDefinition = InferenceDefinition.builder()
|
||||
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap(), false)))
|
||||
.setTrainedModel(buildRegressionInference())
|
||||
.build();
|
||||
LocalModel model = new LocalModel("regression_model",
|
||||
|
@ -228,7 +228,7 @@ public class LocalModelTests extends ESTestCase {
|
|||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
|
||||
InferenceDefinition trainedModelDefinition = InferenceDefinition.builder()
|
||||
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap(), false)))
|
||||
.setTrainedModel(buildRegressionInference())
|
||||
.build();
|
||||
LocalModel model = new LocalModel(
|
||||
|
@ -260,7 +260,7 @@ public class LocalModelTests extends ESTestCase {
|
|||
String modelId = "classification_model";
|
||||
List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
|
||||
InferenceDefinition definition = InferenceDefinition.builder()
|
||||
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
.setPreProcessors(Collections.singletonList(new OneHotEncoding("categorical", oneHotMap(), false)))
|
||||
.setTrainedModel(buildClassificationInference(false))
|
||||
.build();
|
||||
|
||||
|
|
Loading…
Reference in New Issue