Adds a new parameter to regression and classification that enables computation of importance for the top most important features. The computation of the importance is based on SHAP (SHapley Additive exPlanations) method. Backport of #50914
This commit is contained in:
parent
0178c7c5d0
commit
1d8cb3c741
|
@ -46,6 +46,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
static final ParseField ETA = new ParseField("eta");
|
||||
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
||||
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
|
||||
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
||||
|
@ -62,10 +63,11 @@ public class Classification implements DataFrameAnalysis {
|
|||
(Double) a[3],
|
||||
(Integer) a[4],
|
||||
(Double) a[5],
|
||||
(String) a[6],
|
||||
(Double) a[7],
|
||||
(Integer) a[8],
|
||||
(Long) a[9]));
|
||||
(Integer) a[6],
|
||||
(String) a[7],
|
||||
(Double) a[8],
|
||||
(Integer) a[9],
|
||||
(Long) a[10]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||
|
@ -74,6 +76,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
|
||||
|
@ -86,13 +89,15 @@ public class Classification implements DataFrameAnalysis {
|
|||
private final Double eta;
|
||||
private final Integer maximumNumberTrees;
|
||||
private final Double featureBagFraction;
|
||||
private final Integer numTopFeatureImportanceValues;
|
||||
private final String predictionFieldName;
|
||||
private final Double trainingPercent;
|
||||
private final Integer numTopClasses;
|
||||
private final Long randomizeSeed;
|
||||
|
||||
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
|
||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction,
|
||||
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
|
||||
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
this.lambda = lambda;
|
||||
|
@ -100,6 +105,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
this.eta = eta;
|
||||
this.maximumNumberTrees = maximumNumberTrees;
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
this.trainingPercent = trainingPercent;
|
||||
this.numTopClasses = numTopClasses;
|
||||
|
@ -135,6 +141,10 @@ public class Classification implements DataFrameAnalysis {
|
|||
return featureBagFraction;
|
||||
}
|
||||
|
||||
public Integer getNumTopFeatureImportanceValues() {
|
||||
return numTopFeatureImportanceValues;
|
||||
}
|
||||
|
||||
public String getPredictionFieldName() {
|
||||
return predictionFieldName;
|
||||
}
|
||||
|
@ -170,6 +180,9 @@ public class Classification implements DataFrameAnalysis {
|
|||
if (featureBagFraction != null) {
|
||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
}
|
||||
if (numTopFeatureImportanceValues != null) {
|
||||
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
|
||||
}
|
||||
if (predictionFieldName != null) {
|
||||
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
|
@ -188,8 +201,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent, randomizeSeed, numTopClasses);
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues,
|
||||
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -203,6 +216,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
&& Objects.equals(eta, that.eta)
|
||||
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& Objects.equals(trainingPercent, that.trainingPercent)
|
||||
&& Objects.equals(randomizeSeed, that.randomizeSeed)
|
||||
|
@ -221,6 +235,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
private Double eta;
|
||||
private Integer maximumNumberTrees;
|
||||
private Double featureBagFraction;
|
||||
private Integer numTopFeatureImportanceValues;
|
||||
private String predictionFieldName;
|
||||
private Double trainingPercent;
|
||||
private Integer numTopClasses;
|
||||
|
@ -255,6 +270,11 @@ public class Classification implements DataFrameAnalysis {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
|
||||
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setPredictionFieldName(String predictionFieldName) {
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
return this;
|
||||
|
@ -276,8 +296,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
}
|
||||
|
||||
public Classification build() {
|
||||
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent, numTopClasses, randomizeSeed);
|
||||
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
|
||||
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,6 +46,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
static final ParseField ETA = new ParseField("eta");
|
||||
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
||||
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
|
||||
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
|
||||
|
@ -61,9 +62,10 @@ public class Regression implements DataFrameAnalysis {
|
|||
(Double) a[3],
|
||||
(Integer) a[4],
|
||||
(Double) a[5],
|
||||
(String) a[6],
|
||||
(Double) a[7],
|
||||
(Long) a[8]));
|
||||
(Integer) a[6],
|
||||
(String) a[7],
|
||||
(Double) a[8],
|
||||
(Long) a[9]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||
|
@ -72,6 +74,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
||||
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
|
||||
|
@ -83,12 +86,14 @@ public class Regression implements DataFrameAnalysis {
|
|||
private final Double eta;
|
||||
private final Integer maximumNumberTrees;
|
||||
private final Double featureBagFraction;
|
||||
private final Integer numTopFeatureImportanceValues;
|
||||
private final String predictionFieldName;
|
||||
private final Double trainingPercent;
|
||||
private final Long randomizeSeed;
|
||||
|
||||
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
|
||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction,
|
||||
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
|
||||
@Nullable Double trainingPercent, @Nullable Long randomizeSeed) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
this.lambda = lambda;
|
||||
|
@ -96,6 +101,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
this.eta = eta;
|
||||
this.maximumNumberTrees = maximumNumberTrees;
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
this.trainingPercent = trainingPercent;
|
||||
this.randomizeSeed = randomizeSeed;
|
||||
|
@ -130,6 +136,10 @@ public class Regression implements DataFrameAnalysis {
|
|||
return featureBagFraction;
|
||||
}
|
||||
|
||||
public Integer getNumTopFeatureImportanceValues() {
|
||||
return numTopFeatureImportanceValues;
|
||||
}
|
||||
|
||||
public String getPredictionFieldName() {
|
||||
return predictionFieldName;
|
||||
}
|
||||
|
@ -161,6 +171,9 @@ public class Regression implements DataFrameAnalysis {
|
|||
if (featureBagFraction != null) {
|
||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
}
|
||||
if (numTopFeatureImportanceValues != null) {
|
||||
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
|
||||
}
|
||||
if (predictionFieldName != null) {
|
||||
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
|
@ -176,8 +189,8 @@ public class Regression implements DataFrameAnalysis {
|
|||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent, randomizeSeed);
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues,
|
||||
predictionFieldName, trainingPercent, randomizeSeed);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -191,6 +204,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
&& Objects.equals(eta, that.eta)
|
||||
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& Objects.equals(trainingPercent, that.trainingPercent)
|
||||
&& Objects.equals(randomizeSeed, that.randomizeSeed);
|
||||
|
@ -208,6 +222,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
private Double eta;
|
||||
private Integer maximumNumberTrees;
|
||||
private Double featureBagFraction;
|
||||
private Integer numTopFeatureImportanceValues;
|
||||
private String predictionFieldName;
|
||||
private Double trainingPercent;
|
||||
private Long randomizeSeed;
|
||||
|
@ -241,6 +256,11 @@ public class Regression implements DataFrameAnalysis {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
|
||||
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setPredictionFieldName(String predictionFieldName) {
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
return this;
|
||||
|
@ -257,8 +277,8 @@ public class Regression implements DataFrameAnalysis {
|
|||
}
|
||||
|
||||
public Regression build() {
|
||||
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent, randomizeSeed);
|
||||
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
|
||||
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1324,6 +1324,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.setPredictionFieldName("my_dependent_variable_prediction")
|
||||
.setTrainingPercent(80.0)
|
||||
.setRandomizeSeed(42L)
|
||||
.setLambda(1.0)
|
||||
.setGamma(1.0)
|
||||
.setEta(1.0)
|
||||
.setMaximumNumberTrees(10)
|
||||
.setFeatureBagFraction(0.5)
|
||||
.setNumTopFeatureImportanceValues(3)
|
||||
.build())
|
||||
.setDescription("this is a regression")
|
||||
.build();
|
||||
|
@ -1361,6 +1367,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.setTrainingPercent(80.0)
|
||||
.setRandomizeSeed(42L)
|
||||
.setNumTopClasses(1)
|
||||
.setLambda(1.0)
|
||||
.setGamma(1.0)
|
||||
.setEta(1.0)
|
||||
.setMaximumNumberTrees(10)
|
||||
.setFeatureBagFraction(0.5)
|
||||
.setNumTopFeatureImportanceValues(3)
|
||||
.build())
|
||||
.setDescription("this is a classification")
|
||||
.build();
|
||||
|
|
|
@ -2975,10 +2975,11 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
.setEta(5.5) // <4>
|
||||
.setMaximumNumberTrees(50) // <5>
|
||||
.setFeatureBagFraction(0.4) // <6>
|
||||
.setPredictionFieldName("my_prediction_field_name") // <7>
|
||||
.setTrainingPercent(50.0) // <8>
|
||||
.setRandomizeSeed(1234L) // <9>
|
||||
.setNumTopClasses(1) // <10>
|
||||
.setNumTopFeatureImportanceValues(3) // <7>
|
||||
.setPredictionFieldName("my_prediction_field_name") // <8>
|
||||
.setTrainingPercent(50.0) // <9>
|
||||
.setRandomizeSeed(1234L) // <10>
|
||||
.setNumTopClasses(1) // <11>
|
||||
.build();
|
||||
// end::put-data-frame-analytics-classification
|
||||
|
||||
|
@ -2989,9 +2990,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
.setEta(5.5) // <4>
|
||||
.setMaximumNumberTrees(50) // <5>
|
||||
.setFeatureBagFraction(0.4) // <6>
|
||||
.setPredictionFieldName("my_prediction_field_name") // <7>
|
||||
.setTrainingPercent(50.0) // <8>
|
||||
.setRandomizeSeed(1234L) // <9>
|
||||
.setNumTopFeatureImportanceValues(3) // <7>
|
||||
.setPredictionFieldName("my_prediction_field_name") // <8>
|
||||
.setTrainingPercent(50.0) // <9>
|
||||
.setRandomizeSeed(1234L) // <10>
|
||||
.build();
|
||||
// end::put-data-frame-analytics-regression
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
|||
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
|
||||
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
|
||||
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
||||
.setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE))
|
||||
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
|
||||
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
||||
.setRandomizeSeed(randomBoolean() ? null : randomLong())
|
||||
|
|
|
@ -32,6 +32,7 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
|||
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
|
||||
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
|
||||
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
||||
.setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE))
|
||||
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
|
||||
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
||||
.build();
|
||||
|
|
|
@ -117,10 +117,11 @@ include-tagged::{doc-tests-file}[{api}-classification]
|
|||
<4> The applied shrinkage. A double in [0.001, 1].
|
||||
<5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000].
|
||||
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
|
||||
<7> The name of the prediction field in the results object.
|
||||
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||
<9> The seed to be used by the random generator that picks which rows are used in training.
|
||||
<10> The number of top classes to be reported in the results. Defaults to 2.
|
||||
<7> If set, feature importance for the top most important features will be computed.
|
||||
<8> The name of the prediction field in the results object.
|
||||
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||
<10> The seed to be used by the random generator that picks which rows are used in training.
|
||||
<11> The number of top classes to be reported in the results. Defaults to 2.
|
||||
|
||||
===== Regression
|
||||
|
||||
|
@ -137,9 +138,10 @@ include-tagged::{doc-tests-file}[{api}-regression]
|
|||
<4> The applied shrinkage. A double in [0.001, 1].
|
||||
<5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000].
|
||||
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
|
||||
<7> The name of the prediction field in the results object.
|
||||
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||
<9> The seed to be used by the random generator that picks which rows are used in training.
|
||||
<7> If set, feature importance for the top most important features will be computed.
|
||||
<8> The name of the prediction field in the results object.
|
||||
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||
<10> The seed to be used by the random generator that picks which rows are used in training.
|
||||
|
||||
==== Analyzed fields
|
||||
|
||||
|
|
|
@ -150,6 +150,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction-field-name]
|
|||
(Optional, long)
|
||||
include::{docdir}/ml/ml-shared.asciidoc[tag=randomize-seed]
|
||||
|
||||
`analysis`.`classification`.`num_top_feature_importance_values`::::
|
||||
(Optional, integer)
|
||||
include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values]
|
||||
|
||||
`analysis`.`classification`.`training_percent`::::
|
||||
(Optional, integer)
|
||||
include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent]
|
||||
|
@ -229,6 +233,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=lambda]
|
|||
(Optional, string)
|
||||
include::{docdir}/ml/ml-shared.asciidoc[tag=prediction-field-name]
|
||||
|
||||
`analysis`.`regression`.`num_top_feature_importance_values`::::
|
||||
(Optional, integer)
|
||||
include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values]
|
||||
|
||||
`analysis`.`regression`.`training_percent`::::
|
||||
(Optional, integer)
|
||||
include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent]
|
||||
|
|
|
@ -637,6 +637,14 @@ end::include-model-definition[]
|
|||
tag::indices[]
|
||||
An array of index names. Wildcards are supported. For example:
|
||||
`["it_ops_metrics", "server*"]`.
|
||||
|
||||
tag::num-top-feature-importance-values[]
|
||||
Advanced configuration option. If set, feature importance for the top
|
||||
most important features will be computed. Importance is calculated
|
||||
using the SHAP (SHapley Additive exPlanations) method as described in
|
||||
https://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf[Lundberg, S. M., & Lee, S.-I. A Unified Approach to Interpreting Model Predictions. In NeurIPS 2017.].
|
||||
end::num-top-feature-importance-values[]
|
||||
|
||||
+
|
||||
--
|
||||
NOTE: If any indices are in remote clusters then `cluster.remote.connect` must
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -34,6 +35,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|||
public static final ParseField ETA = new ParseField("eta");
|
||||
public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
||||
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
|
||||
|
||||
static void declareFields(AbstractObjectParser<?, Void> parser) {
|
||||
parser.declareDouble(optionalConstructorArg(), LAMBDA);
|
||||
|
@ -41,6 +43,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|||
parser.declareDouble(optionalConstructorArg(), ETA);
|
||||
parser.declareInt(optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
||||
parser.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||
parser.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
|
||||
}
|
||||
|
||||
private final Double lambda;
|
||||
|
@ -48,12 +51,14 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|||
private final Double eta;
|
||||
private final Integer maximumNumberTrees;
|
||||
private final Double featureBagFraction;
|
||||
private final Integer numTopFeatureImportanceValues;
|
||||
|
||||
public BoostedTreeParams(@Nullable Double lambda,
|
||||
@Nullable Double gamma,
|
||||
@Nullable Double eta,
|
||||
@Nullable Integer maximumNumberTrees,
|
||||
@Nullable Double featureBagFraction) {
|
||||
@Nullable Double featureBagFraction,
|
||||
@Nullable Integer numTopFeatureImportanceValues) {
|
||||
if (lambda != null && lambda < 0) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName());
|
||||
}
|
||||
|
@ -69,15 +74,16 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|||
if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName());
|
||||
}
|
||||
if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative integer",
|
||||
NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
|
||||
}
|
||||
this.lambda = lambda;
|
||||
this.gamma = gamma;
|
||||
this.eta = eta;
|
||||
this.maximumNumberTrees = maximumNumberTrees;
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
}
|
||||
|
||||
public BoostedTreeParams() {
|
||||
this(null, null, null, null, null);
|
||||
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||
}
|
||||
|
||||
BoostedTreeParams(StreamInput in) throws IOException {
|
||||
|
@ -86,6 +92,11 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|||
eta = in.readOptionalDouble();
|
||||
maximumNumberTrees = in.readOptionalVInt();
|
||||
featureBagFraction = in.readOptionalDouble();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||
numTopFeatureImportanceValues = in.readOptionalInt();
|
||||
} else {
|
||||
numTopFeatureImportanceValues = null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -95,6 +106,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|||
out.writeOptionalDouble(eta);
|
||||
out.writeOptionalVInt(maximumNumberTrees);
|
||||
out.writeOptionalDouble(featureBagFraction);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||
out.writeOptionalInt(numTopFeatureImportanceValues);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -114,6 +128,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|||
if (featureBagFraction != null) {
|
||||
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
}
|
||||
if (numTopFeatureImportanceValues != null) {
|
||||
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
|
@ -134,6 +151,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|||
if (featureBagFraction != null) {
|
||||
params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
||||
}
|
||||
if (numTopFeatureImportanceValues != null) {
|
||||
params.put(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
|
@ -146,11 +166,62 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|||
&& Objects.equals(gamma, that.gamma)
|
||||
&& Objects.equals(eta, that.eta)
|
||||
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||
&& Objects.equals(featureBagFraction, that.featureBagFraction);
|
||||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||
&& Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction);
|
||||
return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private Double lambda;
|
||||
private Double gamma;
|
||||
private Double eta;
|
||||
private Integer maximumNumberTrees;
|
||||
private Double featureBagFraction;
|
||||
private Integer numTopFeatureImportanceValues;
|
||||
|
||||
private Builder() {}
|
||||
|
||||
public Builder setLambda(Double lambda) {
|
||||
this.lambda = lambda;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setGamma(Double gamma) {
|
||||
this.gamma = gamma;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setEta(Double eta) {
|
||||
this.eta = eta;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setMaximumNumberTrees(Integer maximumNumberTrees) {
|
||||
this.maximumNumberTrees = maximumNumberTrees;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setFeatureBagFraction(Double featureBagFraction) {
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
|
||||
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
||||
return this;
|
||||
}
|
||||
|
||||
public BoostedTreeParams build() {
|
||||
return new BoostedTreeParams(lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,11 +50,11 @@ public class Classification implements DataFrameAnalysis {
|
|||
lenient,
|
||||
a -> new Classification(
|
||||
(String) a[0],
|
||||
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]),
|
||||
(String) a[6],
|
||||
(Integer) a[7],
|
||||
(Double) a[8],
|
||||
(Long) a[9]));
|
||||
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]),
|
||||
(String) a[7],
|
||||
(Integer) a[8],
|
||||
(Double) a[9],
|
||||
(Long) a[10]));
|
||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||
BoostedTreeParams.declareFields(parser);
|
||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
|
@ -114,7 +114,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
}
|
||||
|
||||
public Classification(String dependentVariable) {
|
||||
this(dependentVariable, new BoostedTreeParams(), null, null, null, null);
|
||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null);
|
||||
}
|
||||
|
||||
public Classification(StreamInput in) throws IOException {
|
||||
|
|
|
@ -47,10 +47,10 @@ public class Regression implements DataFrameAnalysis {
|
|||
lenient,
|
||||
a -> new Regression(
|
||||
(String) a[0],
|
||||
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]),
|
||||
(String) a[6],
|
||||
(Double) a[7],
|
||||
(Long) a[8]));
|
||||
new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]),
|
||||
(String) a[7],
|
||||
(Double) a[8],
|
||||
(Long) a[9]));
|
||||
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
|
||||
BoostedTreeParams.declareFields(parser);
|
||||
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
|
@ -85,7 +85,7 @@ public class Regression implements DataFrameAnalysis {
|
|||
}
|
||||
|
||||
public Regression(String dependentVariable) {
|
||||
this(dependentVariable, new BoostedTreeParams(), null, null, null);
|
||||
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null);
|
||||
}
|
||||
|
||||
public Regression(StreamInput in) throws IOException {
|
||||
|
|
|
@ -471,6 +471,9 @@ public class ElasticsearchMappings {
|
|||
.startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName())
|
||||
.field(TYPE, INTEGER)
|
||||
.endObject()
|
||||
.startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName())
|
||||
.field(TYPE, KEYWORD)
|
||||
.endObject()
|
||||
|
@ -499,6 +502,9 @@ public class ElasticsearchMappings {
|
|||
.startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.startObject(BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName())
|
||||
.field(TYPE, INTEGER)
|
||||
.endObject()
|
||||
.startObject(Classification.PREDICTION_FIELD_NAME.getPreferredName())
|
||||
.field(TYPE, KEYWORD)
|
||||
.endObject()
|
||||
|
|
|
@ -323,6 +323,7 @@ public final class ReservedFieldNames {
|
|||
BoostedTreeParams.ETA.getPreferredName(),
|
||||
BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName(),
|
||||
BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName(),
|
||||
BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(),
|
||||
|
||||
ElasticsearchMappings.CONFIG_TYPE,
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedT
|
|||
new ConstructingObjectParser<>(
|
||||
BoostedTreeParams.NAME,
|
||||
true,
|
||||
a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4]));
|
||||
a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4], (Integer) a[5]));
|
||||
BoostedTreeParams.declareFields(objParser);
|
||||
return objParser.apply(parser, null);
|
||||
}
|
||||
|
@ -34,12 +34,14 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedT
|
|||
}
|
||||
|
||||
public static BoostedTreeParams createRandom() {
|
||||
Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
|
||||
Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true);
|
||||
Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true);
|
||||
Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000);
|
||||
Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false);
|
||||
return new BoostedTreeParams(lambda, gamma, eta, maximumNumberTrees, featureBagFraction);
|
||||
return BoostedTreeParams.builder()
|
||||
.setLambda(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
|
||||
.setGamma(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
|
||||
.setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true))
|
||||
.setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
|
||||
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
|
||||
.setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -49,57 +51,64 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase<BoostedT
|
|||
|
||||
public void testConstructor_GivenNegativeLambda() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(-0.00001, 0.0, 0.5, 500, 0.3));
|
||||
() -> BoostedTreeParams.builder().setLambda(-0.00001).build());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNegativeGamma() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, -0.00001, 0.5, 500, 0.3));
|
||||
() -> BoostedTreeParams.builder().setGamma(-0.00001).build());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenEtaIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, 0.0, 0.0, 500, 0.3));
|
||||
() -> BoostedTreeParams.builder().setEta(0.0).build());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenEtaIsGreaterThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, 0.0, 1.00001, 500, 0.3));
|
||||
() -> BoostedTreeParams.builder().setEta(1.00001).build());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenMaximumNumberTreesIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 0, 0.3));
|
||||
() -> BoostedTreeParams.builder().setMaximumNumberTrees(0).build());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenMaximumNumberTreesIsGreaterThan2k() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 2001, 0.3));
|
||||
() -> BoostedTreeParams.builder().setMaximumNumberTrees(2001).build());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenFeatureBagFractionIsLessThanZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, -0.00001));
|
||||
() -> BoostedTreeParams.builder().setFeatureBagFraction(-0.00001).build());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenFeatureBagFractionIsGreaterThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.00001));
|
||||
() -> BoostedTreeParams.builder().setFeatureBagFraction(1.00001).build());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTopFeatureImportanceValuesIsNegative() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> BoostedTreeParams.builder().setNumTopFeatureImportanceValues(-1).build());
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[num_top_feature_importance_values] must be a non-negative integer"));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ import static org.hamcrest.Matchers.nullValue;
|
|||
|
||||
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
||||
|
||||
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0);
|
||||
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build();
|
||||
|
||||
@Override
|
||||
protected Classification doParseInstance(XContentParser parser) throws IOException {
|
||||
|
|
|
@ -31,7 +31,7 @@ import static org.hamcrest.Matchers.nullValue;
|
|||
|
||||
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
|
||||
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0);
|
||||
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build();
|
||||
|
||||
@Override
|
||||
protected Regression doParseInstance(XContentParser parser) throws IOException {
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import com.google.common.collect.Ordering;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
|
||||
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
|
||||
|
@ -28,7 +27,6 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
|
@ -86,7 +84,14 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
String predictedClassField = KEYWORD_FIELD + "_prediction";
|
||||
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null,
|
||||
new Classification(
|
||||
KEYWORD_FIELD,
|
||||
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -104,6 +109,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
|
||||
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
|
||||
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
||||
assertThat(resultsObject.keySet().stream().filter(k -> k.startsWith("feature_importance.")).findAny().isPresent(), is(true));
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
|
@ -178,7 +184,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0, null));
|
||||
new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, numTopClasses, 50.0, null));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -413,7 +419,13 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
String firstJobId = "classification_two_jobs_with_same_randomize_seed_1";
|
||||
String firstJobDestIndex = firstJobId + "_dest";
|
||||
|
||||
BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0);
|
||||
BoostedTreeParams boostedTreeParams = BoostedTreeParams.builder()
|
||||
.setLambda(1.0)
|
||||
.setGamma(1.0)
|
||||
.setEta(1.0)
|
||||
.setFeatureBagFraction(1.0)
|
||||
.setMaximumNumberTrees(1)
|
||||
.build();
|
||||
|
||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||
new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null));
|
||||
|
|
|
@ -18,7 +18,6 @@ import org.elasticsearch.search.SearchHit;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.junit.After;
|
||||
|
||||
|
@ -53,7 +52,14 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
||||
indexData(sourceIndex, 300, 50);
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null,
|
||||
new Regression(
|
||||
DEPENDENT_VARIABLE_FIELD,
|
||||
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
|
||||
null,
|
||||
null,
|
||||
null)
|
||||
);
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -78,6 +84,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
||||
assertThat(resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD), is(true));
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
|
@ -141,7 +148,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0, null));
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -244,7 +251,13 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
String firstJobId = "regression_two_jobs_with_same_randomize_seed_1";
|
||||
String firstJobDestIndex = firstJobId + "_dest";
|
||||
|
||||
BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0);
|
||||
BoostedTreeParams boostedTreeParams = BoostedTreeParams.builder()
|
||||
.setLambda(1.0)
|
||||
.setGamma(1.0)
|
||||
.setEta(1.0)
|
||||
.setFeatureBagFraction(1.0)
|
||||
.setMaximumNumberTrees(1)
|
||||
.build();
|
||||
|
||||
DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
|
||||
new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null));
|
||||
|
|
Loading…
Reference in New Issue