Adds a new parameter to regression and classification that enables computation of importance for the top most important features. The computation of the importance is based on SHAP (SHapley Additive exPlanations) method. Backport of #50914
This commit is contained in:
parent
0178c7c5d0
commit
1d8cb3c741
|
@ -46,6 +46,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
static final ParseField ETA = new ParseField("eta");
|
static final ParseField ETA = new ParseField("eta");
|
||||||
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
||||||
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||||
|
static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
|
||||||
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||||
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||||
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||||
|
@ -62,10 +63,11 @@ public class Classification implements DataFrameAnalysis {
|
||||||
(Double) a[3],
|
(Double) a[3],
|
||||||
(Integer) a[4],
|
(Integer) a[4],
|
||||||
(Double) a[5],
|
(Double) a[5],
|
||||||
(String) a[6],
|
(Integer) a[6],
|
||||||
(Double) a[7],
|
(String) a[7],
|
||||||
(Integer) a[8],
|
(Double) a[8],
|
||||||
(Long) a[9]));
|
(Integer) a[9],
|
||||||
|
(Long) a[10]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||||
|
@ -74,6 +76,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
|
||||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
||||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||||
|
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
|
||||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
||||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
|
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
|
||||||
|
@ -86,13 +89,15 @@ public class Classification implements DataFrameAnalysis {
|
||||||
private final Double eta;
|
private final Double eta;
|
||||||
private final Integer maximumNumberTrees;
|
private final Integer maximumNumberTrees;
|
||||||
private final Double featureBagFraction;
|
private final Double featureBagFraction;
|
||||||
|
private final Integer numTopFeatureImportanceValues;
|
||||||
private final String predictionFieldName;
|
private final String predictionFieldName;
|
||||||
private final Double trainingPercent;
|
private final Double trainingPercent;
|
||||||
private final Integer numTopClasses;
|
private final Integer numTopClasses;
|
||||||
private final Long randomizeSeed;
|
private final Long randomizeSeed;
|
||||||
|
|
||||||
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
|
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction,
|
||||||
|
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
|
||||||
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
|
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
|
||||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||||
this.lambda = lambda;
|
this.lambda = lambda;
|
||||||
|
@ -100,6 +105,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
this.eta = eta;
|
this.eta = eta;
|
||||||
this.maximumNumberTrees = maximumNumberTrees;
|
this.maximumNumberTrees = maximumNumberTrees;
|
||||||
this.featureBagFraction = featureBagFraction;
|
this.featureBagFraction = featureBagFraction;
|
||||||
|
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||||
this.predictionFieldName = predictionFieldName;
|
this.predictionFieldName = predictionFieldName;
|
||||||
this.trainingPercent = trainingPercent;
|
this.trainingPercent = trainingPercent;
|
||||||
this.numTopClasses = numTopClasses;
|
this.numTopClasses = numTopClasses;
|
||||||
|
@ -135,6 +141,10 @@ public class Classification implements DataFrameAnalysis {
|
||||||
return featureBagFraction;
|
return featureBagFraction;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Integer getNumTopFeatureImportanceValues() {
|
||||||
|
return numTopFeatureImportanceValues;
|
||||||
|
}
|
||||||
|
|
||||||
public String getPredictionFieldName() {
|
public String getPredictionFieldName() {
|
||||||
return predictionFieldName;
|
return predictionFieldName;
|
||||||
}
|
}
|
||||||
|
@ -170,6 +180,9 @@ public class Classification implements DataFrameAnalysis {
|
||||||
if (featureBagFraction != null) {
|
if (featureBagFraction != null) {
|
||||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||||
}
|
}
|
||||||
|
if (numTopFeatureImportanceValues != null) {
|
||||||
|
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
|
||||||
|
}
|
||||||
if (predictionFieldName != null) {
|
if (predictionFieldName != null) {
|
||||||
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||||
}
|
}
|
||||||
|
@ -188,8 +201,8 @@ public class Classification implements DataFrameAnalysis {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues,
|
||||||
trainingPercent, randomizeSeed, numTopClasses);
|
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -203,6 +216,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
&& Objects.equals(eta, that.eta)
|
&& Objects.equals(eta, that.eta)
|
||||||
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||||
|
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
|
||||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||||
&& Objects.equals(trainingPercent, that.trainingPercent)
|
&& Objects.equals(trainingPercent, that.trainingPercent)
|
||||||
&& Objects.equals(randomizeSeed, that.randomizeSeed)
|
&& Objects.equals(randomizeSeed, that.randomizeSeed)
|
||||||
|
@ -221,6 +235,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
private Double eta;
|
private Double eta;
|
||||||
private Integer maximumNumberTrees;
|
private Integer maximumNumberTrees;
|
||||||
private Double featureBagFraction;
|
private Double featureBagFraction;
|
||||||
|
private Integer numTopFeatureImportanceValues;
|
||||||
private String predictionFieldName;
|
private String predictionFieldName;
|
||||||
private Double trainingPercent;
|
private Double trainingPercent;
|
||||||
private Integer numTopClasses;
|
private Integer numTopClasses;
|
||||||
|
@ -255,6 +270,11 @@ public class Classification implements DataFrameAnalysis {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
|
||||||
|
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public Builder setPredictionFieldName(String predictionFieldName) {
|
public Builder setPredictionFieldName(String predictionFieldName) {
|
||||||
this.predictionFieldName = predictionFieldName;
|
this.predictionFieldName = predictionFieldName;
|
||||||
return this;
|
return this;
|
||||||
|
@ -276,8 +296,8 @@ public class Classification implements DataFrameAnalysis {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Classification build() {
|
public Classification build() {
|
||||||
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
|
||||||
trainingPercent, numTopClasses, randomizeSeed);
|
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,6 +46,7 @@ public class Regression implements DataFrameAnalysis {
|
||||||
static final ParseField ETA = new ParseField("eta");
|
static final ParseField ETA = new ParseField("eta");
|
||||||
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
||||||
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||||
|
static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
|
||||||
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||||
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||||
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||||
|
@ -61,9 +62,10 @@ public class Regression implements DataFrameAnalysis {
|
||||||
(Double) a[3],
|
(Double) a[3],
|
||||||
(Integer) a[4],
|
(Integer) a[4],
|
||||||
(Double) a[5],
|
(Double) a[5],
|
||||||
(String) a[6],
|
(Integer) a[6],
|
||||||
(Double) a[7],
|
(String) a[7],
|
||||||
(Long) a[8]));
|
(Double) a[8],
|
||||||
|
(Long) a[9]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||||
|
@ -72,6 +74,7 @@ public class Regression implements DataFrameAnalysis {
|
||||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
|
||||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
||||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||||
|
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
|
||||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
||||||
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
|
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
|
||||||
|
@ -83,12 +86,14 @@ public class Regression implements DataFrameAnalysis {
|
||||||
private final Double eta;
|
private final Double eta;
|
||||||
private final Integer maximumNumberTrees;
|
private final Integer maximumNumberTrees;
|
||||||
private final Double featureBagFraction;
|
private final Double featureBagFraction;
|
||||||
|
private final Integer numTopFeatureImportanceValues;
|
||||||
private final String predictionFieldName;
|
private final String predictionFieldName;
|
||||||
private final Double trainingPercent;
|
private final Double trainingPercent;
|
||||||
private final Long randomizeSeed;
|
private final Long randomizeSeed;
|
||||||
|
|
||||||
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
|
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction,
|
||||||
|
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
|
||||||
@Nullable Double trainingPercent, @Nullable Long randomizeSeed) {
|
@Nullable Double trainingPercent, @Nullable Long randomizeSeed) {
|
||||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||||
this.lambda = lambda;
|
this.lambda = lambda;
|
||||||
|
@ -96,6 +101,7 @@ public class Regression implements DataFrameAnalysis {
|
||||||
this.eta = eta;
|
this.eta = eta;
|
||||||
this.maximumNumberTrees = maximumNumberTrees;
|
this.maximumNumberTrees = maximumNumberTrees;
|
||||||
this.featureBagFraction = featureBagFraction;
|
this.featureBagFraction = featureBagFraction;
|
||||||
|
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||||
this.predictionFieldName = predictionFieldName;
|
this.predictionFieldName = predictionFieldName;
|
||||||
this.trainingPercent = trainingPercent;
|
this.trainingPercent = trainingPercent;
|
||||||
this.randomizeSeed = randomizeSeed;
|
this.randomizeSeed = randomizeSeed;
|
||||||
|
@ -130,6 +136,10 @@ public class Regression implements DataFrameAnalysis {
|
||||||
return featureBagFraction;
|
return featureBagFraction;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Integer getNumTopFeatureImportanceValues() {
|
||||||
|
return numTopFeatureImportanceValues;
|
||||||
|
}
|
||||||
|
|
||||||
public String getPredictionFieldName() {
|
public String getPredictionFieldName() {
|
||||||
return predictionFieldName;
|
return predictionFieldName;
|
||||||
}
|
}
|
||||||
|
@ -161,6 +171,9 @@ public class Regression implements DataFrameAnalysis {
|
||||||
if (featureBagFraction != null) {
|
if (featureBagFraction != null) {
|
||||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||||
}
|
}
|
||||||
|
if (numTopFeatureImportanceValues != null) {
|
||||||
|
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
|
||||||
|
}
|
||||||
if (predictionFieldName != null) {
|
if (predictionFieldName != null) {
|
||||||
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||||
}
|
}
|
||||||
|
@ -176,8 +189,8 @@ public class Regression implements DataFrameAnalysis {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues,
|
||||||
trainingPercent, randomizeSeed);
|
predictionFieldName, trainingPercent, randomizeSeed);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -191,6 +204,7 @@ public class Regression implements DataFrameAnalysis {
|
||||||
&& Objects.equals(eta, that.eta)
|
&& Objects.equals(eta, that.eta)
|
||||||
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||||
|
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
|
||||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||||
&& Objects.equals(trainingPercent, that.trainingPercent)
|
&& Objects.equals(trainingPercent, that.trainingPercent)
|
||||||
&& Objects.equals(randomizeSeed, that.randomizeSeed);
|
&& Objects.equals(randomizeSeed, that.randomizeSeed);
|
||||||
|
@ -208,6 +222,7 @@ public class Regression implements DataFrameAnalysis {
|
||||||
private Double eta;
|
private Double eta;
|
||||||
private Integer maximumNumberTrees;
|
private Integer maximumNumberTrees;
|
||||||
private Double featureBagFraction;
|
private Double featureBagFraction;
|
||||||
|
private Integer numTopFeatureImportanceValues;
|
||||||
private String predictionFieldName;
|
private String predictionFieldName;
|
||||||
private Double trainingPercent;
|
private Double trainingPercent;
|
||||||
private Long randomizeSeed;
|
private Long randomizeSeed;
|
||||||
|
@ -241,6 +256,11 @@ public class Regression implements DataFrameAnalysis {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
|
||||||
|
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public Builder setPredictionFieldName(String predictionFieldName) {
|
public Builder setPredictionFieldName(String predictionFieldName) {
|
||||||
this.predictionFieldName = predictionFieldName;
|
this.predictionFieldName = predictionFieldName;
|
||||||
return this;
|
return this;
|
||||||
|
@ -257,8 +277,8 @@ public class Regression implements DataFrameAnalysis {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Regression build() {
|
public Regression build() {
|
||||||
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
|
||||||
trainingPercent, randomizeSeed);
|
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1324,6 +1324,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
.setPredictionFieldName("my_dependent_variable_prediction")
|
.setPredictionFieldName("my_dependent_variable_prediction")
|
||||||
.setTrainingPercent(80.0)
|
.setTrainingPercent(80.0)
|
||||||
.setRandomizeSeed(42L)
|
.setRandomizeSeed(42L)
|
||||||
|
.setLambda(1.0)
|
||||||
|
.setGamma(1.0)
|
||||||
|
.setEta(1.0)
|
||||||
|
.setMaximumNumberTrees(10)
|
||||||
|
.setFeatureBagFraction(0.5)
|
||||||
|
.setNumTopFeatureImportanceValues(3)
|
||||||
.build())
|
.build())
|
||||||
.setDescription("this is a regression")
|
.setDescription("this is a regression")
|
||||||
.build();
|
.build();
|
||||||
|
@ -1361,6 +1367,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
.setTrainingPercent(80.0)
|
.setTrainingPercent(80.0)
|
||||||
.setRandomizeSeed(42L)
|
.setRandomizeSeed(42L)
|
||||||
.setNumTopClasses(1)
|
.setNumTopClasses(1)
|
||||||
|
.setLambda(1.0)
|
||||||
|
.setGamma(1.0)
|
||||||
|
.setEta(1.0)
|
||||||
|
.setMaximumNumberTrees(10)
|
||||||
|
.setFeatureBagFraction(0.5)
|
||||||
|
.setNumTopFeatureImportanceValues(3)
|
||||||
.build())
|
.build())
|
||||||
.setDescription("this is a classification")
|
.setDescription("this is a classification")
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -2975,10 +2975,11 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
.setEta(5.5) // <4>
|
.setEta(5.5) // <4>
|
||||||
.setMaximumNumberTrees(50) // <5>
|
.setMaximumNumberTrees(50) // <5>
|
||||||
.setFeatureBagFraction(0.4) // <6>
|
.setFeatureBagFraction(0.4) // <6>
|
||||||
.setPredictionFieldName("my_prediction_field_name") // <7>
|
.setNumTopFeatureImportanceValues(3) // <7>
|
||||||
.setTrainingPercent(50.0) // <8>
|
.setPredictionFieldName("my_prediction_field_name") // <8>
|
||||||
.setRandomizeSeed(1234L) // <9>
|
.setTrainingPercent(50.0) // <9>
|
||||||
.setNumTopClasses(1) // <10>
|
.setRandomizeSeed(1234L) // <10>
|
||||||
|
.setNumTopClasses(1) // <11>
|
||||||
.build();
|
.build();
|
||||||
// end::put-data-frame-analytics-classification
|
// end::put-data-frame-analytics-classification
|
||||||
|
|
||||||
|
@ -2989,9 +2990,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
.setEta(5.5) // <4>
|
.setEta(5.5) // <4>
|
||||||
.setMaximumNumberTrees(50) // <5>
|
.setMaximumNumberTrees(50) // <5>
|
||||||
.setFeatureBagFraction(0.4) // <6>
|
.setFeatureBagFraction(0.4) // <6>
|
||||||
.setPredictionFieldName("my_prediction_field_name") // <7>
|
.setNumTopFeatureImportanceValues(3) // <7>
|
||||||
.setTrainingPercent(50.0) // <8>
|
.setPredictionFieldName("my_prediction_field_name") // <8>
|
||||||
.setRandomizeSeed(1234L) // <9>
|
.setTrainingPercent(50.0) // <9>
|
||||||
|
.setRandomizeSeed(1234L) // <10>
|
||||||
.build();
|
.build();
|
||||||
// end::put-data-frame-analytics-regression
|
// end::put-data-frame-analytics-regression
|
||||||
|
|
||||||
|
@ -3670,7 +3672,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig);
|
PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig);
|
||||||
|
|
||||||
// tag::put-trained-model-execute-listener
|
// tag::put-trained-model-execute-listener
|
||||||
ActionListener<PutTrainedModelResponse> listener = new ActionListener<PutTrainedModelResponse>() {
|
ActionListener<PutTrainedModelResponse> listener = new ActionListener<PutTrainedModelResponse>() {
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -32,6 +32,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
||||||
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
|
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
|
||||||
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
|
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
|
||||||
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
||||||
|
.setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE))
|
||||||
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
|
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
|
||||||
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
||||||
.setRandomizeSeed(randomBoolean() ? null : randomLong())
|
.setRandomizeSeed(randomBoolean() ? null : randomLong())
|
||||||
|
|
|
@ -32,6 +32,7 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
||||||
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
|
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
|
||||||
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
|
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
|
||||||
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
||||||
|
.setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE))
|
||||||
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
|
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
|
||||||
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -117,10 +117,11 @@ include-tagged::{doc-tests-file}[{api}-classification]
|
||||||
<4> The applied shrinkage. A double in [0.001, 1].
|
<4> The applied shrinkage. A double in [0.001, 1].
|
||||||
<5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000].
|
<5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000].
|
||||||
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
|
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
|
||||||
<7> The name of the prediction field in the results object.
|
<7> If set, feature importance for the top most important features will be computed.
|
||||||
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
<8> The name of the prediction field in the results object.
|
||||||
<9> The seed to be used by the random generator that picks which rows are used in training.
|
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||||
<10> The number of top classes to be reported in the results. Defaults to 2.
|
<10> The seed to be used by the random generator that picks which rows are used in training.
|
||||||
|
<11> The number of top classes to be reported in the results. Defaults to 2.
|
||||||
|
|
||||||
===== Regression
|
===== Regression
|
||||||
|
|
||||||
|
@ -137,9 +138,10 @@ include-tagged::{doc-tests-file}[{api}-regression]
|
||||||
<4> The applied shrinkage. A double in [0.001, 1].
|
<4> The applied shrinkage. A double in [0.001, 1].
|
||||||
<5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000].
|
<5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000].
|
||||||
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
|
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
|
||||||
<7> The name of the prediction field in the results object.
|
<7> If set, feature importance for the top most important features will be computed.
|
||||||
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
<8> The name of the prediction field in the results object.
|
||||||
<9> The seed to be used by the random generator that picks which rows are used in training.
|
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||||
|
<10> The seed to be used by the random generator that picks which rows are used in training.
|
||||||
|
|
||||||
==== Analyzed fields
|
==== Analyzed fields
|
||||||
|
|
||||||
|
|
|
@ -150,6 +150,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction-field-name]
|
||||||
(Optional, long)
|
(Optional, long)
|
||||||
include::{docdir}/ml/ml-shared.asciidoc[tag=randomize-seed]
|
include::{docdir}/ml/ml-shared.asciidoc[tag=randomize-seed]
|
||||||
|
|
||||||
|
`analysis`.`classification`.`num_top_feature_importance_values`::::
|
||||||
|
(Optional, integer)
|
||||||
|
include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values]
|
||||||
|
|
||||||
`analysis`.`classification`.`training_percent`::::
|
`analysis`.`classification`.`training_percent`::::
|
||||||
(Optional, integer)
|
(Optional, integer)
|
||||||
include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent]
|
include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent]
|
||||||
|
@ -229,6 +233,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=lambda]
|
||||||
(Optional, string)
|
(Optional, string)
|
||||||
include::{docdir}/ml/ml-shared.asciidoc[tag=prediction-field-name]
|
include::{docdir}/ml/ml-shared.asciidoc[tag=prediction-field-name]
|
||||||
|
|
||||||
|
`analysis`.`regression`.`num_top_feature_importance_values`::::
|
||||||
|
(Optional, integer)
|
||||||
|
include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values]
|
||||||
|
|
||||||
`analysis`.`regression`.`training_percent`::::
|
`analysis`.`regression`.`training_percent`::::
|
||||||
(Optional, integer)
|
(Optional, integer)
|
||||||
include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent]
|
include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent]
|
||||||
|
|
|
@ -637,6 +637,14 @@ end::include-model-definition[]
|
||||||
tag::indices[]
|
tag::indices[]
|
||||||
An array of index names. Wildcards are supported. For example:
|
An array of index names. Wildcards are supported. For example:
|
||||||
`["it_ops_metrics", "server*"]`.
|
`["it_ops_metrics", "server*"]`.
|
||||||
|
|
||||||
|
tag::num-top-feature-importance-values[]
|
||||||
|
Advanced configuration option. If set, feature importance for the top
|
||||||
|
most important features will be computed. Importance is calculated
|
||||||
|
using the SHAP (SHapley Additive exPlanations) method as described in
|
||||||
|
https://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf[Lundberg, S. M., & Lee, S.-I. A Unified Approach to Interpreting Model Predictions. In NeurIPS 2017.].
|
||||||
|
end::num-top-feature-importance-values[]
|
||||||
|
|
||||||
+
|
+
|
||||||
--
|
--
|
||||||
NOTE: If any indices are in remote clusters then `cluster.remote.connect` must
|
NOTE: If any indices are in remote clusters then `cluster.remote.connect` must
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||||
|
|
||||||
|
import org.elasticsearch.Version;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.io.stream.StreamInput;
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
@ -34,6 +35,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
||||||
public static final ParseField ETA = new ParseField("eta");
|
public static final ParseField ETA = new ParseField("eta");
|
||||||
public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
||||||
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||||
|
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
|
||||||
|
|
||||||
static void declareFields(AbstractObjectParser<?, Void> parser) {
|
static void declareFields(AbstractObjectParser<?, Void> parser) {
|
||||||
parser.declareDouble(optionalConstructorArg(), LAMBDA);
|
parser.declareDouble(optionalConstructorArg(), LAMBDA);
|
||||||
|
@ -41,6 +43,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
||||||
parser.declareDouble(optionalConstructorArg(), ETA);
|
parser.declareDouble(optionalConstructorArg(), ETA);
|
||||||
parser.declareInt(optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
parser.declareInt(optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
||||||
parser.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
parser.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||||
|
parser.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final Double lambda;
|
private final Double lambda;
|
||||||
|
@ -48,12 +51,14 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
||||||
private final Double eta;
|
private final Double eta;
|
||||||
private final Integer maximumNumberTrees;
|
private final Integer maximumNumberTrees;
|
||||||
private final Double featureBagFraction;
|
private final Double featureBagFraction;
|
||||||
|
private final Integer numTopFeatureImportanceValues;
|
||||||
|
|
||||||
public BoostedTreeParams(@Nullable Double lambda,
|
public BoostedTreeParams(@Nullable Double lambda,
|
||||||
@Nullable Double gamma,
|
@Nullable Double gamma,
|
||||||
@Nullable Double eta,
|
@Nullable Double eta,
|
||||||
@Nullable Integer maximumNumberTrees,
|
@Nullable Integer maximumNumberTrees,
|
||||||
@Nullable Double featureBagFraction) {
|
@Nullable Double featureBagFraction,
|
||||||
|
@Nullable Integer numTopFeatureImportanceValues) {
|
||||||
if (lambda != null && lambda < 0) {
|
if (lambda != null && lambda < 0) {
|
||||||
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName());
|
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName());
|
||||||
}
|
}
|
||||||
|
@ -69,15 +74,16 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
||||||
if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) {
|
if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) {
|
||||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName());
|
throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName());
|
||||||
}
|
}
|
||||||
|
if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) {
|
||||||
|
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative integer",
|
||||||
|
NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
|
||||||
|
}
|
||||||
this.lambda = lambda;
|
this.lambda = lambda;
|
||||||
this.gamma = gamma;
|
this.gamma = gamma;
|
||||||
this.eta = eta;
|
this.eta = eta;
|
||||||
this.maximumNumberTrees = maximumNumberTrees;
|
this.maximumNumberTrees = maximumNumberTrees;
|
||||||
this.featureBagFraction = featureBagFraction;
|
this.featureBagFraction = featureBagFraction;
|
||||||
}
|
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||||
|
|
||||||
public BoostedTreeParams() {
|
|
||||||
this(null, null, null, null, null);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BoostedTreeParams(StreamInput in) throws IOException {
|
BoostedTreeParams(StreamInput in) throws IOException {
|
||||||
|
@ -86,6 +92,11 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
||||||
eta = in.readOptionalDouble();
|
eta = in.readOptionalDouble();
|
||||||
maximumNumberTrees = in.readOptionalVInt();
|
maximumNumberTrees = in.readOptionalVInt();
|
||||||
featureBagFraction = in.readOptionalDouble();
|
featureBagFraction = in.readOptionalDouble();
|
||||||
|
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||||
|
numTopFeatureImportanceValues = in.readOptionalInt();
|
||||||
|
} else {
|
||||||
|
numTopFeatureImportanceValues = null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -95,6 +106,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
||||||
out.writeOptionalDouble(eta);
|
out.writeOptionalDouble(eta);
|
||||||
out.writeOptionalVInt(maximumNumberTrees);
|
out.writeOptionalVInt(maximumNumberTrees);
|
||||||
out.writeOptionalDouble(featureBagFraction);
|
out.writeOptionalDouble(featureBagFraction);
|
||||||
|
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||||
|
out.writeOptionalInt(numTopFeatureImportanceValues);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -114,6 +128,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
||||||
if (featureBagFraction != null) {
|
if (featureBagFraction != null) {
|
||||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||||
}
|
}
|
||||||
|
if (numTopFeatureImportanceValues != null) {
|
||||||
|
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
|
||||||
|
}
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,6 +151,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
||||||
if (featureBagFraction != null) {
|
if (featureBagFraction != null) {
|
||||||
params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||||
}
|
}
|
||||||
|
if (numTopFeatureImportanceValues != null) {
|
||||||
|
params.put(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
|
||||||
|
}
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,11 +166,62 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
||||||
&& Objects.equals(gamma, that.gamma)
|
&& Objects.equals(gamma, that.gamma)
|
||||||
&& Objects.equals(eta, that.eta)
|
&& Objects.equals(eta, that.eta)
|
||||||
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||||
&& Objects.equals(featureBagFraction, that.featureBagFraction);
|
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||||
|
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction);
|
return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new Builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
|
||||||
|
private Double lambda;
|
||||||
|
private Double gamma;
|
||||||
|
private Double eta;
|
||||||
|
private Integer maximumNumberTrees;
|
||||||
|
private Double featureBagFraction;
|
||||||
|
private Integer numTopFeatureImportanceValues;
|
||||||
|
|
||||||
|
private Builder() {}
|
||||||
|
|
||||||
|
public Builder setLambda(Double lambda) {
|
||||||
|
this.lambda = lambda;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setGamma(Double gamma) {
|
||||||
|
this.gamma = gamma;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setEta(Double eta) {
|
||||||
|
this.eta = eta;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setMaximumNumberTrees(Integer maximumNumberTrees) {
|
||||||
|
this.maximumNumberTrees = maximumNumberTrees;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setFeatureBagFraction(Double featureBagFraction) {
|
||||||
|
this.featureBagFraction = featureBagFraction;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
|
||||||
|
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public BoostedTreeParams build() {
|
||||||
|
return new BoostedTreeParams(lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,11 +50,11 @@ public class Classification implements DataFrameAnalysis {
|
||||||
lenient,
|
lenient,
|
||||||
a -> new Classification(
|
a -> new Classification(
|
||||||
(String) a[0],
|
(String) a[0],
|
||||||
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]),
|
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]),
|
||||||
(String) a[6],
|
(String) a[7],
|
||||||
(Integer) a[7],
|
(Integer) a[8],
|
||||||
(Double) a[8],
|
(Double) a[9],
|
||||||
(Long) a[9]));
|
(Long) a[10]));
|
||||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||||
BoostedTreeParams.declareFields(parser);
|
BoostedTreeParams.declareFields(parser);
|
||||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||||
|
@ -114,7 +114,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Classification(String dependentVariable) {
|
public Classification(String dependentVariable) {
|
||||||
this(dependentVariable, new BoostedTreeParams(), null, null, null, null);
|
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Classification(StreamInput in) throws IOException {
|
public Classification(StreamInput in) throws IOException {
|
||||||
|
|
|
@ -47,10 +47,10 @@ public class Regression implements DataFrameAnalysis {
|
||||||
lenient,
|
lenient,
|
||||||
a -> new Regression(
|
a -> new Regression(
|
||||||
(String) a[0],
|
(String) a[0],
|
||||||
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]),
|
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]),
|
||||||
(String) a[6],
|
(String) a[7],
|
||||||
(Double) a[7],
|
(Double) a[8],
|
||||||
(Long) a[8]));
|
(Long) a[9]));
|
||||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||||
BoostedTreeParams.declareFields(parser);
|
BoostedTreeParams.declareFields(parser);
|
||||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||||
|
@ -85,7 +85,7 @@ public class Regression implements DataFrameAnalysis {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Regression(String dependentVariable) {
|
public Regression(String dependentVariable) {
|
||||||
this(dependentVariable, new BoostedTreeParams(), null, null, null);
|
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Regression(StreamInput in) throws IOException {
|
public Regression(StreamInput in) throws IOException {
|
||||||
|
|
|
@ -471,6 +471,9 @@ public class ElasticsearchMappings {
|
||||||
.startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName())
|
.startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName())
|
||||||
.field(TYPE, DOUBLE)
|
.field(TYPE, DOUBLE)
|
||||||
.endObject()
|
.endObject()
|
||||||
|
.startObject(BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName())
|
||||||
|
.field(TYPE, INTEGER)
|
||||||
|
.endObject()
|
||||||
.startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName())
|
.startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName())
|
||||||
.field(TYPE, KEYWORD)
|
.field(TYPE, KEYWORD)
|
||||||
.endObject()
|
.endObject()
|
||||||
|
@ -499,6 +502,9 @@ public class ElasticsearchMappings {
|
||||||
.startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName())
|
.startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName())
|
||||||
.field(TYPE, DOUBLE)
|
.field(TYPE, DOUBLE)
|
||||||
.endObject()
|
.endObject()
|
||||||
|
.startObject(BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName())
|
||||||
|
.field(TYPE, INTEGER)
|
||||||
|
.endObject()
|
||||||
.startObject(Classification.PREDICTION_FIELD_NAME.getPreferredName())
|
.startObject(Classification.PREDICTION_FIELD_NAME.getPreferredName())
|
||||||
.field(TYPE, KEYWORD)
|
.field(TYPE, KEYWORD)
|
||||||
.endObject()
|
.endObject()
|
||||||
|
|
|
@ -323,6 +323,7 @@ public final class ReservedFieldNames {
|
||||||
BoostedTreeParams.ETA.getPreferredName(),
|
BoostedTreeParams.ETA.getPreferredName(),
|
||||||
BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName(),
|
BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName(),
|
||||||
BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName(),
|
BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName(),
|
||||||
|
BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(),
|
||||||
|
|
||||||
ElasticsearchMappings.CONFIG_TYPE,
|
ElasticsearchMappings.CONFIG_TYPE,
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedT
|
||||||
new ConstructingObjectParser<>(
|
new ConstructingObjectParser<>(
|
||||||
BoostedTreeParams.NAME,
|
BoostedTreeParams.NAME,
|
||||||
true,
|
true,
|
||||||
a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4]));
|
a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4], (Integer) a[5]));
|
||||||
BoostedTreeParams.declareFields(objParser);
|
BoostedTreeParams.declareFields(objParser);
|
||||||
return objParser.apply(parser, null);
|
return objParser.apply(parser, null);
|
||||||
}
|
}
|
||||||
|
@ -34,12 +34,14 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedT
|
||||||
}
|
}
|
||||||
|
|
||||||
public static BoostedTreeParams createRandom() {
|
public static BoostedTreeParams createRandom() {
|
||||||
Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
|
return BoostedTreeParams.builder()
|
||||||
Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
|
.setLambda(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
|
||||||
Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true);
|
.setGamma(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
|
||||||
Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000);
|
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
|
||||||
Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false);
|
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
|
||||||
return new BoostedTreeParams(lambda, gamma, eta, maximumNumberTrees, featureBagFraction);
|
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
||||||
|
.setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE))
|
||||||
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -49,57 +51,64 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedT
|
||||||
|
|
||||||
public void testConstructor_GivenNegativeLambda() {
|
public void testConstructor_GivenNegativeLambda() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new BoostedTreeParams(-0.00001, 0.0, 0.5, 500, 0.3));
|
() -> BoostedTreeParams.builder().setLambda(-0.00001).build());
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double"));
|
assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenNegativeGamma() {
|
public void testConstructor_GivenNegativeGamma() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new BoostedTreeParams(0.0, -0.00001, 0.5, 500, 0.3));
|
() -> BoostedTreeParams.builder().setGamma(-0.00001).build());
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double"));
|
assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenEtaIsZero() {
|
public void testConstructor_GivenEtaIsZero() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new BoostedTreeParams(0.0, 0.0, 0.0, 500, 0.3));
|
() -> BoostedTreeParams.builder().setEta(0.0).build());
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenEtaIsGreaterThanOne() {
|
public void testConstructor_GivenEtaIsGreaterThanOne() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new BoostedTreeParams(0.0, 0.0, 1.00001, 500, 0.3));
|
() -> BoostedTreeParams.builder().setEta(1.00001).build());
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenMaximumNumberTreesIsZero() {
|
public void testConstructor_GivenMaximumNumberTreesIsZero() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 0, 0.3));
|
() -> BoostedTreeParams.builder().setMaximumNumberTrees(0).build());
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenMaximumNumberTreesIsGreaterThan2k() {
|
public void testConstructor_GivenMaximumNumberTreesIsGreaterThan2k() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 2001, 0.3));
|
() -> BoostedTreeParams.builder().setMaximumNumberTrees(2001).build());
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenFeatureBagFractionIsLessThanZero() {
|
public void testConstructor_GivenFeatureBagFractionIsLessThanZero() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, -0.00001));
|
() -> BoostedTreeParams.builder().setFeatureBagFraction(-0.00001).build());
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenFeatureBagFractionIsGreaterThanOne() {
|
public void testConstructor_GivenFeatureBagFractionIsGreaterThanOne() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.00001));
|
() -> BoostedTreeParams.builder().setFeatureBagFraction(1.00001).build());
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testConstructor_GivenTopFeatureImportanceValuesIsNegative() {
|
||||||
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
|
() -> BoostedTreeParams.builder().setNumTopFeatureImportanceValues(-1).build());
|
||||||
|
|
||||||
|
assertThat(e.getMessage(), equalTo("[num_top_feature_importance_values] must be a non-negative integer"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ import static org.hamcrest.Matchers.nullValue;
|
||||||
|
|
||||||
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
||||||
|
|
||||||
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0);
|
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Classification doParseInstance(XContentParser parser) throws IOException {
|
protected Classification doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
|
|
@ -31,7 +31,7 @@ import static org.hamcrest.Matchers.nullValue;
|
||||||
|
|
||||||
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||||
|
|
||||||
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0);
|
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Regression doParseInstance(XContentParser parser) throws IOException {
|
protected Regression doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
package org.elasticsearch.xpack.ml.integration;
|
package org.elasticsearch.xpack.ml.integration;
|
||||||
|
|
||||||
import com.google.common.collect.Ordering;
|
import com.google.common.collect.Ordering;
|
||||||
|
|
||||||
import org.elasticsearch.ElasticsearchStatusException;
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
|
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
|
||||||
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
|
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
|
||||||
|
@ -28,7 +27,6 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||||
|
@ -86,7 +84,14 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
String predictedClassField = KEYWORD_FIELD + "_prediction";
|
String predictedClassField = KEYWORD_FIELD + "_prediction";
|
||||||
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
|
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null,
|
||||||
|
new Classification(
|
||||||
|
KEYWORD_FIELD,
|
||||||
|
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null));
|
||||||
registerAnalytics(config);
|
registerAnalytics(config);
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
|
@ -104,6 +109,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
|
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
|
||||||
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
|
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
|
||||||
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
||||||
|
assertThat(resultsObject.keySet().stream().filter(k -> k.startsWith("feature_importance.")).findAny().isPresent(), is(true));
|
||||||
}
|
}
|
||||||
|
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
|
@ -178,7 +184,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
sourceIndex,
|
sourceIndex,
|
||||||
destIndex,
|
destIndex,
|
||||||
null,
|
null,
|
||||||
new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0, null));
|
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, numTopClasses, 50.0, null));
|
||||||
registerAnalytics(config);
|
registerAnalytics(config);
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
|
@ -413,7 +419,13 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
String firstJobId = "classification_two_jobs_with_same_randomize_seed_1";
|
String firstJobId = "classification_two_jobs_with_same_randomize_seed_1";
|
||||||
String firstJobDestIndex = firstJobId + "_dest";
|
String firstJobDestIndex = firstJobId + "_dest";
|
||||||
|
|
||||||
BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0);
|
BoostedTreeParams boostedTreeParams = BoostedTreeParams.builder()
|
||||||
|
.setLambda(1.0)
|
||||||
|
.setGamma(1.0)
|
||||||
|
.setEta(1.0)
|
||||||
|
.setFeatureBagFraction(1.0)
|
||||||
|
.setMaximumNumberTrees(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||||
new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null));
|
new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null));
|
||||||
|
|
|
@ -18,7 +18,6 @@ import org.elasticsearch.search.SearchHit;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
|
||||||
|
@ -53,7 +52,14 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
||||||
indexData(sourceIndex, 300, 50);
|
indexData(sourceIndex, 300, 50);
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null,
|
||||||
|
new Regression(
|
||||||
|
DEPENDENT_VARIABLE_FIELD,
|
||||||
|
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null)
|
||||||
|
);
|
||||||
registerAnalytics(config);
|
registerAnalytics(config);
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
|
@ -78,6 +84,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||||
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
||||||
|
assertThat(resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD), is(true));
|
||||||
}
|
}
|
||||||
|
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
|
@ -141,7 +148,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
sourceIndex,
|
sourceIndex,
|
||||||
destIndex,
|
destIndex,
|
||||||
null,
|
null,
|
||||||
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0, null));
|
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null));
|
||||||
registerAnalytics(config);
|
registerAnalytics(config);
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
|
@ -244,7 +251,13 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
String firstJobId = "regression_two_jobs_with_same_randomize_seed_1";
|
String firstJobId = "regression_two_jobs_with_same_randomize_seed_1";
|
||||||
String firstJobDestIndex = firstJobId + "_dest";
|
String firstJobDestIndex = firstJobId + "_dest";
|
||||||
|
|
||||||
BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0);
|
BoostedTreeParams boostedTreeParams = BoostedTreeParams.builder()
|
||||||
|
.setLambda(1.0)
|
||||||
|
.setGamma(1.0)
|
||||||
|
.setEta(1.0)
|
||||||
|
.setFeatureBagFraction(1.0)
|
||||||
|
.setMaximumNumberTrees(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null));
|
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null));
|
||||||
|
|
Loading…
Reference in New Issue