* [ML][Inference] Feature pre-processing objects and functions (#46777) To support inference on pre-trained machine learning models, some basic feature encoding will be necessary. I am using a named object serialization approach so new encodings/pre-processing steps could be added in the future. This PR lays down the ground work for 3 basic encodings: * HotOne * Target Mean * Frequency More feature encodings or pre-processings could be added in the future: * Handling missing columns * Standardization * Label encoding * etc.... * fixing compilation for namedxcontent tests
This commit is contained in:
parent
81cbd3fba4
commit
05fb7be571
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
||||
|
||||
@Override
|
||||
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
|
||||
// PreProcessing
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(OneHotEncoding.NAME),
|
||||
OneHotEncoding::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(TargetMeanEncoding.NAME),
|
||||
TargetMeanEncoding::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(FrequencyEncoding.NAME),
|
||||
FrequencyEncoding::fromXContent));
|
||||
return namedXContent;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,161 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
/**
|
||||
* PreProcessor for frequency encoding a set of categorical values for a given field.
|
||||
*/
|
||||
public class FrequencyEncoding implements PreProcessor {
|
||||
|
||||
public static final String NAME = "frequency_encoding";
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<FrequencyEncoding, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2]));
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
||||
FREQUENCY_MAP);
|
||||
}
|
||||
|
||||
public static FrequencyEncoding fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String field;
|
||||
private final String featureName;
|
||||
private final Map<String, Double> frequencyMap;
|
||||
|
||||
public FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap) {
|
||||
this.field = Objects.requireNonNull(field);
|
||||
this.featureName = Objects.requireNonNull(featureName);
|
||||
this.frequencyMap = Collections.unmodifiableMap(Objects.requireNonNull(frequencyMap));
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Field name on which to frequency encode
|
||||
*/
|
||||
public String getField() {
|
||||
return field;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Map of Value: frequency for the frequency encoding
|
||||
*/
|
||||
public Map<String, Double> getFrequencyMap() {
|
||||
return frequencyMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The encoded feature name
|
||||
*/
|
||||
public String getFeatureName() {
|
||||
return featureName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD.getPreferredName(), field);
|
||||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||
builder.field(FREQUENCY_MAP.getPreferredName(), frequencyMap);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
FrequencyEncoding that = (FrequencyEncoding) o;
|
||||
return Objects.equals(field, that.field)
|
||||
&& Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(frequencyMap, that.frequencyMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, featureName, frequencyMap);
|
||||
}
|
||||
|
||||
public Builder builder(String field) {
|
||||
return new Builder(field);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String field;
|
||||
private String featureName;
|
||||
private Map<String, Double> frequencyMap = new HashMap<>();
|
||||
|
||||
public Builder(String field) {
|
||||
this.field = field;
|
||||
}
|
||||
|
||||
public Builder setField(String field) {
|
||||
this.field = field;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setFeatureName(String featureName) {
|
||||
this.featureName = featureName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setFrequencyMap(Map<String, Double> frequencyMap) {
|
||||
this.frequencyMap = new HashMap<>(frequencyMap);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder addFrequency(String valueName, double frequency) {
|
||||
this.frequencyMap.put(valueName, frequency);
|
||||
return this;
|
||||
}
|
||||
|
||||
public FrequencyEncoding build() {
|
||||
return new FrequencyEncoding(field, featureName, frequencyMap);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,138 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* PreProcessor for one hot encoding a set of categorical values for a given field.
|
||||
*/
|
||||
public class OneHotEncoding implements PreProcessor {
|
||||
|
||||
public static final String NAME = "one_hot_encoding";
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField HOT_MAP = new ParseField("hot_map");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<OneHotEncoding, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1]));
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP);
|
||||
}
|
||||
|
||||
public static OneHotEncoding fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String field;
|
||||
private final Map<String, String> hotMap;
|
||||
|
||||
public OneHotEncoding(String field, Map<String, String> hotMap) {
|
||||
this.field = Objects.requireNonNull(field);
|
||||
this.hotMap = Collections.unmodifiableMap(Objects.requireNonNull(hotMap));
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Field name on which to one hot encode
|
||||
*/
|
||||
public String getField() {
|
||||
return field;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Map of Value: ColumnName for the one hot encoding
|
||||
*/
|
||||
public Map<String, String> getHotMap() {
|
||||
return hotMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD.getPreferredName(), field);
|
||||
builder.field(HOT_MAP.getPreferredName(), hotMap);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
OneHotEncoding that = (OneHotEncoding) o;
|
||||
return Objects.equals(field, that.field)
|
||||
&& Objects.equals(hotMap, that.hotMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, hotMap);
|
||||
}
|
||||
|
||||
public Builder builder(String field) {
|
||||
return new Builder(field);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String field;
|
||||
private Map<String, String> hotMap = new HashMap<>();
|
||||
|
||||
public Builder(String field) {
|
||||
this.field = field;
|
||||
}
|
||||
|
||||
public Builder setField(String field) {
|
||||
this.field = field;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setHotMap(Map<String, String> hotMap) {
|
||||
this.hotMap = new HashMap<>(hotMap);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder addOneHot(String valueName, String oneHotFeatureName) {
|
||||
this.hotMap.put(valueName, oneHotFeatureName);
|
||||
return this;
|
||||
}
|
||||
|
||||
public OneHotEncoding build() {
|
||||
return new OneHotEncoding(field, hotMap);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
|
||||
|
||||
/**
|
||||
* Describes a pre-processor for a defined machine learning model
|
||||
*/
|
||||
public interface PreProcessor extends ToXContentObject {
|
||||
|
||||
/**
|
||||
* @return The name of the pre-processor
|
||||
*/
|
||||
String getName();
|
||||
}
|
|
@ -0,0 +1,183 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
/**
|
||||
* PreProcessor for target mean encoding a set of categorical values for a given field.
|
||||
*/
|
||||
public class TargetMeanEncoding implements PreProcessor {
|
||||
|
||||
public static final String NAME = "target_mean_encoding";
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
public static final ParseField TARGET_MEANS = new ParseField("target_means");
|
||||
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<TargetMeanEncoding, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3]));
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
||||
TARGET_MEANS);
|
||||
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
|
||||
}
|
||||
|
||||
public static TargetMeanEncoding fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String field;
|
||||
private final String featureName;
|
||||
private final Map<String, Double> meanMap;
|
||||
private final double defaultValue;
|
||||
|
||||
public TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue) {
|
||||
this.field = Objects.requireNonNull(field);
|
||||
this.featureName = Objects.requireNonNull(featureName);
|
||||
this.meanMap = Collections.unmodifiableMap(Objects.requireNonNull(meanMap));
|
||||
this.defaultValue = Objects.requireNonNull(defaultValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Field name on which to target mean encode
|
||||
*/
|
||||
public String getField() {
|
||||
return field;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Map of Value: targetMean for the target mean encoding
|
||||
*/
|
||||
public Map<String, Double> getMeanMap() {
|
||||
return meanMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The default value to set when a previously unobserved value is seen
|
||||
*/
|
||||
public double getDefaultValue() {
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The feature name for the encoded value
|
||||
*/
|
||||
public String getFeatureName() {
|
||||
return featureName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD.getPreferredName(), field);
|
||||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||
builder.field(TARGET_MEANS.getPreferredName(), meanMap);
|
||||
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TargetMeanEncoding that = (TargetMeanEncoding) o;
|
||||
return Objects.equals(field, that.field)
|
||||
&& Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(meanMap, that.meanMap)
|
||||
&& Objects.equals(defaultValue, that.defaultValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, featureName, meanMap, defaultValue);
|
||||
}
|
||||
|
||||
public Builder builder(String field) {
|
||||
return new Builder(field);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String field;
|
||||
private String featureName;
|
||||
private Map<String, Double> meanMap = new HashMap<>();
|
||||
private double defaultValue;
|
||||
|
||||
public Builder(String field) {
|
||||
this.field = field;
|
||||
}
|
||||
|
||||
public String getField() {
|
||||
return field;
|
||||
}
|
||||
|
||||
public Builder setField(String field) {
|
||||
this.field = field;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setFeatureName(String featureName) {
|
||||
this.featureName = featureName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setMeanMap(Map<String, Double> meanMap) {
|
||||
this.meanMap = meanMap;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder addMeanMapEntry(String valueName, double meanEncoding) {
|
||||
this.meanMap.put(valueName, meanEncoding);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setDefaultValue(double defaultValue) {
|
||||
this.defaultValue = defaultValue;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TargetMeanEncoding build() {
|
||||
return new TargetMeanEncoding(field, featureName, meanMap, defaultValue);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
org.elasticsearch.client.indexlifecycle.IndexLifecycleNamedXContentProvider
|
||||
org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider
|
||||
org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider
|
||||
org.elasticsearch.client.transform.TransformNamedXContentProvider
|
||||
|
|
|
@ -65,6 +65,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Binar
|
|||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
import org.elasticsearch.client.transform.transforms.SyncConfig;
|
||||
import org.elasticsearch.client.transform.transforms.TimeSyncConfig;
|
||||
import org.elasticsearch.common.CheckedFunction;
|
||||
|
@ -95,6 +98,7 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.test.InternalAggregationTestCase;
|
||||
import org.elasticsearch.test.rest.yaml.restspec.ClientYamlSuiteRestApi;
|
||||
import org.elasticsearch.test.rest.yaml.restspec.ClientYamlSuiteRestSpec;
|
||||
|
||||
import org.hamcrest.Matchers;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -676,7 +680,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(37, namedXContents.size());
|
||||
assertEquals(40, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
|
@ -686,7 +690,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
categories.put(namedXContent.categoryClass, counter + 1);
|
||||
}
|
||||
}
|
||||
assertEquals("Had: " + categories, 9, categories.size());
|
||||
assertEquals("Had: " + categories, 10, categories.size());
|
||||
assertEquals(Integer.valueOf(3), categories.get(Aggregation.class));
|
||||
assertTrue(names.contains(ChildrenAggregationBuilder.NAME));
|
||||
assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME));
|
||||
|
@ -733,6 +737,8 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
ConfusionMatrixMetric.NAME,
|
||||
MeanSquaredErrorMetric.NAME,
|
||||
RSquaredMetric.NAME));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
|
||||
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME));
|
||||
}
|
||||
|
||||
public void testApiNamingConventions() throws Exception {
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
|
||||
public class FrequencyEncodingTests extends AbstractXContentTestCase<FrequencyEncoding> {
|
||||
|
||||
@Override
|
||||
protected FrequencyEncoding doParseInstance(XContentParser parser) throws IOException {
|
||||
return FrequencyEncoding.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FrequencyEncoding createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static FrequencyEncoding createRandom() {
|
||||
int valuesSize = randomIntBetween(1, 10);
|
||||
Map<String, Double> valueMap = new HashMap<>();
|
||||
for (int i = 0; i < valuesSize; i++) {
|
||||
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
|
||||
}
|
||||
return new FrequencyEncoding(randomAlphaOfLength(10), randomAlphaOfLength(10), valueMap);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
|
||||
public class OneHotEncodingTests extends AbstractXContentTestCase<OneHotEncoding> {
|
||||
|
||||
@Override
|
||||
protected OneHotEncoding doParseInstance(XContentParser parser) throws IOException {
|
||||
return OneHotEncoding.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OneHotEncoding createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static OneHotEncoding createRandom() {
|
||||
int valuesSize = randomIntBetween(1, 10);
|
||||
Map<String, String> valueMap = new HashMap<>();
|
||||
for (int i = 0; i < valuesSize; i++) {
|
||||
valueMap.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
|
||||
}
|
||||
return new OneHotEncoding(randomAlphaOfLength(10), valueMap);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
|
||||
public class TargetMeanEncodingTests extends AbstractXContentTestCase<TargetMeanEncoding> {
|
||||
|
||||
@Override
|
||||
protected TargetMeanEncoding doParseInstance(XContentParser parser) throws IOException {
|
||||
return TargetMeanEncoding.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TargetMeanEncoding createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static TargetMeanEncoding createRandom() {
|
||||
int valuesSize = randomIntBetween(1, 10);
|
||||
Map<String, Double> valueMap = new HashMap<>();
|
||||
for (int i = 0; i < valuesSize; i++) {
|
||||
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
|
||||
}
|
||||
return new TargetMeanEncoding(randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
valueMap,
|
||||
randomDoubleBetween(0.0, 1.0, false));
|
||||
}
|
||||
|
||||
}
|
|
@ -145,6 +145,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.P
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
|
||||
import org.elasticsearch.xpack.core.monitoring.MonitoringFeatureSetUsage;
|
||||
import org.elasticsearch.xpack.core.rollup.RollupFeatureSetUsage;
|
||||
|
@ -472,6 +476,10 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
|||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, ScoreByThresholdResult::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
|
||||
ConfusionMatrix.Result::new),
|
||||
// ML - Inference
|
||||
new NamedWriteableRegistry.Entry(PreProcessor.class, FrequencyEncoding.NAME.getPreferredName(), FrequencyEncoding::new),
|
||||
new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(), OneHotEncoding::new),
|
||||
new NamedWriteableRegistry.Entry(PreProcessor.class, TargetMeanEncoding.NAME.getPreferredName(), TargetMeanEncoding::new),
|
||||
|
||||
// monitoring
|
||||
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new),
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
|
||||
|
||||
@Override
|
||||
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
|
||||
// PreProcessing Lenient
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, OneHotEncoding.NAME,
|
||||
OneHotEncoding::fromXContentLenient));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
|
||||
TargetMeanEncoding::fromXContentLenient));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, FrequencyEncoding.NAME,
|
||||
FrequencyEncoding::fromXContentLenient));
|
||||
|
||||
// PreProcessing Strict
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
|
||||
OneHotEncoding::fromXContentStrict));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
|
||||
TargetMeanEncoding::fromXContentStrict));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME,
|
||||
FrequencyEncoding::fromXContentStrict));
|
||||
|
||||
return namedXContent;
|
||||
}
|
||||
|
||||
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
|
||||
// PreProcessing
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(),
|
||||
OneHotEncoding::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, TargetMeanEncoding.NAME.getPreferredName(),
|
||||
TargetMeanEncoding::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, FrequencyEncoding.NAME.getPreferredName(),
|
||||
FrequencyEncoding::new));
|
||||
|
||||
return namedWriteables;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,146 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
/**
|
||||
* PreProcessor for frequency encoding a set of categorical values for a given field.
|
||||
*/
|
||||
public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
|
||||
|
||||
public static final ParseField NAME = new ParseField("frequency_encoding");
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map");
|
||||
|
||||
public static final ConstructingObjectParser<FrequencyEncoding, Void> STRICT_PARSER = createParser(false);
|
||||
public static final ConstructingObjectParser<FrequencyEncoding, Void> LENIENT_PARSER = createParser(true);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<FrequencyEncoding, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<FrequencyEncoding, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
||||
FREQUENCY_MAP);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static FrequencyEncoding fromXContentStrict(XContentParser parser) {
|
||||
return STRICT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public static FrequencyEncoding fromXContentLenient(XContentParser parser) {
|
||||
return LENIENT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String field;
|
||||
private final String featureName;
|
||||
private final Map<String, Double> frequencyMap;
|
||||
|
||||
public FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap) {
|
||||
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
|
||||
this.featureName = ExceptionsHelper.requireNonNull(featureName, FEATURE_NAME);
|
||||
this.frequencyMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(frequencyMap, FREQUENCY_MAP));
|
||||
}
|
||||
|
||||
public FrequencyEncoding(StreamInput in) throws IOException {
|
||||
this.field = in.readString();
|
||||
this.featureName = in.readString();
|
||||
this.frequencyMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readDouble));
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Field name on which to frequency encode
|
||||
*/
|
||||
public String getField() {
|
||||
return field;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Map of Value: frequency for the frequency encoding
|
||||
*/
|
||||
public Map<String, Double> getFrequencyMap() {
|
||||
return frequencyMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The encoded feature name
|
||||
*/
|
||||
public String getFeatureName() {
|
||||
return featureName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Map<String, Object> fields) {
|
||||
String value = (String)fields.get(field);
|
||||
if (value == null) {
|
||||
return;
|
||||
}
|
||||
fields.put(featureName, frequencyMap.getOrDefault(value, 0.0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(field);
|
||||
out.writeString(featureName);
|
||||
out.writeMap(frequencyMap, StreamOutput::writeString, StreamOutput::writeDouble);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD.getPreferredName(), field);
|
||||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||
builder.field(FREQUENCY_MAP.getPreferredName(), frequencyMap);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
FrequencyEncoding that = (FrequencyEncoding) o;
|
||||
return Objects.equals(field, that.field)
|
||||
&& Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(frequencyMap, that.frequencyMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, featureName, frequencyMap);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
/**
|
||||
* To be used in conjunction with a lenient parser.
|
||||
*/
|
||||
public interface LenientlyParsedPreProcessor extends PreProcessor {
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* PreProcessor for one hot encoding a set of categorical values for a given field.
|
||||
*/
|
||||
public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
|
||||
|
||||
public static final ParseField NAME = new ParseField("one_hot_encoding");
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField HOT_MAP = new ParseField("hot_map");
|
||||
|
||||
public static final ConstructingObjectParser<OneHotEncoding, Void> STRICT_PARSER = createParser(false);
|
||||
public static final ConstructingObjectParser<OneHotEncoding, Void> LENIENT_PARSER = createParser(true);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<OneHotEncoding, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<OneHotEncoding, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static OneHotEncoding fromXContentStrict(XContentParser parser) {
|
||||
return STRICT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public static OneHotEncoding fromXContentLenient(XContentParser parser) {
|
||||
return LENIENT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String field;
|
||||
private final Map<String, String> hotMap;
|
||||
|
||||
public OneHotEncoding(String field, Map<String, String> hotMap) {
|
||||
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
|
||||
this.hotMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP));
|
||||
}
|
||||
|
||||
public OneHotEncoding(StreamInput in) throws IOException {
|
||||
this.field = in.readString();
|
||||
this.hotMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString));
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Field name on which to one hot encode
|
||||
*/
|
||||
public String getField() {
|
||||
return field;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Map of Value: ColumnName for the one hot encoding
|
||||
*/
|
||||
public Map<String, String> getHotMap() {
|
||||
return hotMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Map<String, Object> fields) {
|
||||
String value = (String)fields.get(field);
|
||||
if (value == null) {
|
||||
return;
|
||||
}
|
||||
hotMap.forEach((val, col) -> {
|
||||
int encoding = value.equals(val) ? 1 : 0;
|
||||
fields.put(col, encoding);
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(field);
|
||||
out.writeMap(hotMap, StreamOutput::writeString, StreamOutput::writeString);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD.getPreferredName(), field);
|
||||
builder.field(HOT_MAP.getPreferredName(), hotMap);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
OneHotEncoding that = (OneHotEncoding) o;
|
||||
return Objects.equals(field, that.field)
|
||||
&& Objects.equals(hotMap, that.hotMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, hotMap);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Describes a pre-processor for a defined machine learning model
|
||||
* This processor should take a set of fields and return the modified set of fields.
|
||||
*/
|
||||
public interface PreProcessor extends NamedXContentObject, NamedWriteable {
|
||||
|
||||
/**
|
||||
* Process the given fields and their values and return the modified map.
|
||||
*
|
||||
* NOTE: The passed map object is mutated directly
|
||||
* @param fields The fields and their values to process
|
||||
*/
|
||||
void process(Map<String, Object> fields);
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
/**
|
||||
* To be used in conjunction with a strict parser.
|
||||
*/
|
||||
public interface StrictlyParsedPreProcessor extends PreProcessor {
|
||||
}
|
|
@ -0,0 +1,161 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
/**
|
||||
* PreProcessor for target mean encoding a set of categorical values for a given field.
|
||||
*/
|
||||
public class TargetMeanEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
|
||||
|
||||
public static final ParseField NAME = new ParseField("target_mean_encoding");
|
||||
public static final ParseField FIELD = new ParseField("field");
|
||||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
public static final ParseField TARGET_MEANS = new ParseField("target_means");
|
||||
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
|
||||
|
||||
public static final ConstructingObjectParser<TargetMeanEncoding, Void> STRICT_PARSER = createParser(false);
|
||||
public static final ConstructingObjectParser<TargetMeanEncoding, Void> LENIENT_PARSER = createParser(true);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ConstructingObjectParser<TargetMeanEncoding, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<TargetMeanEncoding, Void> parser = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
lenient,
|
||||
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
|
||||
TARGET_MEANS);
|
||||
parser.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static TargetMeanEncoding fromXContentStrict(XContentParser parser) {
|
||||
return STRICT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public static TargetMeanEncoding fromXContentLenient(XContentParser parser) {
|
||||
return LENIENT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String field;
|
||||
private final String featureName;
|
||||
private final Map<String, Double> meanMap;
|
||||
private final double defaultValue;
|
||||
|
||||
public TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue) {
|
||||
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
|
||||
this.featureName = ExceptionsHelper.requireNonNull(featureName, FEATURE_NAME);
|
||||
this.meanMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(meanMap, TARGET_MEANS));
|
||||
this.defaultValue = ExceptionsHelper.requireNonNull(defaultValue, DEFAULT_VALUE);
|
||||
}
|
||||
|
||||
public TargetMeanEncoding(StreamInput in) throws IOException {
|
||||
this.field = in.readString();
|
||||
this.featureName = in.readString();
|
||||
this.meanMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readDouble));
|
||||
this.defaultValue = in.readDouble();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Field name on which to target mean encode
|
||||
*/
|
||||
public String getField() {
|
||||
return field;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Map of Value: targetMean for the target mean encoding
|
||||
*/
|
||||
public Map<String, Double> getMeanMap() {
|
||||
return meanMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The default value to set when a previously unobserved value is seen
|
||||
*/
|
||||
public Double getDefaultValue() {
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The feature name for the encoded value
|
||||
*/
|
||||
public String getFeatureName() {
|
||||
return featureName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Map<String, Object> fields) {
|
||||
String value = (String)fields.get(field);
|
||||
if (value == null) {
|
||||
return;
|
||||
}
|
||||
fields.put(featureName, meanMap.getOrDefault(value, defaultValue));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(field);
|
||||
out.writeString(featureName);
|
||||
out.writeMap(meanMap, StreamOutput::writeString, StreamOutput::writeDouble);
|
||||
out.writeDouble(defaultValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FIELD.getPreferredName(), field);
|
||||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||
builder.field(TARGET_MEANS.getPreferredName(), meanMap);
|
||||
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TargetMeanEncoding that = (TargetMeanEncoding) o;
|
||||
return Objects.equals(field, that.field)
|
||||
&& Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(meanMap, that.meanMap)
|
||||
&& Objects.equals(defaultValue, that.defaultValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, featureName, meanMap, defaultValue);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.utils;
|
||||
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
|
||||
/**
|
||||
* Simple interface for XContent Objects that are named.
|
||||
*
|
||||
* This affords more general handling when serializing and de-serializing this type of XContent when it is used in a NamedObjects
|
||||
* parser.
|
||||
*/
|
||||
public interface NamedXContentObject extends ToXContentObject {
|
||||
/**
|
||||
* @return The name of the XContentObject that is to be serialized
|
||||
*/
|
||||
String getName();
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
//TODO these tests are temporary until the named objects are actually used by an encompassing class (i.e. ModelInferer)
|
||||
public class NamedXContentObjectsTests extends AbstractXContentTestCase<NamedXContentObjectsTests.NamedObjectContainer> {
|
||||
|
||||
static class NamedObjectContainer implements ToXContentObject {
|
||||
|
||||
static ParseField PRE_PROCESSORS = new ParseField("pre_processors");
|
||||
|
||||
static final ObjectParser<NamedObjectContainer, Void> STRICT_PARSER = createParser(false);
|
||||
static final ObjectParser<NamedObjectContainer, Void> LENIENT_PARSER = createParser(true);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static ObjectParser<NamedObjectContainer, Void> createParser(boolean lenient) {
|
||||
ObjectParser<NamedObjectContainer, Void> parser = new ObjectParser<>(
|
||||
"named_xcontent_object_container_test",
|
||||
lenient,
|
||||
NamedObjectContainer::new);
|
||||
parser.declareNamedObjects(NamedObjectContainer::setPreProcessors,
|
||||
(p, c, n) ->
|
||||
lenient ? p.namedObject(LenientlyParsedPreProcessor.class, n, null) :
|
||||
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
|
||||
(noc) -> noc.setUseExplicitPreprocessorOrder(true), PRE_PROCESSORS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
private boolean useExplicitPreprocessorOrder = false;
|
||||
private List<? extends PreProcessor> preProcessors;
|
||||
|
||||
void setPreProcessors(List<? extends PreProcessor> preProcessors) {
|
||||
this.preProcessors = preProcessors;
|
||||
}
|
||||
|
||||
void setUseExplicitPreprocessorOrder(boolean value) {
|
||||
this.useExplicitPreprocessorOrder = value;
|
||||
}
|
||||
|
||||
static NamedObjectContainer fromXContent(XContentParser parser, boolean lenient) {
|
||||
return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
writeNamedObjects(builder, params, useExplicitPreprocessorOrder, PRE_PROCESSORS.getPreferredName(), preProcessors);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
XContentBuilder writeNamedObjects(XContentBuilder builder,
|
||||
Params params,
|
||||
boolean useExplicitOrder,
|
||||
String namedObjectsName,
|
||||
List<? extends NamedXContentObject> namedObjects) throws IOException {
|
||||
if (useExplicitOrder) {
|
||||
builder.startArray(namedObjectsName);
|
||||
} else {
|
||||
builder.startObject(namedObjectsName);
|
||||
}
|
||||
for (NamedXContentObject object : namedObjects) {
|
||||
if (useExplicitOrder) {
|
||||
builder.startObject();
|
||||
}
|
||||
builder.field(object.getName(), object, params);
|
||||
if (useExplicitOrder) {
|
||||
builder.endObject();
|
||||
}
|
||||
}
|
||||
if (useExplicitOrder) {
|
||||
builder.endArray();
|
||||
} else {
|
||||
builder.endObject();
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
NamedObjectContainer that = (NamedObjectContainer) o;
|
||||
return Objects.equals(preProcessors, that.preProcessors);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(preProcessors);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
public NamedObjectContainer createTestInstance() {
|
||||
int max = randomIntBetween(1, 10);
|
||||
List<PreProcessor> preProcessors = new ArrayList<>(max);
|
||||
for (int i = 0; i < max; i++) {
|
||||
preProcessors.add(randomFrom(FrequencyEncodingTests.createRandom(),
|
||||
OneHotEncodingTests.createRandom(),
|
||||
TargetMeanEncodingTests.createRandom()));
|
||||
}
|
||||
NamedObjectContainer container = new NamedObjectContainer();
|
||||
container.setPreProcessors(preProcessors);
|
||||
container.setUseExplicitPreprocessorOrder(true);
|
||||
return container;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedObjectContainer doParseInstance(XContentParser parser) throws IOException {
|
||||
return NamedObjectContainer.fromXContent(parser, lenient);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
// We only want to add random fields to the root, or the root of the named objects
|
||||
return field ->
|
||||
(field.endsWith("frequency_encoding") ||
|
||||
field.endsWith("one_hot_encoding") ||
|
||||
field.endsWith("target_mean_encoding") ||
|
||||
field.isEmpty()) == false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.hamcrest.Matcher;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding> {
|
||||
|
||||
@Override
|
||||
protected FrequencyEncoding doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? FrequencyEncoding.fromXContentLenient(parser) : FrequencyEncoding.fromXContentStrict(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FrequencyEncoding createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static FrequencyEncoding createRandom() {
|
||||
int valuesSize = randomIntBetween(1, 10);
|
||||
Map<String, Double> valueMap = new HashMap<>();
|
||||
for (int i = 0; i < valuesSize; i++) {
|
||||
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
|
||||
}
|
||||
return new FrequencyEncoding(randomAlphaOfLength(10), randomAlphaOfLength(10), valueMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<FrequencyEncoding> instanceReader() {
|
||||
return FrequencyEncoding::new;
|
||||
}
|
||||
|
||||
public void testProcessWithFieldPresent() {
|
||||
String field = "categorical";
|
||||
List<String> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote");
|
||||
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Function.identity(),
|
||||
v -> randomDoubleBetween(0.0, 1.0, false)));
|
||||
String encodedFeatureName = "encoded";
|
||||
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap);
|
||||
String fieldValue = randomFrom(values);
|
||||
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName, equalTo(valueMap.get(fieldValue)));
|
||||
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
|
||||
testProcess(encoding, fieldValues, matchers);
|
||||
|
||||
// Test where the value is some unknown Value
|
||||
fieldValues = randomFieldValues(field, "unknownValue");
|
||||
fieldValues.put(field, "unknownValue");
|
||||
matchers = Collections.singletonMap(encodedFeatureName, equalTo(0.0));
|
||||
testProcess(encoding, fieldValues, matchers);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.hamcrest.Matcher;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
|
||||
|
||||
@Override
|
||||
protected OneHotEncoding doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? OneHotEncoding.fromXContentLenient(parser) : OneHotEncoding.fromXContentStrict(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OneHotEncoding createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static OneHotEncoding createRandom() {
|
||||
int valuesSize = randomIntBetween(1, 10);
|
||||
Map<String, String> valueMap = new HashMap<>();
|
||||
for (int i = 0; i < valuesSize; i++) {
|
||||
valueMap.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
|
||||
}
|
||||
return new OneHotEncoding(randomAlphaOfLength(10), valueMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<OneHotEncoding> instanceReader() {
|
||||
return OneHotEncoding::new;
|
||||
}
|
||||
|
||||
public void testProcessWithFieldPresent() {
|
||||
String field = "categorical";
|
||||
List<String> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote");
|
||||
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Function.identity(), v -> "Column_" + v));
|
||||
OneHotEncoding encoding = new OneHotEncoding(field, valueMap);
|
||||
String fieldValue = randomFrom(values);
|
||||
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
|
||||
|
||||
Map<String, Matcher<? super Object>> matchers = values.stream().map(v -> "Column_" + v)
|
||||
.collect(Collectors.toMap(
|
||||
Function.identity(),
|
||||
v -> v.equals("Column_" + fieldValue) ? equalTo(1) : equalTo(0)));
|
||||
|
||||
fieldValues.put(field, fieldValue);
|
||||
testProcess(encoding, fieldValues, matchers);
|
||||
|
||||
// Test where the value is some unknown Value
|
||||
fieldValues = randomFieldValues(field, "unknownValue");
|
||||
matchers.put("Column_" + fieldValue, equalTo(0));
|
||||
testProcess(encoding, fieldValues, matchers);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.hamcrest.Matcher;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public abstract class PreProcessingTests<T extends PreProcessor> extends AbstractSerializingTestCase<T> {
|
||||
|
||||
protected boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
|
||||
void testProcess(PreProcessor preProcessor, Map<String, Object> fieldValues, Map<String, Matcher<? super Object>> assertions) {
|
||||
preProcessor.process(fieldValues);
|
||||
assertions.forEach((fieldName, matcher) ->
|
||||
assertThat(fieldValues.get(fieldName), matcher)
|
||||
);
|
||||
}
|
||||
|
||||
public void testWithMissingField() {
|
||||
Map<String, Object> fields = randomFieldValues();
|
||||
PreProcessor preProcessor = this.createTestInstance();
|
||||
Map<String, Object> fieldsCopy = new HashMap<>(fields);
|
||||
preProcessor.process(fields);
|
||||
assertThat(fieldsCopy, equalTo(fields));
|
||||
}
|
||||
|
||||
Map<String, Object> randomFieldValues() {
|
||||
int numFields = randomIntBetween(1, 5);
|
||||
Map<String, Object> fieldValues = new HashMap<>(numFields);
|
||||
for (int k = 0; k < numFields; k++) {
|
||||
fieldValues.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
|
||||
}
|
||||
return fieldValues;
|
||||
}
|
||||
|
||||
Map<String, Object> randomFieldValues(String categoricalField, String catigoricalValue) {
|
||||
Map<String, Object> fieldValues = randomFieldValues();
|
||||
fieldValues.put(categoricalField, catigoricalValue);
|
||||
return fieldValues;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.hamcrest.Matcher;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncoding> {
|
||||
|
||||
@Override
|
||||
protected TargetMeanEncoding doParseInstance(XContentParser parser) throws IOException {
|
||||
return lenient ? TargetMeanEncoding.fromXContentLenient(parser) : TargetMeanEncoding.fromXContentStrict(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TargetMeanEncoding createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
public static TargetMeanEncoding createRandom() {
|
||||
int valuesSize = randomIntBetween(1, 10);
|
||||
Map<String, Double> valueMap = new HashMap<>();
|
||||
for (int i = 0; i < valuesSize; i++) {
|
||||
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
|
||||
}
|
||||
return new TargetMeanEncoding(randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
valueMap,
|
||||
randomDoubleBetween(0.0, 1.0, false));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<TargetMeanEncoding> instanceReader() {
|
||||
return TargetMeanEncoding::new;
|
||||
}
|
||||
|
||||
public void testProcessWithFieldPresent() {
|
||||
String field = "categorical";
|
||||
List<String> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote");
|
||||
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Function.identity(),
|
||||
v -> randomDoubleBetween(0.0, 1.0, false)));
|
||||
String encodedFeatureName = "encoded";
|
||||
Double defaultvalue = randomDouble();
|
||||
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue);
|
||||
String fieldValue = randomFrom(values);
|
||||
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName, equalTo(valueMap.get(fieldValue)));
|
||||
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
|
||||
testProcess(encoding, fieldValues, matchers);
|
||||
|
||||
// Test where the value is some unknown Value
|
||||
fieldValues = randomFieldValues(field, "unknownValue");
|
||||
matchers = Collections.singletonMap(encodedFeatureName, equalTo(defaultvalue));
|
||||
testProcess(encoding, fieldValues, matchers);
|
||||
}
|
||||
|
||||
}
|
|
@ -122,6 +122,7 @@ import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction;
|
|||
import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
||||
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields;
|
||||
import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
|
||||
|
@ -950,6 +951,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
return namedXContent;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue