mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-03-09 14:34:43 +00:00
* [ML] adding docs + hlrc for data frame analysis feature_processors (#61149) Adds HLRC and some docs for the new feature_processors field in Data frame analytics. Co-authored-by: Przemysław Witek <przemyslaw.witek@elastic.co> Co-authored-by: Lisa Cawley <lcawley@elastic.co>
This commit is contained in:
parent
d05649bfae
commit
1ae2923632
@ -18,6 +18,8 @@
|
|||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml.dataframe;
|
package org.elasticsearch.client.ml.dataframe;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper;
|
||||||
|
import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
@ -26,6 +28,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
|
|||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
@ -53,7 +56,9 @@ public class Classification implements DataFrameAnalysis {
|
|||||||
static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
|
static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
|
||||||
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||||
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||||
|
static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
private static final ConstructingObjectParser<Classification, Void> PARSER =
|
private static final ConstructingObjectParser<Classification, Void> PARSER =
|
||||||
new ConstructingObjectParser<>(
|
new ConstructingObjectParser<>(
|
||||||
NAME.getPreferredName(),
|
NAME.getPreferredName(),
|
||||||
@ -70,7 +75,8 @@ public class Classification implements DataFrameAnalysis {
|
|||||||
(Double) a[8],
|
(Double) a[8],
|
||||||
(Integer) a[9],
|
(Integer) a[9],
|
||||||
(Long) a[10],
|
(Long) a[10],
|
||||||
(ClassAssignmentObjective) a[11]));
|
(ClassAssignmentObjective) a[11],
|
||||||
|
(List<PreProcessor>) a[12]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||||
@ -86,6 +92,10 @@ public class Classification implements DataFrameAnalysis {
|
|||||||
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
|
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
|
||||||
PARSER.declareString(
|
PARSER.declareString(
|
||||||
ConstructingObjectParser.optionalConstructorArg(), ClassAssignmentObjective::fromString, CLASS_ASSIGNMENT_OBJECTIVE);
|
ConstructingObjectParser.optionalConstructorArg(), ClassAssignmentObjective::fromString, CLASS_ASSIGNMENT_OBJECTIVE);
|
||||||
|
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
|
||||||
|
(p, c, n) -> p.namedObject(PreProcessor.class, n, c),
|
||||||
|
(classification) -> {},
|
||||||
|
FEATURE_PROCESSORS);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final String dependentVariable;
|
private final String dependentVariable;
|
||||||
@ -100,12 +110,13 @@ public class Classification implements DataFrameAnalysis {
|
|||||||
private final ClassAssignmentObjective classAssignmentObjective;
|
private final ClassAssignmentObjective classAssignmentObjective;
|
||||||
private final Integer numTopClasses;
|
private final Integer numTopClasses;
|
||||||
private final Long randomizeSeed;
|
private final Long randomizeSeed;
|
||||||
|
private final List<PreProcessor> featureProcessors;
|
||||||
|
|
||||||
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||||
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
|
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
|
||||||
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
|
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
|
||||||
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed,
|
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed,
|
||||||
@Nullable ClassAssignmentObjective classAssignmentObjective) {
|
@Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable List<PreProcessor> featureProcessors) {
|
||||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||||
this.lambda = lambda;
|
this.lambda = lambda;
|
||||||
this.gamma = gamma;
|
this.gamma = gamma;
|
||||||
@ -118,6 +129,7 @@ public class Classification implements DataFrameAnalysis {
|
|||||||
this.classAssignmentObjective = classAssignmentObjective;
|
this.classAssignmentObjective = classAssignmentObjective;
|
||||||
this.numTopClasses = numTopClasses;
|
this.numTopClasses = numTopClasses;
|
||||||
this.randomizeSeed = randomizeSeed;
|
this.randomizeSeed = randomizeSeed;
|
||||||
|
this.featureProcessors = featureProcessors;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -173,6 +185,10 @@ public class Classification implements DataFrameAnalysis {
|
|||||||
return numTopClasses;
|
return numTopClasses;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<PreProcessor> getFeatureProcessors() {
|
||||||
|
return featureProcessors;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
@ -210,6 +226,9 @@ public class Classification implements DataFrameAnalysis {
|
|||||||
if (numTopClasses != null) {
|
if (numTopClasses != null) {
|
||||||
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
|
||||||
}
|
}
|
||||||
|
if (featureProcessors != null) {
|
||||||
|
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
|
||||||
|
}
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
@ -217,7 +236,7 @@ public class Classification implements DataFrameAnalysis {
|
|||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
|
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
|
||||||
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective);
|
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective, featureProcessors);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -236,7 +255,8 @@ public class Classification implements DataFrameAnalysis {
|
|||||||
&& Objects.equals(trainingPercent, that.trainingPercent)
|
&& Objects.equals(trainingPercent, that.trainingPercent)
|
||||||
&& Objects.equals(randomizeSeed, that.randomizeSeed)
|
&& Objects.equals(randomizeSeed, that.randomizeSeed)
|
||||||
&& Objects.equals(numTopClasses, that.numTopClasses)
|
&& Objects.equals(numTopClasses, that.numTopClasses)
|
||||||
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective);
|
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
|
||||||
|
&& Objects.equals(featureProcessors, that.featureProcessors);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -270,6 +290,7 @@ public class Classification implements DataFrameAnalysis {
|
|||||||
private Integer numTopClasses;
|
private Integer numTopClasses;
|
||||||
private Long randomizeSeed;
|
private Long randomizeSeed;
|
||||||
private ClassAssignmentObjective classAssignmentObjective;
|
private ClassAssignmentObjective classAssignmentObjective;
|
||||||
|
private List<PreProcessor> featureProcessors;
|
||||||
|
|
||||||
private Builder(String dependentVariable) {
|
private Builder(String dependentVariable) {
|
||||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||||
@ -330,10 +351,15 @@ public class Classification implements DataFrameAnalysis {
|
|||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Builder setFeatureProcessors(List<PreProcessor> featureProcessors) {
|
||||||
|
this.featureProcessors = featureProcessors;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public Classification build() {
|
public Classification build() {
|
||||||
return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
|
return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
|
||||||
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed,
|
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed,
|
||||||
classAssignmentObjective);
|
classAssignmentObjective, featureProcessors);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,8 @@
|
|||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml.dataframe;
|
package org.elasticsearch.client.ml.dataframe;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper;
|
||||||
|
import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
@ -26,6 +28,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
|
|||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
@ -55,7 +58,9 @@ public class Regression implements DataFrameAnalysis {
|
|||||||
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||||
static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
|
static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
|
||||||
static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
|
static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
|
||||||
|
static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
private static final ConstructingObjectParser<Regression, Void> PARSER =
|
private static final ConstructingObjectParser<Regression, Void> PARSER =
|
||||||
new ConstructingObjectParser<>(
|
new ConstructingObjectParser<>(
|
||||||
NAME.getPreferredName(),
|
NAME.getPreferredName(),
|
||||||
@ -72,7 +77,8 @@ public class Regression implements DataFrameAnalysis {
|
|||||||
(Double) a[8],
|
(Double) a[8],
|
||||||
(Long) a[9],
|
(Long) a[9],
|
||||||
(LossFunction) a[10],
|
(LossFunction) a[10],
|
||||||
(Double) a[11]
|
(Double) a[11],
|
||||||
|
(List<PreProcessor>) a[12]
|
||||||
));
|
));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
@ -88,6 +94,10 @@ public class Regression implements DataFrameAnalysis {
|
|||||||
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
|
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
|
||||||
PARSER.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION);
|
PARSER.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION);
|
||||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
|
||||||
|
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
|
||||||
|
(p, c, n) -> p.namedObject(PreProcessor.class, n, c),
|
||||||
|
(regression) -> {},
|
||||||
|
FEATURE_PROCESSORS);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final String dependentVariable;
|
private final String dependentVariable;
|
||||||
@ -102,12 +112,13 @@ public class Regression implements DataFrameAnalysis {
|
|||||||
private final Long randomizeSeed;
|
private final Long randomizeSeed;
|
||||||
private final LossFunction lossFunction;
|
private final LossFunction lossFunction;
|
||||||
private final Double lossFunctionParameter;
|
private final Double lossFunctionParameter;
|
||||||
|
private final List<PreProcessor> featureProcessors;
|
||||||
|
|
||||||
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||||
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
|
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
|
||||||
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
|
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
|
||||||
@Nullable Double trainingPercent, @Nullable Long randomizeSeed, @Nullable LossFunction lossFunction,
|
@Nullable Double trainingPercent, @Nullable Long randomizeSeed, @Nullable LossFunction lossFunction,
|
||||||
@Nullable Double lossFunctionParameter) {
|
@Nullable Double lossFunctionParameter, @Nullable List<PreProcessor> featureProcessors) {
|
||||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||||
this.lambda = lambda;
|
this.lambda = lambda;
|
||||||
this.gamma = gamma;
|
this.gamma = gamma;
|
||||||
@ -120,6 +131,7 @@ public class Regression implements DataFrameAnalysis {
|
|||||||
this.randomizeSeed = randomizeSeed;
|
this.randomizeSeed = randomizeSeed;
|
||||||
this.lossFunction = lossFunction;
|
this.lossFunction = lossFunction;
|
||||||
this.lossFunctionParameter = lossFunctionParameter;
|
this.lossFunctionParameter = lossFunctionParameter;
|
||||||
|
this.featureProcessors = featureProcessors;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -175,6 +187,10 @@ public class Regression implements DataFrameAnalysis {
|
|||||||
return lossFunctionParameter;
|
return lossFunctionParameter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<PreProcessor> getFeatureProcessors() {
|
||||||
|
return featureProcessors;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
@ -212,6 +228,9 @@ public class Regression implements DataFrameAnalysis {
|
|||||||
if (lossFunctionParameter != null) {
|
if (lossFunctionParameter != null) {
|
||||||
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
|
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
|
||||||
}
|
}
|
||||||
|
if (featureProcessors != null) {
|
||||||
|
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
|
||||||
|
}
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
@ -219,7 +238,7 @@ public class Regression implements DataFrameAnalysis {
|
|||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
|
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
|
||||||
predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter);
|
predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter, featureProcessors);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -238,7 +257,8 @@ public class Regression implements DataFrameAnalysis {
|
|||||||
&& Objects.equals(trainingPercent, that.trainingPercent)
|
&& Objects.equals(trainingPercent, that.trainingPercent)
|
||||||
&& Objects.equals(randomizeSeed, that.randomizeSeed)
|
&& Objects.equals(randomizeSeed, that.randomizeSeed)
|
||||||
&& Objects.equals(lossFunction, that.lossFunction)
|
&& Objects.equals(lossFunction, that.lossFunction)
|
||||||
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
|
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter)
|
||||||
|
&& Objects.equals(featureProcessors, that.featureProcessors);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -259,6 +279,7 @@ public class Regression implements DataFrameAnalysis {
|
|||||||
private Long randomizeSeed;
|
private Long randomizeSeed;
|
||||||
private LossFunction lossFunction;
|
private LossFunction lossFunction;
|
||||||
private Double lossFunctionParameter;
|
private Double lossFunctionParameter;
|
||||||
|
private List<PreProcessor> featureProcessors;
|
||||||
|
|
||||||
private Builder(String dependentVariable) {
|
private Builder(String dependentVariable) {
|
||||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||||
@ -319,9 +340,15 @@ public class Regression implements DataFrameAnalysis {
|
|||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Builder setFeatureProcessors(List<PreProcessor> featureProcessors) {
|
||||||
|
this.featureProcessors = featureProcessors;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public Regression build() {
|
public Regression build() {
|
||||||
return new Regression(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
|
return new Regression(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
|
||||||
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter);
|
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter,
|
||||||
|
featureProcessors);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ public class OneHotEncoding implements PreProcessor {
|
|||||||
return Objects.hash(field, hotMap, custom);
|
return Objects.hash(field, hotMap, custom);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Builder builder(String field) {
|
public static Builder builder(String field) {
|
||||||
return new Builder(field);
|
return new Builder(field);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,6 +179,7 @@ import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
|
|||||||
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
|
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
|
||||||
import org.elasticsearch.client.ml.inference.TrainedModelInput;
|
import org.elasticsearch.client.ml.inference.TrainedModelInput;
|
||||||
import org.elasticsearch.client.ml.inference.TrainedModelStats;
|
import org.elasticsearch.client.ml.inference.TrainedModelStats;
|
||||||
|
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
|
||||||
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
|
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
|
||||||
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
||||||
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
|
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
|
||||||
@ -3003,6 +3004,9 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||||||
.setRandomizeSeed(1234L) // <10>
|
.setRandomizeSeed(1234L) // <10>
|
||||||
.setClassAssignmentObjective(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY) // <11>
|
.setClassAssignmentObjective(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY) // <11>
|
||||||
.setNumTopClasses(1) // <12>
|
.setNumTopClasses(1) // <12>
|
||||||
|
.setFeatureProcessors(Arrays.asList(OneHotEncoding.builder("categorical_feature") // <13>
|
||||||
|
.addOneHot("cat", "cat_column")
|
||||||
|
.build()))
|
||||||
.build();
|
.build();
|
||||||
// end::put-data-frame-analytics-classification
|
// end::put-data-frame-analytics-classification
|
||||||
|
|
||||||
@ -3019,6 +3023,9 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||||||
.setRandomizeSeed(1234L) // <10>
|
.setRandomizeSeed(1234L) // <10>
|
||||||
.setLossFunction(Regression.LossFunction.MSE) // <11>
|
.setLossFunction(Regression.LossFunction.MSE) // <11>
|
||||||
.setLossFunctionParameter(1.0) // <12>
|
.setLossFunctionParameter(1.0) // <12>
|
||||||
|
.setFeatureProcessors(Arrays.asList(OneHotEncoding.builder("categorical_feature") // <13>
|
||||||
|
.addOneHot("cat", "cat_column")
|
||||||
|
.build()))
|
||||||
.build();
|
.build();
|
||||||
// end::put-data-frame-analytics-regression
|
// end::put-data-frame-analytics-regression
|
||||||
|
|
||||||
|
@ -18,10 +18,20 @@
|
|||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml.dataframe;
|
package org.elasticsearch.client.ml.dataframe;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
|
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||||
|
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests;
|
||||||
|
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
public class ClassificationTests extends AbstractXContentTestCase<Classification> {
|
public class ClassificationTests extends AbstractXContentTestCase<Classification> {
|
||||||
|
|
||||||
@ -38,9 +48,20 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
|||||||
.setRandomizeSeed(randomBoolean() ? null : randomLong())
|
.setRandomizeSeed(randomBoolean() ? null : randomLong())
|
||||||
.setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
|
.setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
|
||||||
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
|
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
|
||||||
|
.setFeatureProcessors(randomBoolean() ? null :
|
||||||
|
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
||||||
|
OneHotEncodingTests.createRandom(),
|
||||||
|
TargetMeanEncodingTests.createRandom()))
|
||||||
|
.limit(randomIntBetween(1, 10))
|
||||||
|
.collect(Collectors.toList()))
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||||
|
return field -> field.startsWith("feature_processors");
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Classification createTestInstance() {
|
protected Classification createTestInstance() {
|
||||||
return randomClassification();
|
return randomClassification();
|
||||||
@ -55,4 +76,11 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
|||||||
protected boolean supportsUnknownFields() {
|
protected boolean supportsUnknownFields() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||||
|
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
return new NamedXContentRegistry(namedXContent);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
package org.elasticsearch.client.ml.dataframe;
|
package org.elasticsearch.client.ml.dataframe;
|
||||||
|
|
||||||
import org.elasticsearch.Version;
|
import org.elasticsearch.Version;
|
||||||
|
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||||
@ -101,6 +102,7 @@ public class DataFrameAnalyticsConfigTests extends AbstractXContentTestCase<Data
|
|||||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||||
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
return new NamedXContentRegistry(namedXContent);
|
return new NamedXContentRegistry(namedXContent);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,10 +18,20 @@
|
|||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml.dataframe;
|
package org.elasticsearch.client.ml.dataframe;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
|
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||||
|
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests;
|
||||||
|
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
||||||
|
|
||||||
@ -37,9 +47,20 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
|||||||
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
||||||
.setLossFunction(randomBoolean() ? null : randomFrom(Regression.LossFunction.values()))
|
.setLossFunction(randomBoolean() ? null : randomFrom(Regression.LossFunction.values()))
|
||||||
.setLossFunctionParameter(randomBoolean() ? null : randomDoubleBetween(1.0, Double.MAX_VALUE, true))
|
.setLossFunctionParameter(randomBoolean() ? null : randomDoubleBetween(1.0, Double.MAX_VALUE, true))
|
||||||
|
.setFeatureProcessors(randomBoolean() ? null :
|
||||||
|
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
||||||
|
OneHotEncodingTests.createRandom(),
|
||||||
|
TargetMeanEncodingTests.createRandom()))
|
||||||
|
.limit(randomIntBetween(1, 10))
|
||||||
|
.collect(Collectors.toList()))
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||||
|
return field -> field.startsWith("feature_processors");
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Regression createTestInstance() {
|
protected Regression createTestInstance() {
|
||||||
return randomRegression();
|
return randomRegression();
|
||||||
@ -54,4 +75,11 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
|||||||
protected boolean supportsUnknownFields() {
|
protected boolean supportsUnknownFields() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||||
|
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
return new NamedXContentRegistry(namedXContent);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -124,6 +124,8 @@ include-tagged::{doc-tests-file}[{api}-classification]
|
|||||||
<10> The seed to be used by the random generator that picks which rows are used in training.
|
<10> The seed to be used by the random generator that picks which rows are used in training.
|
||||||
<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
|
<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
|
||||||
<12> The number of top classes to be reported in the results. Defaults to 2.
|
<12> The number of top classes to be reported in the results. Defaults to 2.
|
||||||
|
<13> Custom feature processors that will create new features for analysis from the included document
|
||||||
|
fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features.
|
||||||
|
|
||||||
===== Regression
|
===== Regression
|
||||||
|
|
||||||
@ -146,6 +148,8 @@ include-tagged::{doc-tests-file}[{api}-regression]
|
|||||||
<10> The seed to be used by the random generator that picks which rows are used in training.
|
<10> The seed to be used by the random generator that picks which rows are used in training.
|
||||||
<11> The loss function used for regression. Defaults to `mse`.
|
<11> The loss function used for regression. Defaults to `mse`.
|
||||||
<12> An optional parameter to the loss function.
|
<12> An optional parameter to the loss function.
|
||||||
|
<13> Custom feature processors that will create new features for analysis from the included document
|
||||||
|
fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features.
|
||||||
|
|
||||||
==== Analyzed fields
|
==== Analyzed fields
|
||||||
|
|
||||||
|
@ -115,6 +115,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=eta]
|
|||||||
(Optional, double)
|
(Optional, double)
|
||||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=feature-bag-fraction]
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=feature-bag-fraction]
|
||||||
|
|
||||||
|
`feature_processors`::::
|
||||||
|
(Optional, list)
|
||||||
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-feature-processors]
|
||||||
|
|
||||||
`gamma`::::
|
`gamma`::::
|
||||||
(Optional, double)
|
(Optional, double)
|
||||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=gamma]
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=gamma]
|
||||||
@ -215,6 +219,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=eta]
|
|||||||
(Optional, double)
|
(Optional, double)
|
||||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=feature-bag-fraction]
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=feature-bag-fraction]
|
||||||
|
|
||||||
|
`feature_processors`::::
|
||||||
|
(Optional, list)
|
||||||
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-feature-processors]
|
||||||
|
|
||||||
`gamma`::::
|
`gamma`::::
|
||||||
(Optional, double)
|
(Optional, double)
|
||||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=gamma]
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=gamma]
|
||||||
|
@ -517,6 +517,14 @@ Specifies the rate at which the `eta` increases for each new tree that is added
|
|||||||
forest. For example, a rate of `1.05` increases `eta` by 5%.
|
forest. For example, a rate of `1.05` increases `eta` by 5%.
|
||||||
end::dfas-eta-growth[]
|
end::dfas-eta-growth[]
|
||||||
|
|
||||||
|
tag::dfas-feature-processors[]
|
||||||
|
A collection of feature preprocessors that modify one or more included fields.
|
||||||
|
The analysis uses the resulting one or more features instead of the
|
||||||
|
original document field. Multiple `feature_processors` entries can refer to the
|
||||||
|
same document fields.
|
||||||
|
Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs.
|
||||||
|
end::dfas-feature-processors[]
|
||||||
|
|
||||||
tag::dfas-iteration[]
|
tag::dfas-iteration[]
|
||||||
The number of iterations on the analysis.
|
The number of iterations on the analysis.
|
||||||
end::dfas-iteration[]
|
end::dfas-iteration[]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user