[7.x] [ML][Inference] Feature pre-processing objects and functions (#46777) (#47040)

* [ML][Inference] Feature pre-processing objects and functions (#46777)

To support inference on pre-trained machine learning models, some basic feature encoding will be necessary. I am using a named object serialization approach so new encodings/pre-processing steps could be added in the future. 

This PR lays down the ground work for 3 basic encodings:

* HotOne
* Target Mean
* Frequency

More feature encodings or pre-processings could be added in the future:

* Handling missing columns
* Standardization
* Label encoding
* etc....

* fixing compilation for namedxcontent tests
This commit is contained in:
Benjamin Trent 2019-09-25 08:16:24 -04:00 committed by GitHub
parent 81cbd3fba4
commit 05fb7be571
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1782 additions and 2 deletions

View File

@ -0,0 +1,48 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import java.util.ArrayList;
import java.util.List;
public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
@Override
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
// PreProcessing
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(OneHotEncoding.NAME),
OneHotEncoding::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(TargetMeanEncoding.NAME),
TargetMeanEncoding::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(FrequencyEncoding.NAME),
FrequencyEncoding::fromXContent));
return namedXContent;
}
}

View File

@ -0,0 +1,161 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.preprocessing;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* PreProcessor for frequency encoding a set of categorical values for a given field.
*/
public class FrequencyEncoding implements PreProcessor {
public static final String NAME = "frequency_encoding";
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<FrequencyEncoding, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
PARSER.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
FREQUENCY_MAP);
}
public static FrequencyEncoding fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final String field;
private final String featureName;
private final Map<String, Double> frequencyMap;
public FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap) {
this.field = Objects.requireNonNull(field);
this.featureName = Objects.requireNonNull(featureName);
this.frequencyMap = Collections.unmodifiableMap(Objects.requireNonNull(frequencyMap));
}
/**
* @return Field name on which to frequency encode
*/
public String getField() {
return field;
}
/**
* @return Map of Value: frequency for the frequency encoding
*/
public Map<String, Double> getFrequencyMap() {
return frequencyMap;
}
/**
* @return The encoded feature name
*/
public String getFeatureName() {
return featureName;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(FIELD.getPreferredName(), field);
builder.field(FEATURE_NAME.getPreferredName(), featureName);
builder.field(FREQUENCY_MAP.getPreferredName(), frequencyMap);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FrequencyEncoding that = (FrequencyEncoding) o;
return Objects.equals(field, that.field)
&& Objects.equals(featureName, that.featureName)
&& Objects.equals(frequencyMap, that.frequencyMap);
}
@Override
public int hashCode() {
return Objects.hash(field, featureName, frequencyMap);
}
public Builder builder(String field) {
return new Builder(field);
}
public static class Builder {
private String field;
private String featureName;
private Map<String, Double> frequencyMap = new HashMap<>();
public Builder(String field) {
this.field = field;
}
public Builder setField(String field) {
this.field = field;
return this;
}
public Builder setFeatureName(String featureName) {
this.featureName = featureName;
return this;
}
public Builder setFrequencyMap(Map<String, Double> frequencyMap) {
this.frequencyMap = new HashMap<>(frequencyMap);
return this;
}
public Builder addFrequency(String valueName, double frequency) {
this.frequencyMap.put(valueName, frequency);
return this;
}
public FrequencyEncoding build() {
return new FrequencyEncoding(field, featureName, frequencyMap);
}
}
}

View File

@ -0,0 +1,138 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.preprocessing;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* PreProcessor for one hot encoding a set of categorical values for a given field.
*/
public class OneHotEncoding implements PreProcessor {
public static final String NAME = "one_hot_encoding";
public static final ParseField FIELD = new ParseField("field");
public static final ParseField HOT_MAP = new ParseField("hot_map");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<OneHotEncoding, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP);
}
public static OneHotEncoding fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final String field;
private final Map<String, String> hotMap;
public OneHotEncoding(String field, Map<String, String> hotMap) {
this.field = Objects.requireNonNull(field);
this.hotMap = Collections.unmodifiableMap(Objects.requireNonNull(hotMap));
}
/**
* @return Field name on which to one hot encode
*/
public String getField() {
return field;
}
/**
* @return Map of Value: ColumnName for the one hot encoding
*/
public Map<String, String> getHotMap() {
return hotMap;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(FIELD.getPreferredName(), field);
builder.field(HOT_MAP.getPreferredName(), hotMap);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
OneHotEncoding that = (OneHotEncoding) o;
return Objects.equals(field, that.field)
&& Objects.equals(hotMap, that.hotMap);
}
@Override
public int hashCode() {
return Objects.hash(field, hotMap);
}
public Builder builder(String field) {
return new Builder(field);
}
public static class Builder {
private String field;
private Map<String, String> hotMap = new HashMap<>();
public Builder(String field) {
this.field = field;
}
public Builder setField(String field) {
this.field = field;
return this;
}
public Builder setHotMap(Map<String, String> hotMap) {
this.hotMap = new HashMap<>(hotMap);
return this;
}
public Builder addOneHot(String valueName, String oneHotFeatureName) {
this.hotMap.put(valueName, oneHotFeatureName);
return this;
}
public OneHotEncoding build() {
return new OneHotEncoding(field, hotMap);
}
}
}

View File

@ -0,0 +1,33 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.preprocessing;
import org.elasticsearch.common.xcontent.ToXContentObject;
/**
* Describes a pre-processor for a defined machine learning model
*/
public interface PreProcessor extends ToXContentObject {
/**
* @return The name of the pre-processor
*/
String getName();
}

View File

@ -0,0 +1,183 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.preprocessing;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* PreProcessor for target mean encoding a set of categorical values for a given field.
*/
public class TargetMeanEncoding implements PreProcessor {
public static final String NAME = "target_mean_encoding";
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField TARGET_MEANS = new ParseField("target_means");
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<TargetMeanEncoding, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
PARSER.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
TARGET_MEANS);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
}
public static TargetMeanEncoding fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final String field;
private final String featureName;
private final Map<String, Double> meanMap;
private final double defaultValue;
public TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue) {
this.field = Objects.requireNonNull(field);
this.featureName = Objects.requireNonNull(featureName);
this.meanMap = Collections.unmodifiableMap(Objects.requireNonNull(meanMap));
this.defaultValue = Objects.requireNonNull(defaultValue);
}
/**
* @return Field name on which to target mean encode
*/
public String getField() {
return field;
}
/**
* @return Map of Value: targetMean for the target mean encoding
*/
public Map<String, Double> getMeanMap() {
return meanMap;
}
/**
* @return The default value to set when a previously unobserved value is seen
*/
public double getDefaultValue() {
return defaultValue;
}
/**
* @return The feature name for the encoded value
*/
public String getFeatureName() {
return featureName;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(FIELD.getPreferredName(), field);
builder.field(FEATURE_NAME.getPreferredName(), featureName);
builder.field(TARGET_MEANS.getPreferredName(), meanMap);
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TargetMeanEncoding that = (TargetMeanEncoding) o;
return Objects.equals(field, that.field)
&& Objects.equals(featureName, that.featureName)
&& Objects.equals(meanMap, that.meanMap)
&& Objects.equals(defaultValue, that.defaultValue);
}
@Override
public int hashCode() {
return Objects.hash(field, featureName, meanMap, defaultValue);
}
public Builder builder(String field) {
return new Builder(field);
}
public static class Builder {
private String field;
private String featureName;
private Map<String, Double> meanMap = new HashMap<>();
private double defaultValue;
public Builder(String field) {
this.field = field;
}
public String getField() {
return field;
}
public Builder setField(String field) {
this.field = field;
return this;
}
public Builder setFeatureName(String featureName) {
this.featureName = featureName;
return this;
}
public Builder setMeanMap(Map<String, Double> meanMap) {
this.meanMap = meanMap;
return this;
}
public Builder addMeanMapEntry(String valueName, double meanEncoding) {
this.meanMap.put(valueName, meanEncoding);
return this;
}
public Builder setDefaultValue(double defaultValue) {
this.defaultValue = defaultValue;
return this;
}
public TargetMeanEncoding build() {
return new TargetMeanEncoding(field, featureName, meanMap, defaultValue);
}
}
}

View File

@ -1,4 +1,5 @@
org.elasticsearch.client.indexlifecycle.IndexLifecycleNamedXContentProvider
org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider
org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider
org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider
org.elasticsearch.client.transform.TransformNamedXContentProvider

View File

@ -65,6 +65,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Binar
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.client.transform.transforms.SyncConfig;
import org.elasticsearch.client.transform.transforms.TimeSyncConfig;
import org.elasticsearch.common.CheckedFunction;
@ -95,6 +98,7 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.InternalAggregationTestCase;
import org.elasticsearch.test.rest.yaml.restspec.ClientYamlSuiteRestApi;
import org.elasticsearch.test.rest.yaml.restspec.ClientYamlSuiteRestSpec;
import org.hamcrest.Matchers;
import org.junit.Before;
@ -676,7 +680,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(37, namedXContents.size());
assertEquals(40, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -686,7 +690,7 @@ public class RestHighLevelClientTests extends ESTestCase {
categories.put(namedXContent.categoryClass, counter + 1);
}
}
assertEquals("Had: " + categories, 9, categories.size());
assertEquals("Had: " + categories, 10, categories.size());
assertEquals(Integer.valueOf(3), categories.get(Aggregation.class));
assertTrue(names.contains(ChildrenAggregationBuilder.NAME));
assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME));
@ -733,6 +737,8 @@ public class RestHighLevelClientTests extends ESTestCase {
ConfusionMatrixMetric.NAME,
MeanSquaredErrorMetric.NAME,
RSquaredMetric.NAME));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME));
}
public void testApiNamingConventions() throws Exception {

View File

@ -0,0 +1,60 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.preprocessing;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Predicate;
public class FrequencyEncodingTests extends AbstractXContentTestCase<FrequencyEncoding> {
@Override
protected FrequencyEncoding doParseInstance(XContentParser parser) throws IOException {
return FrequencyEncoding.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> !field.isEmpty();
}
@Override
protected FrequencyEncoding createTestInstance() {
return createRandom();
}
public static FrequencyEncoding createRandom() {
int valuesSize = randomIntBetween(1, 10);
Map<String, Double> valueMap = new HashMap<>();
for (int i = 0; i < valuesSize; i++) {
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
}
return new FrequencyEncoding(randomAlphaOfLength(10), randomAlphaOfLength(10), valueMap);
}
}

View File

@ -0,0 +1,61 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.preprocessing;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Predicate;
public class OneHotEncodingTests extends AbstractXContentTestCase<OneHotEncoding> {
@Override
protected OneHotEncoding doParseInstance(XContentParser parser) throws IOException {
return OneHotEncoding.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> !field.isEmpty();
}
@Override
protected OneHotEncoding createTestInstance() {
return createRandom();
}
public static OneHotEncoding createRandom() {
int valuesSize = randomIntBetween(1, 10);
Map<String, String> valueMap = new HashMap<>();
for (int i = 0; i < valuesSize; i++) {
valueMap.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
}
return new OneHotEncoding(randomAlphaOfLength(10), valueMap);
}
}

View File

@ -0,0 +1,64 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.preprocessing;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Predicate;
public class TargetMeanEncodingTests extends AbstractXContentTestCase<TargetMeanEncoding> {
@Override
protected TargetMeanEncoding doParseInstance(XContentParser parser) throws IOException {
return TargetMeanEncoding.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> !field.isEmpty();
}
@Override
protected TargetMeanEncoding createTestInstance() {
return createRandom();
}
public static TargetMeanEncoding createRandom() {
int valuesSize = randomIntBetween(1, 10);
Map<String, Double> valueMap = new HashMap<>();
for (int i = 0; i < valuesSize; i++) {
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
}
return new TargetMeanEncoding(randomAlphaOfLength(10),
randomAlphaOfLength(10),
valueMap,
randomDoubleBetween(0.0, 1.0, false));
}
}

View File

@ -145,6 +145,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.P
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
import org.elasticsearch.xpack.core.monitoring.MonitoringFeatureSetUsage;
import org.elasticsearch.xpack.core.rollup.RollupFeatureSetUsage;
@ -472,6 +476,10 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, ScoreByThresholdResult::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
ConfusionMatrix.Result::new),
// ML - Inference
new NamedWriteableRegistry.Entry(PreProcessor.class, FrequencyEncoding.NAME.getPreferredName(), FrequencyEncoding::new),
new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(), OneHotEncoding::new),
new NamedWriteableRegistry.Entry(PreProcessor.class, TargetMeanEncoding.NAME.getPreferredName(), TargetMeanEncoding::new),
// monitoring
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new),

View File

@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import java.util.ArrayList;
import java.util.List;
public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
@Override
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
// PreProcessing Lenient
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, OneHotEncoding.NAME,
OneHotEncoding::fromXContentLenient));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
TargetMeanEncoding::fromXContentLenient));
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, FrequencyEncoding.NAME,
FrequencyEncoding::fromXContentLenient));
// PreProcessing Strict
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
OneHotEncoding::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
TargetMeanEncoding::fromXContentStrict));
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME,
FrequencyEncoding::fromXContentStrict));
return namedXContent;
}
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
// PreProcessing
namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(),
OneHotEncoding::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, TargetMeanEncoding.NAME.getPreferredName(),
TargetMeanEncoding::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, FrequencyEncoding.NAME.getPreferredName(),
FrequencyEncoding::new));
return namedWriteables;
}
}

View File

@ -0,0 +1,146 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* PreProcessor for frequency encoding a set of categorical values for a given field.
*/
public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
public static final ParseField NAME = new ParseField("frequency_encoding");
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map");
public static final ConstructingObjectParser<FrequencyEncoding, Void> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<FrequencyEncoding, Void> LENIENT_PARSER = createParser(true);
@SuppressWarnings("unchecked")
private static ConstructingObjectParser<FrequencyEncoding, Void> createParser(boolean lenient) {
ConstructingObjectParser<FrequencyEncoding, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
lenient,
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2]));
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
parser.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
FREQUENCY_MAP);
return parser;
}
public static FrequencyEncoding fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null);
}
public static FrequencyEncoding fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null);
}
private final String field;
private final String featureName;
private final Map<String, Double> frequencyMap;
public FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap) {
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
this.featureName = ExceptionsHelper.requireNonNull(featureName, FEATURE_NAME);
this.frequencyMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(frequencyMap, FREQUENCY_MAP));
}
public FrequencyEncoding(StreamInput in) throws IOException {
this.field = in.readString();
this.featureName = in.readString();
this.frequencyMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readDouble));
}
/**
* @return Field name on which to frequency encode
*/
public String getField() {
return field;
}
/**
* @return Map of Value: frequency for the frequency encoding
*/
public Map<String, Double> getFrequencyMap() {
return frequencyMap;
}
/**
* @return The encoded feature name
*/
public String getFeatureName() {
return featureName;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public void process(Map<String, Object> fields) {
String value = (String)fields.get(field);
if (value == null) {
return;
}
fields.put(featureName, frequencyMap.getOrDefault(value, 0.0));
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(field);
out.writeString(featureName);
out.writeMap(frequencyMap, StreamOutput::writeString, StreamOutput::writeDouble);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FIELD.getPreferredName(), field);
builder.field(FEATURE_NAME.getPreferredName(), featureName);
builder.field(FREQUENCY_MAP.getPreferredName(), frequencyMap);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FrequencyEncoding that = (FrequencyEncoding) o;
return Objects.equals(field, that.field)
&& Objects.equals(featureName, that.featureName)
&& Objects.equals(frequencyMap, that.frequencyMap);
}
@Override
public int hashCode() {
return Objects.hash(field, featureName, frequencyMap);
}
}

View File

@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
/**
* To be used in conjunction with a lenient parser.
*/
public interface LenientlyParsedPreProcessor extends PreProcessor {
}

View File

@ -0,0 +1,130 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
/**
* PreProcessor for one hot encoding a set of categorical values for a given field.
*/
public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
public static final ParseField NAME = new ParseField("one_hot_encoding");
public static final ParseField FIELD = new ParseField("field");
public static final ParseField HOT_MAP = new ParseField("hot_map");
public static final ConstructingObjectParser<OneHotEncoding, Void> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<OneHotEncoding, Void> LENIENT_PARSER = createParser(true);
@SuppressWarnings("unchecked")
private static ConstructingObjectParser<OneHotEncoding, Void> createParser(boolean lenient) {
ConstructingObjectParser<OneHotEncoding, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
lenient,
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1]));
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP);
return parser;
}
public static OneHotEncoding fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null);
}
public static OneHotEncoding fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null);
}
private final String field;
private final Map<String, String> hotMap;
public OneHotEncoding(String field, Map<String, String> hotMap) {
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
this.hotMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(hotMap, HOT_MAP));
}
public OneHotEncoding(StreamInput in) throws IOException {
this.field = in.readString();
this.hotMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString));
}
/**
* @return Field name on which to one hot encode
*/
public String getField() {
return field;
}
/**
* @return Map of Value: ColumnName for the one hot encoding
*/
public Map<String, String> getHotMap() {
return hotMap;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public void process(Map<String, Object> fields) {
String value = (String)fields.get(field);
if (value == null) {
return;
}
hotMap.forEach((val, col) -> {
int encoding = value.equals(val) ? 1 : 0;
fields.put(col, encoding);
});
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(field);
out.writeMap(hotMap, StreamOutput::writeString, StreamOutput::writeString);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FIELD.getPreferredName(), field);
builder.field(HOT_MAP.getPreferredName(), hotMap);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
OneHotEncoding that = (OneHotEncoding) o;
return Objects.equals(field, that.field)
&& Objects.equals(hotMap, that.hotMap);
}
@Override
public int hashCode() {
return Objects.hash(field, hotMap);
}
}

View File

@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
import java.util.Map;
/**
* Describes a pre-processor for a defined machine learning model
* This processor should take a set of fields and return the modified set of fields.
*/
public interface PreProcessor extends NamedXContentObject, NamedWriteable {
/**
* Process the given fields and their values and return the modified map.
*
* NOTE: The passed map object is mutated directly
* @param fields The fields and their values to process
*/
void process(Map<String, Object> fields);
}

View File

@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
/**
* To be used in conjunction with a strict parser.
*/
public interface StrictlyParsedPreProcessor extends PreProcessor {
}

View File

