[7.x] [ML] adding docs + hlrc for data frame analysis feature_processors (#61149) (#61493)

* [ML] adding docs + hlrc for data frame analysis feature_processors (#61149)

Adds HLRC and some docs for the new feature_processors field in Data frame analytics.

Co-authored-by: Przemysław Witek <przemyslaw.witek@elastic.co>
Co-authored-by: Lisa Cawley <lcawley@elastic.co>
This commit is contained in:
Benjamin Trent 2020-08-24 12:56:21 -04:00 committed by GitHub
parent d05649bfae
commit 1ae2923632
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 254 additions and 116 deletions

View File

@ -18,6 +18,8 @@
*/
package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper;
import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
@ -26,6 +28,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
@ -53,7 +56,9 @@ public class Classification implements DataFrameAnalysis {
static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Classification, Void> PARSER =
new ConstructingObjectParser<>(
NAME.getPreferredName(),
@ -70,7 +75,8 @@ public class Classification implements DataFrameAnalysis {
(Double) a[8],
(Integer) a[9],
(Long) a[10],
(ClassAssignmentObjective) a[11]));
(ClassAssignmentObjective) a[11],
(List<PreProcessor>) a[12]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@ -86,6 +92,10 @@ public class Classification implements DataFrameAnalysis {
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
PARSER.declareString(
ConstructingObjectParser.optionalConstructorArg(), ClassAssignmentObjective::fromString, CLASS_ASSIGNMENT_OBJECTIVE);
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(PreProcessor.class, n, c),
(classification) -> {},
FEATURE_PROCESSORS);
}
private final String dependentVariable;
@ -100,12 +110,13 @@ public class Classification implements DataFrameAnalysis {
private final ClassAssignmentObjective classAssignmentObjective;
private final Integer numTopClasses;
private final Long randomizeSeed;
private final List<PreProcessor> featureProcessors;
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed,
@Nullable ClassAssignmentObjective classAssignmentObjective) {
@Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable List<PreProcessor> featureProcessors) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
@ -118,6 +129,7 @@ public class Classification implements DataFrameAnalysis {
this.classAssignmentObjective = classAssignmentObjective;
this.numTopClasses = numTopClasses;
this.randomizeSeed = randomizeSeed;
this.featureProcessors = featureProcessors;
}
@Override
@ -173,6 +185,10 @@ public class Classification implements DataFrameAnalysis {
return numTopClasses;
}
public List<PreProcessor> getFeatureProcessors() {
return featureProcessors;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@ -210,6 +226,9 @@ public class Classification implements DataFrameAnalysis {
if (numTopClasses != null) {
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
}
if (featureProcessors != null) {
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
}
builder.endObject();
return builder;
}
@ -217,7 +236,7 @@ public class Classification implements DataFrameAnalysis {
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective);
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective, featureProcessors);
}
@Override
@ -236,7 +255,8 @@ public class Classification implements DataFrameAnalysis {
&& Objects.equals(trainingPercent, that.trainingPercent)
&& Objects.equals(randomizeSeed, that.randomizeSeed)
&& Objects.equals(numTopClasses, that.numTopClasses)
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective);
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
&& Objects.equals(featureProcessors, that.featureProcessors);
}
@Override
@ -270,6 +290,7 @@ public class Classification implements DataFrameAnalysis {
private Integer numTopClasses;
private Long randomizeSeed;
private ClassAssignmentObjective classAssignmentObjective;
private List<PreProcessor> featureProcessors;
private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@ -330,10 +351,15 @@ public class Classification implements DataFrameAnalysis {
return this;
}
public Builder setFeatureProcessors(List<PreProcessor> featureProcessors) {
this.featureProcessors = featureProcessors;
return this;
}
public Classification build() {
return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed,
classAssignmentObjective);
classAssignmentObjective, featureProcessors);
}
}
}

View File

@ -18,6 +18,8 @@
*/
package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper;
import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
@ -26,6 +28,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
@ -55,7 +58,9 @@ public class Regression implements DataFrameAnalysis {
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Regression, Void> PARSER =
new ConstructingObjectParser<>(
NAME.getPreferredName(),
@ -72,7 +77,8 @@ public class Regression implements DataFrameAnalysis {
(Double) a[8],
(Long) a[9],
(LossFunction) a[10],
(Double) a[11]
(Double) a[11],
(List<PreProcessor>) a[12]
));
static {
@ -88,6 +94,10 @@ public class Regression implements DataFrameAnalysis {
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
PARSER.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(PreProcessor.class, n, c),
(regression) -> {},
FEATURE_PROCESSORS);
}
private final String dependentVariable;
@ -102,12 +112,13 @@ public class Regression implements DataFrameAnalysis {
private final Long randomizeSeed;
private final LossFunction lossFunction;
private final Double lossFunctionParameter;
private final List<PreProcessor> featureProcessors;
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
@Nullable Double trainingPercent, @Nullable Long randomizeSeed, @Nullable LossFunction lossFunction,
@Nullable Double lossFunctionParameter) {
@Nullable Double lossFunctionParameter, @Nullable List<PreProcessor> featureProcessors) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
this.lambda = lambda;
this.gamma = gamma;
@ -120,6 +131,7 @@ public class Regression implements DataFrameAnalysis {
this.randomizeSeed = randomizeSeed;
this.lossFunction = lossFunction;
this.lossFunctionParameter = lossFunctionParameter;
this.featureProcessors = featureProcessors;
}
@Override
@ -175,6 +187,10 @@ public class Regression implements DataFrameAnalysis {
return lossFunctionParameter;
}
public List<PreProcessor> getFeatureProcessors() {
return featureProcessors;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@ -212,6 +228,9 @@ public class Regression implements DataFrameAnalysis {
if (lossFunctionParameter != null) {
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
}
if (featureProcessors != null) {
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
}
builder.endObject();
return builder;
}
@ -219,7 +238,7 @@ public class Regression implements DataFrameAnalysis {
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter);
predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter, featureProcessors);
}
@Override
@ -238,7 +257,8 @@ public class Regression implements DataFrameAnalysis {
&& Objects.equals(trainingPercent, that.trainingPercent)
&& Objects.equals(randomizeSeed, that.randomizeSeed)
&& Objects.equals(lossFunction, that.lossFunction)
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter)
&& Objects.equals(featureProcessors, that.featureProcessors);
}
@Override
@ -259,6 +279,7 @@ public class Regression implements DataFrameAnalysis {
private Long randomizeSeed;
private LossFunction lossFunction;
private Double lossFunctionParameter;
private List<PreProcessor> featureProcessors;
private Builder(String dependentVariable) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@ -319,9 +340,15 @@ public class Regression implements DataFrameAnalysis {
return this;
}
public Builder setFeatureProcessors(List<PreProcessor> featureProcessors) {
this.featureProcessors = featureProcessors;
return this;
}
public Regression build() {
return new Regression(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter);
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter,
featureProcessors);
}
}