@ -0,0 +1,161 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* PreProcessor for target mean encoding a set of categorical values for a given field.
*/
public class TargetMeanEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
public static final ParseField NAME = new ParseField("target_mean_encoding");
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField TARGET_MEANS = new ParseField("target_means");
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
public static final ConstructingObjectParser<TargetMeanEncoding, Void> STRICT_PARSER = createParser(false);
public static final ConstructingObjectParser<TargetMeanEncoding, Void> LENIENT_PARSER = createParser(true);
@SuppressWarnings("unchecked")
private static ConstructingObjectParser<TargetMeanEncoding, Void> createParser(boolean lenient) {
ConstructingObjectParser<TargetMeanEncoding, Void> parser = new ConstructingObjectParser<>(
NAME.getPreferredName(),
lenient,
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3]));
parser.declareString(ConstructingObjectParser.constructorArg(), FIELD);
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
parser.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
TARGET_MEANS);
parser.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
return parser;
}
public static TargetMeanEncoding fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null);
}
public static TargetMeanEncoding fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null);
}
private final String field;
private final String featureName;
private final Map<String, Double> meanMap;
private final double defaultValue;
public TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue) {
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
this.featureName = ExceptionsHelper.requireNonNull(featureName, FEATURE_NAME);
this.meanMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(meanMap, TARGET_MEANS));
this.defaultValue = ExceptionsHelper.requireNonNull(defaultValue, DEFAULT_VALUE);
}
public TargetMeanEncoding(StreamInput in) throws IOException {
this.field = in.readString();
this.featureName = in.readString();
this.meanMap = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readDouble));
this.defaultValue = in.readDouble();
}
/**
* @return Field name on which to target mean encode
*/
public String getField() {
return field;
}
/**
* @return Map of Value: targetMean for the target mean encoding
*/
public Map<String, Double> getMeanMap() {
return meanMap;
}
/**
* @return The default value to set when a previously unobserved value is seen
*/
public Double getDefaultValue() {
return defaultValue;
}
/**
* @return The feature name for the encoded value
*/
public String getFeatureName() {
return featureName;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public void process(Map<String, Object> fields) {
String value = (String)fields.get(field);
if (value == null) {
return;
}
fields.put(featureName, meanMap.getOrDefault(value, defaultValue));
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(field);
out.writeString(featureName);
out.writeMap(meanMap, StreamOutput::writeString, StreamOutput::writeDouble);
out.writeDouble(defaultValue);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FIELD.getPreferredName(), field);
builder.field(FEATURE_NAME.getPreferredName(), featureName);
builder.field(TARGET_MEANS.getPreferredName(), meanMap);
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TargetMeanEncoding that = (TargetMeanEncoding) o;
return Objects.equals(field, that.field)
&& Objects.equals(featureName, that.featureName)
&& Objects.equals(meanMap, that.meanMap)
&& Objects.equals(defaultValue, that.defaultValue);
}
@Override
public int hashCode() {
return Objects.hash(field, featureName, meanMap, defaultValue);
}
}

View File

@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.utils;
import org.elasticsearch.common.xcontent.ToXContentObject;
/**
* Simple interface for XContent Objects that are named.
*
* This affords more general handling when serializing and de-serializing this type of XContent when it is used in a NamedObjects
* parser.
*/
public interface NamedXContentObject extends ToXContentObject {
/**
* @return The name of the XContentObject that is to be serialized
*/
String getName();
}

View File

@ -0,0 +1,171 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
//TODO these tests are temporary until the named objects are actually used by an encompassing class (i.e. ModelInferer)
public class NamedXContentObjectsTests extends AbstractXContentTestCase<NamedXContentObjectsTests.NamedObjectContainer> {
static class NamedObjectContainer implements ToXContentObject {
static ParseField PRE_PROCESSORS = new ParseField("pre_processors");
static final ObjectParser<NamedObjectContainer, Void> STRICT_PARSER = createParser(false);
static final ObjectParser<NamedObjectContainer, Void> LENIENT_PARSER = createParser(true);
@SuppressWarnings("unchecked")
private static ObjectParser<NamedObjectContainer, Void> createParser(boolean lenient) {
ObjectParser<NamedObjectContainer, Void> parser = new ObjectParser<>(
"named_xcontent_object_container_test",
lenient,
NamedObjectContainer::new);
parser.declareNamedObjects(NamedObjectContainer::setPreProcessors,
(p, c, n) ->
lenient ? p.namedObject(LenientlyParsedPreProcessor.class, n, null) :
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
(noc) -> noc.setUseExplicitPreprocessorOrder(true), PRE_PROCESSORS);
return parser;
}
private boolean useExplicitPreprocessorOrder = false;
private List<? extends PreProcessor> preProcessors;
void setPreProcessors(List<? extends PreProcessor> preProcessors) {
this.preProcessors = preProcessors;
}
void setUseExplicitPreprocessorOrder(boolean value) {
this.useExplicitPreprocessorOrder = value;
}
static NamedObjectContainer fromXContent(XContentParser parser, boolean lenient) {
return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
writeNamedObjects(builder, params, useExplicitPreprocessorOrder, PRE_PROCESSORS.getPreferredName(), preProcessors);
builder.endObject();
return builder;
}
XContentBuilder writeNamedObjects(XContentBuilder builder,
Params params,
boolean useExplicitOrder,
String namedObjectsName,
List<? extends NamedXContentObject> namedObjects) throws IOException {
if (useExplicitOrder) {
builder.startArray(namedObjectsName);
} else {
builder.startObject(namedObjectsName);
}
for (NamedXContentObject object : namedObjects) {
if (useExplicitOrder) {
builder.startObject();
}
builder.field(object.getName(), object, params);
if (useExplicitOrder) {
builder.endObject();
}
}
if (useExplicitOrder) {
builder.endArray();
} else {
builder.endObject();
}
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
NamedObjectContainer that = (NamedObjectContainer) o;
return Objects.equals(preProcessors, that.preProcessors);
}
@Override
public int hashCode() {
return Objects.hash(preProcessors);
}
}
private boolean lenient;
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
public NamedObjectContainer createTestInstance() {
int max = randomIntBetween(1, 10);
List<PreProcessor> preProcessors = new ArrayList<>(max);
for (int i = 0; i < max; i++) {
preProcessors.add(randomFrom(FrequencyEncodingTests.createRandom(),
OneHotEncodingTests.createRandom(),
TargetMeanEncodingTests.createRandom()));
}
NamedObjectContainer container = new NamedObjectContainer();
container.setPreProcessors(preProcessors);
container.setUseExplicitPreprocessorOrder(true);
return container;
}
@Override
protected NamedObjectContainer doParseInstance(XContentParser parser) throws IOException {
return NamedObjectContainer.fromXContent(parser, lenient);
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
// We only want to add random fields to the root, or the root of the named objects
return field ->
(field.endsWith("frequency_encoding") ||
field.endsWith("one_hot_encoding") ||
field.endsWith("target_mean_encoding") ||
field.isEmpty()) == false;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
}

View File

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.hamcrest.Matcher;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.equalTo;
public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding> {
@Override
protected FrequencyEncoding doParseInstance(XContentParser parser) throws IOException {
return lenient ? FrequencyEncoding.fromXContentLenient(parser) : FrequencyEncoding.fromXContentStrict(parser);
}
@Override
protected FrequencyEncoding createTestInstance() {
return createRandom();
}
public static FrequencyEncoding createRandom() {
int valuesSize = randomIntBetween(1, 10);
Map<String, Double> valueMap = new HashMap<>();
for (int i = 0; i < valuesSize; i++) {
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
}
return new FrequencyEncoding(randomAlphaOfLength(10), randomAlphaOfLength(10), valueMap);
}
@Override
protected Writeable.Reader<FrequencyEncoding> instanceReader() {
return FrequencyEncoding::new;
}
public void testProcessWithFieldPresent() {
String field = "categorical";
List<String> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote");
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Function.identity(),
v -> randomDoubleBetween(0.0, 1.0, false)));
String encodedFeatureName = "encoded";
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap);
String fieldValue = randomFrom(values);
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName, equalTo(valueMap.get(fieldValue)));
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
testProcess(encoding, fieldValues, matchers);
// Test where the value is some unknown Value
fieldValues = randomFieldValues(field, "unknownValue");
fieldValues.put(field, "unknownValue");
matchers = Collections.singletonMap(encodedFeatureName, equalTo(0.0));
testProcess(encoding, fieldValues, matchers);
}
}

View File