View File

@ -114,7 +114,7 @@ public class OneHotEncoding implements PreProcessor {
return Objects.hash(field, hotMap, custom);
}
public Builder builder(String field) {
public static Builder builder(String field) {
return new Builder(field);
}

View File

@ -179,6 +179,7 @@ import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.client.ml.inference.TrainedModelInput;
import org.elasticsearch.client.ml.inference.TrainedModelStats;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
@ -3003,6 +3004,9 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.setRandomizeSeed(1234L) // <10>
.setClassAssignmentObjective(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY) // <11>
.setNumTopClasses(1) // <12>
.setFeatureProcessors(Arrays.asList(OneHotEncoding.builder("categorical_feature") // <13>
.addOneHot("cat", "cat_column")
.build()))
.build();
// end::put-data-frame-analytics-classification
@ -3019,6 +3023,9 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.setRandomizeSeed(1234L) // <10>
.setLossFunction(Regression.LossFunction.MSE) // <11>
.setLossFunctionParameter(1.0) // <12>
.setFeatureProcessors(Arrays.asList(OneHotEncoding.builder("categorical_feature") // <13>
.addOneHot("cat", "cat_column")
.build()))
.build();
// end::put-data-frame-analytics-regression

View File

@ -18,10 +18,20 @@
*/
package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class ClassificationTests extends AbstractXContentTestCase<Classification> {
@ -38,9 +48,20 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
.setRandomizeSeed(randomBoolean() ? null : randomLong())
.setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
.setFeatureProcessors(randomBoolean() ? null :
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
OneHotEncodingTests.createRandom(),
TargetMeanEncodingTests.createRandom()))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList()))
.build();
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> field.startsWith("feature_processors");
}
@Override
protected Classification createTestInstance() {
return randomClassification();
@ -55,4 +76,11 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}
}

View File

@ -20,6 +20,7 @@
package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.Version;
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
@ -101,6 +102,7 @@ public class DataFrameAnalyticsConfigTests extends AbstractXContentTestCase<Data
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}
}

View File

@ -18,10 +18,20 @@
*/
package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class RegressionTests extends AbstractXContentTestCase<Regression> {
@ -37,9 +47,20 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
.setLossFunction(randomBoolean() ? null : randomFrom(Regression.LossFunction.values()))
.setLossFunctionParameter(randomBoolean() ? null : randomDoubleBetween(1.0, Double.MAX_VALUE, true))
.setFeatureProcessors(randomBoolean() ? null :
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
OneHotEncodingTests.createRandom(),
TargetMeanEncodingTests.createRandom()))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList()))
.build();
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> field.startsWith("feature_processors");
}
@Override
protected Regression createTestInstance() {
return randomRegression();
@ -54,4 +75,11 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}
}

View File

@ -124,6 +124,8 @@ include-tagged::{doc-tests-file}[{api}-classification]
<10> The seed to be used by the random generator that picks which rows are used in training.
<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
<12> The number of top classes to be reported in the results. Defaults to 2.
<13> Custom feature processors that will create new features for analysis from the included document
fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features.
===== Regression
@ -146,6 +148,8 @@ include-tagged::{doc-tests-file}[{api}-regression]
<10> The seed to be used by the random generator that picks which rows are used in training.
<11> The loss function used for regression. Defaults to `mse`.
<12> An optional parameter to the loss function.
<13> Custom feature processors that will create new features for analysis from the included document
fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features.
==== Analyzed fields

View File

@ -115,6 +115,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=eta]
(Optional, double)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=feature-bag-fraction]
`feature_processors`::::
(Optional, list)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-feature-processors]
`gamma`::::
(Optional, double)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=gamma]
@ -215,6 +219,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=eta]
(Optional, double)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=feature-bag-fraction]
`feature_processors`::::
(Optional, list)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-feature-processors]
`gamma`::::
(Optional, double)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=gamma]

View File

@ -517,6 +517,14 @@ Specifies the rate at which the `eta` increases for each new tree that is added
forest. For example, a rate of `1.05` increases `eta` by 5%.
end::dfas-eta-growth[]
tag::dfas-feature-processors[]
A collection of feature preprocessors that modify one or more included fields.
The analysis uses the resulting one or more features instead of the
original document field. Multiple `feature_processors` entries can refer to the
same document fields.
Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs.
end::dfas-feature-processors[]
tag::dfas-iteration[]
The number of iterations on the analysis.
end::dfas-iteration[]