@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.hamcrest.Matcher;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.equalTo;
public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
@Override
protected OneHotEncoding doParseInstance(XContentParser parser) throws IOException {
return lenient ? OneHotEncoding.fromXContentLenient(parser) : OneHotEncoding.fromXContentStrict(parser);
}
@Override
protected OneHotEncoding createTestInstance() {
return createRandom();
}
public static OneHotEncoding createRandom() {
int valuesSize = randomIntBetween(1, 10);
Map<String, String> valueMap = new HashMap<>();
for (int i = 0; i < valuesSize; i++) {
valueMap.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
}
return new OneHotEncoding(randomAlphaOfLength(10), valueMap);
}
@Override
protected Writeable.Reader<OneHotEncoding> instanceReader() {
return OneHotEncoding::new;
}
public void testProcessWithFieldPresent() {
String field = "categorical";
List<String> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote");
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Function.identity(), v -> "Column_" + v));
OneHotEncoding encoding = new OneHotEncoding(field, valueMap);
String fieldValue = randomFrom(values);
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
Map<String, Matcher<? super Object>> matchers = values.stream().map(v -> "Column_" + v)
.collect(Collectors.toMap(
Function.identity(),
v -> v.equals("Column_" + fieldValue) ? equalTo(1) : equalTo(0)));
fieldValues.put(field, fieldValue);
testProcess(encoding, fieldValues, matchers);
// Test where the value is some unknown Value
fieldValues = randomFieldValues(field, "unknownValue");
matchers.put("Column_" + fieldValue, equalTo(0));
testProcess(encoding, fieldValues, matchers);
}
}

View File

@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.hamcrest.Matcher;
import org.junit.Before;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Predicate;
import static org.hamcrest.Matchers.equalTo;
public abstract class PreProcessingTests<T extends PreProcessor> extends AbstractSerializingTestCase<T> {
protected boolean lenient;
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> !field.isEmpty();
}
void testProcess(PreProcessor preProcessor, Map<String, Object> fieldValues, Map<String, Matcher<? super Object>> assertions) {
preProcessor.process(fieldValues);
assertions.forEach((fieldName, matcher) ->
assertThat(fieldValues.get(fieldName), matcher)
);
}
public void testWithMissingField() {
Map<String, Object> fields = randomFieldValues();
PreProcessor preProcessor = this.createTestInstance();
Map<String, Object> fieldsCopy = new HashMap<>(fields);
preProcessor.process(fields);
assertThat(fieldsCopy, equalTo(fields));
}
Map<String, Object> randomFieldValues() {
int numFields = randomIntBetween(1, 5);
Map<String, Object> fieldValues = new HashMap<>(numFields);
for (int k = 0; k < numFields; k++) {
fieldValues.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
}
return fieldValues;
}
Map<String, Object> randomFieldValues(String categoricalField, String catigoricalValue) {
Map<String, Object> fieldValues = randomFieldValues();
fieldValues.put(categoricalField, catigoricalValue);
return fieldValues;
}
}

View File

@ -0,0 +1,71 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.hamcrest.Matcher;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.equalTo;
public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncoding> {
@Override
protected TargetMeanEncoding doParseInstance(XContentParser parser) throws IOException {
return lenient ? TargetMeanEncoding.fromXContentLenient(parser) : TargetMeanEncoding.fromXContentStrict(parser);
}
@Override
protected TargetMeanEncoding createTestInstance() {
return createRandom();
}
public static TargetMeanEncoding createRandom() {
int valuesSize = randomIntBetween(1, 10);
Map<String, Double> valueMap = new HashMap<>();
for (int i = 0; i < valuesSize; i++) {
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
}
return new TargetMeanEncoding(randomAlphaOfLength(10),
randomAlphaOfLength(10),
valueMap,
randomDoubleBetween(0.0, 1.0, false));
}
@Override
protected Writeable.Reader<TargetMeanEncoding> instanceReader() {
return TargetMeanEncoding::new;
}
public void testProcessWithFieldPresent() {
String field = "categorical";
List<String> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote");
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Function.identity(),
v -> randomDoubleBetween(0.0, 1.0, false)));
String encodedFeatureName = "encoded";
Double defaultvalue = randomDouble();
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue);
String fieldValue = randomFrom(values);
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName, equalTo(valueMap.get(fieldValue)));
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
testProcess(encoding, fieldValues, matchers);
// Test where the value is some unknown Value
fieldValues = randomFieldValues(field, "unknownValue");
matchers = Collections.singletonMap(encodedFeatureName, equalTo(defaultvalue));
testProcess(encoding, fieldValues, matchers);
}
}

View File

@ -122,6 +122,7 @@ import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction;
import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields;
import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
@ -950,6 +951,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return namedXContent;
}
}