From 1d8cb3c741650b54f9b7cacf514e0af5b25b768b Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 14 Jan 2020 16:46:09 +0200 Subject: [PATCH] =?UTF-8?q?[7.x][ML]=20Add=20num=5Ftop=5Ffeature=5Fimporta?= =?UTF-8?q?nce=5Fvalues=20param=20to=20regression=20and=20classi=E2=80=A6?= =?UTF-8?q?=20(#50914)=20(#50976)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a new parameter to regression and classification that enables computation of importance for the top most important features. The computation of the importance is based on SHAP (SHapley Additive exPlanations) method. Backport of #50914 --- .../client/ml/dataframe/Classification.java | 38 ++++++-- .../client/ml/dataframe/Regression.java | 38 ++++++-- .../client/MachineLearningIT.java | 12 +++ .../MlClientDocumentationIT.java | 18 ++-- .../ml/dataframe/ClassificationTests.java | 1 + .../client/ml/dataframe/RegressionTests.java | 1 + .../ml/put-data-frame-analytics.asciidoc | 16 ++-- .../apis/put-dfanalytics.asciidoc | 8 ++ docs/reference/ml/ml-shared.asciidoc | 8 ++ .../dataframe/analyses/BoostedTreeParams.java | 91 +++++++++++++++++-- .../ml/dataframe/analyses/Classification.java | 12 +-- .../ml/dataframe/analyses/Regression.java | 10 +- .../persistence/ElasticsearchMappings.java | 6 ++ .../ml/job/results/ReservedFieldNames.java | 1 + .../analyses/BoostedTreeParamsTests.java | 39 +++++--- .../analyses/ClassificationTests.java | 2 +- .../dataframe/analyses/RegressionTests.java | 2 +- .../ml/integration/ClassificationIT.java | 22 ++++- .../xpack/ml/integration/RegressionIT.java | 21 ++++- 19 files changed, 266 insertions(+), 80 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java index 9d384e6d867..02861adc738 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java @@ -46,6 +46,7 @@ public class Classification implements DataFrameAnalysis { static final ParseField ETA = new ParseField("eta"); static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); @@ -62,10 +63,11 @@ public class Classification implements DataFrameAnalysis { (Double) a[3], (Integer) a[4], (Double) a[5], - (String) a[6], - (Double) a[7], - (Integer) a[8], - (Long) a[9])); + (Integer) a[6], + (String) a[7], + (Double) a[8], + (Integer) a[9], + (Long) a[10])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -74,6 +76,7 @@ public class Classification implements DataFrameAnalysis { PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES); @@ -86,13 +89,15 @@ public class Classification implements DataFrameAnalysis { private final Double eta; private final Integer maximumNumberTrees; private final Double featureBagFraction; + private final Integer numTopFeatureImportanceValues; private final String predictionFieldName; private final Double trainingPercent; private final Integer numTopClasses; private final Long randomizeSeed; private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, - @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, + @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, + @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName, @Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; @@ -100,6 +105,7 @@ public class Classification implements DataFrameAnalysis { this.eta = eta; this.maximumNumberTrees = maximumNumberTrees; this.featureBagFraction = featureBagFraction; + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; this.numTopClasses = numTopClasses; @@ -135,6 +141,10 @@ public class Classification implements DataFrameAnalysis { return featureBagFraction; } + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + public String getPredictionFieldName() { return predictionFieldName; } @@ -170,6 +180,9 @@ public class Classification implements DataFrameAnalysis { if (featureBagFraction != null) { builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } if (predictionFieldName != null) { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } @@ -188,8 +201,8 @@ public class Classification implements DataFrameAnalysis { @Override public int hashCode() { - return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, randomizeSeed, numTopClasses); + return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues, + predictionFieldName, trainingPercent, randomizeSeed, numTopClasses); } @Override @@ -203,6 +216,7 @@ public class Classification implements DataFrameAnalysis { && Objects.equals(eta, that.eta) && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues) && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(trainingPercent, that.trainingPercent) && Objects.equals(randomizeSeed, that.randomizeSeed) @@ -221,6 +235,7 @@ public class Classification implements DataFrameAnalysis { private Double eta; private Integer maximumNumberTrees; private Double featureBagFraction; + private Integer numTopFeatureImportanceValues; private String predictionFieldName; private Double trainingPercent; private Integer numTopClasses; @@ -255,6 +270,11 @@ public class Classification implements DataFrameAnalysis { return this; } + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + public Builder setPredictionFieldName(String predictionFieldName) { this.predictionFieldName = predictionFieldName; return this; @@ -276,8 +296,8 @@ public class Classification implements DataFrameAnalysis { } public Classification build() { - return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, numTopClasses, randomizeSeed); + return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, + numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed); } } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java index fa55ee40b27..d7e374a2563 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java @@ -46,6 +46,7 @@ public class Regression implements DataFrameAnalysis { static final ParseField ETA = new ParseField("eta"); static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); @@ -61,9 +62,10 @@ public class Regression implements DataFrameAnalysis { (Double) a[3], (Integer) a[4], (Double) a[5], - (String) a[6], - (Double) a[7], - (Long) a[8])); + (Integer) a[6], + (String) a[7], + (Double) a[8], + (Long) a[9])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -72,6 +74,7 @@ public class Regression implements DataFrameAnalysis { PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED); @@ -83,12 +86,14 @@ public class Regression implements DataFrameAnalysis { private final Double eta; private final Integer maximumNumberTrees; private final Double featureBagFraction; + private final Integer numTopFeatureImportanceValues; private final String predictionFieldName; private final Double trainingPercent; private final Long randomizeSeed; - private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, - @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, + private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, + @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, + @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName, @Nullable Double trainingPercent, @Nullable Long randomizeSeed) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; @@ -96,6 +101,7 @@ public class Regression implements DataFrameAnalysis { this.eta = eta; this.maximumNumberTrees = maximumNumberTrees; this.featureBagFraction = featureBagFraction; + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; this.randomizeSeed = randomizeSeed; @@ -130,6 +136,10 @@ public class Regression implements DataFrameAnalysis { return featureBagFraction; } + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + public String getPredictionFieldName() { return predictionFieldName; } @@ -161,6 +171,9 @@ public class Regression implements DataFrameAnalysis { if (featureBagFraction != null) { builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } if (predictionFieldName != null) { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } @@ -176,8 +189,8 @@ public class Regression implements DataFrameAnalysis { @Override public int hashCode() { - return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, randomizeSeed); + return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues, + predictionFieldName, trainingPercent, randomizeSeed); } @Override @@ -191,6 +204,7 @@ public class Regression implements DataFrameAnalysis { && Objects.equals(eta, that.eta) && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues) && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(trainingPercent, that.trainingPercent) && Objects.equals(randomizeSeed, that.randomizeSeed); @@ -208,6 +222,7 @@ public class Regression implements DataFrameAnalysis { private Double eta; private Integer maximumNumberTrees; private Double featureBagFraction; + private Integer numTopFeatureImportanceValues; private String predictionFieldName; private Double trainingPercent; private Long randomizeSeed; @@ -241,6 +256,11 @@ public class Regression implements DataFrameAnalysis { return this; } + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + public Builder setPredictionFieldName(String predictionFieldName) { this.predictionFieldName = predictionFieldName; return this; @@ -257,8 +277,8 @@ public class Regression implements DataFrameAnalysis { } public Regression build() { - return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, randomizeSeed); + return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, + numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 247b726e008..f9ed6f4e259 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1324,6 +1324,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .setPredictionFieldName("my_dependent_variable_prediction") .setTrainingPercent(80.0) .setRandomizeSeed(42L) + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setMaximumNumberTrees(10) + .setFeatureBagFraction(0.5) + .setNumTopFeatureImportanceValues(3) .build()) .setDescription("this is a regression") .build(); @@ -1361,6 +1367,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .setTrainingPercent(80.0) .setRandomizeSeed(42L) .setNumTopClasses(1) + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setMaximumNumberTrees(10) + .setFeatureBagFraction(0.5) + .setNumTopFeatureImportanceValues(3) .build()) .setDescription("this is a classification") .build(); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 860fe533fd3..142f1f1f660 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -2975,10 +2975,11 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .setEta(5.5) // <4> .setMaximumNumberTrees(50) // <5> .setFeatureBagFraction(0.4) // <6> - .setPredictionFieldName("my_prediction_field_name") // <7> - .setTrainingPercent(50.0) // <8> - .setRandomizeSeed(1234L) // <9> - .setNumTopClasses(1) // <10> + .setNumTopFeatureImportanceValues(3) // <7> + .setPredictionFieldName("my_prediction_field_name") // <8> + .setTrainingPercent(50.0) // <9> + .setRandomizeSeed(1234L) // <10> + .setNumTopClasses(1) // <11> .build(); // end::put-data-frame-analytics-classification @@ -2989,9 +2990,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .setEta(5.5) // <4> .setMaximumNumberTrees(50) // <5> .setFeatureBagFraction(0.4) // <6> - .setPredictionFieldName("my_prediction_field_name") // <7> - .setTrainingPercent(50.0) // <8> - .setRandomizeSeed(1234L) // <9> + .setNumTopFeatureImportanceValues(3) // <7> + .setPredictionFieldName("my_prediction_field_name") // <8> + .setTrainingPercent(50.0) // <9> + .setRandomizeSeed(1234L) // <10> .build(); // end::put-data-frame-analytics-regression @@ -3670,7 +3672,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { } { PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig); - + // tag::put-trained-model-execute-listener ActionListener listener = new ActionListener() { @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java index 5ef8fdaef5a..79d78c88888 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java @@ -32,6 +32,7 @@ public class ClassificationTests extends AbstractXContentTestCase { .setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true)) .setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000)) .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false)) + .setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE)) .setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10)) .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true)) .build(); diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc index 2152eff5c08..4be20113402 100644 --- a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -117,10 +117,11 @@ include-tagged::{doc-tests-file}[{api}-classification] <4> The applied shrinkage. A double in [0.001, 1]. <5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000]. <6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. -<7> The name of the prediction field in the results object. -<8> The percentage of training-eligible rows to be used in training. Defaults to 100%. -<9> The seed to be used by the random generator that picks which rows are used in training. -<10> The number of top classes to be reported in the results. Defaults to 2. +<7> If set, feature importance for the top most important features will be computed. +<8> The name of the prediction field in the results object. +<9> The percentage of training-eligible rows to be used in training. Defaults to 100%. +<10> The seed to be used by the random generator that picks which rows are used in training. +<11> The number of top classes to be reported in the results. Defaults to 2. ===== Regression @@ -137,9 +138,10 @@ include-tagged::{doc-tests-file}[{api}-regression] <4> The applied shrinkage. A double in [0.001, 1]. <5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000]. <6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. -<7> The name of the prediction field in the results object. -<8> The percentage of training-eligible rows to be used in training. Defaults to 100%. -<9> The seed to be used by the random generator that picks which rows are used in training. +<7> If set, feature importance for the top most important features will be computed. +<8> The name of the prediction field in the results object. +<9> The percentage of training-eligible rows to be used in training. Defaults to 100%. +<10> The seed to be used by the random generator that picks which rows are used in training. ==== Analyzed fields diff --git a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc index 8ecc11e115f..b38b42f3af8 100644 --- a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc @@ -150,6 +150,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction-field-name] (Optional, long) include::{docdir}/ml/ml-shared.asciidoc[tag=randomize-seed] +`analysis`.`classification`.`num_top_feature_importance_values`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values] + `analysis`.`classification`.`training_percent`:::: (Optional, integer) include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent] @@ -229,6 +233,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=lambda] (Optional, string) include::{docdir}/ml/ml-shared.asciidoc[tag=prediction-field-name] +`analysis`.`regression`.`num_top_feature_importance_values`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values] + `analysis`.`regression`.`training_percent`:::: (Optional, integer) include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent] diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 07e7f38d42f..8d6022232e8 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -637,6 +637,14 @@ end::include-model-definition[] tag::indices[] An array of index names. Wildcards are supported. For example: `["it_ops_metrics", "server*"]`. + +tag::num-top-feature-importance-values[] +Advanced configuration option. If set, feature importance for the top +most important features will be computed. Importance is calculated +using the SHAP (SHapley Additive exPlanations) method as described in +https://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf[Lundberg, S. M., & Lee, S.-I. A Unified Approach to Interpreting Model Predictions. In NeurIPS 2017.]. +end::num-top-feature-importance-values[] + + -- NOTE: If any indices are in remote clusters then `cluster.remote.connect` must diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java index 0f06b08444f..e0890c21377 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.analyses; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -34,6 +35,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { public static final ParseField ETA = new ParseField("eta"); public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static void declareFields(AbstractObjectParser parser) { parser.declareDouble(optionalConstructorArg(), LAMBDA); @@ -41,6 +43,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { parser.declareDouble(optionalConstructorArg(), ETA); parser.declareInt(optionalConstructorArg(), MAXIMUM_NUMBER_TREES); parser.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION); + parser.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); } private final Double lambda; @@ -48,12 +51,14 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { private final Double eta; private final Integer maximumNumberTrees; private final Double featureBagFraction; + private final Integer numTopFeatureImportanceValues; public BoostedTreeParams(@Nullable Double lambda, - @Nullable Double gamma, - @Nullable Double eta, - @Nullable Integer maximumNumberTrees, - @Nullable Double featureBagFraction) { + @Nullable Double gamma, + @Nullable Double eta, + @Nullable Integer maximumNumberTrees, + @Nullable Double featureBagFraction, + @Nullable Integer numTopFeatureImportanceValues) { if (lambda != null && lambda < 0) { throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName()); } @@ -69,15 +74,16 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) { throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName()); } + if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a non-negative integer", + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); + } this.lambda = lambda; this.gamma = gamma; this.eta = eta; this.maximumNumberTrees = maximumNumberTrees; this.featureBagFraction = featureBagFraction; - } - - public BoostedTreeParams() { - this(null, null, null, null, null); + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; } BoostedTreeParams(StreamInput in) throws IOException { @@ -86,6 +92,11 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { eta = in.readOptionalDouble(); maximumNumberTrees = in.readOptionalVInt(); featureBagFraction = in.readOptionalDouble(); + if (in.getVersion().onOrAfter(Version.V_7_6_0)) { + numTopFeatureImportanceValues = in.readOptionalInt(); + } else { + numTopFeatureImportanceValues = null; + } } @Override @@ -95,6 +106,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { out.writeOptionalDouble(eta); out.writeOptionalVInt(maximumNumberTrees); out.writeOptionalDouble(featureBagFraction); + if (out.getVersion().onOrAfter(Version.V_7_6_0)) { + out.writeOptionalInt(numTopFeatureImportanceValues); + } } @Override @@ -114,6 +128,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { if (featureBagFraction != null) { builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } return builder; } @@ -134,6 +151,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { if (featureBagFraction != null) { params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + params.put(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } return params; } @@ -146,11 +166,62 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { && Objects.equals(gamma, that.gamma) && Objects.equals(eta, that.eta) && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) - && Objects.equals(featureBagFraction, that.featureBagFraction); + && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); } @Override public int hashCode() { - return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction); + return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private Double lambda; + private Double gamma; + private Double eta; + private Integer maximumNumberTrees; + private Double featureBagFraction; + private Integer numTopFeatureImportanceValues; + + private Builder() {} + + public Builder setLambda(Double lambda) { + this.lambda = lambda; + return this; + } + + public Builder setGamma(Double gamma) { + this.gamma = gamma; + return this; + } + + public Builder setEta(Double eta) { + this.eta = eta; + return this; + } + + public Builder setMaximumNumberTrees(Integer maximumNumberTrees) { + this.maximumNumberTrees = maximumNumberTrees; + return this; + } + + public Builder setFeatureBagFraction(Double featureBagFraction) { + this.featureBagFraction = featureBagFraction; + return this; + } + + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + + public BoostedTreeParams build() { + return new BoostedTreeParams(lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 5a6cc664edf..24b814d19ed 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -50,11 +50,11 @@ public class Classification implements DataFrameAnalysis { lenient, a -> new Classification( (String) a[0], - new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), - (String) a[6], - (Integer) a[7], - (Double) a[8], - (Long) a[9])); + new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]), + (String) a[7], + (Integer) a[8], + (Double) a[9], + (Long) a[10])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); @@ -114,7 +114,7 @@ public class Classification implements DataFrameAnalysis { } public Classification(String dependentVariable) { - this(dependentVariable, new BoostedTreeParams(), null, null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null); } public Classification(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index fe292759131..83174a9aebf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -47,10 +47,10 @@ public class Regression implements DataFrameAnalysis { lenient, a -> new Regression( (String) a[0], - new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), - (String) a[6], - (Double) a[7], - (Long) a[8])); + new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]), + (String) a[7], + (Double) a[8], + (Long) a[9])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); @@ -85,7 +85,7 @@ public class Regression implements DataFrameAnalysis { } public Regression(String dependentVariable) { - this(dependentVariable, new BoostedTreeParams(), null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null); } public Regression(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index 46294386109..a90f0d91970 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -471,6 +471,9 @@ public class ElasticsearchMappings { .startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName()) .field(TYPE, DOUBLE) .endObject() + .startObject(BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() .startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName()) .field(TYPE, KEYWORD) .endObject() @@ -499,6 +502,9 @@ public class ElasticsearchMappings { .startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName()) .field(TYPE, DOUBLE) .endObject() + .startObject(BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() .startObject(Classification.PREDICTION_FIELD_NAME.getPreferredName()) .field(TYPE, KEYWORD) .endObject() diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index d96f57d0681..23075b2b9df 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -323,6 +323,7 @@ public final class ReservedFieldNames { BoostedTreeParams.ETA.getPreferredName(), BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName(), BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName(), + BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), ElasticsearchMappings.CONFIG_TYPE, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java index 145533df407..6f3aff88846 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java @@ -23,7 +23,7 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase( BoostedTreeParams.NAME, true, - a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4])); + a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4], (Integer) a[5])); BoostedTreeParams.declareFields(objParser); return objParser.apply(parser, null); } @@ -34,12 +34,14 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase new BoostedTreeParams(-0.00001, 0.0, 0.5, 500, 0.3)); + () -> BoostedTreeParams.builder().setLambda(-0.00001).build()); assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double")); } public void testConstructor_GivenNegativeGamma() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, -0.00001, 0.5, 500, 0.3)); + () -> BoostedTreeParams.builder().setGamma(-0.00001).build()); assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double")); } public void testConstructor_GivenEtaIsZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.0, 500, 0.3)); + () -> BoostedTreeParams.builder().setEta(0.0).build()); assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); } public void testConstructor_GivenEtaIsGreaterThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 1.00001, 500, 0.3)); + () -> BoostedTreeParams.builder().setEta(1.00001).build()); assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); } public void testConstructor_GivenMaximumNumberTreesIsZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 0, 0.3)); + () -> BoostedTreeParams.builder().setMaximumNumberTrees(0).build()); assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); } public void testConstructor_GivenMaximumNumberTreesIsGreaterThan2k() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 2001, 0.3)); + () -> BoostedTreeParams.builder().setMaximumNumberTrees(2001).build()); assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); } public void testConstructor_GivenFeatureBagFractionIsLessThanZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, -0.00001)); + () -> BoostedTreeParams.builder().setFeatureBagFraction(-0.00001).build()); assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); } public void testConstructor_GivenFeatureBagFractionIsGreaterThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.00001)); + () -> BoostedTreeParams.builder().setFeatureBagFraction(1.00001).build()); assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); } + + public void testConstructor_GivenTopFeatureImportanceValuesIsNegative() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> BoostedTreeParams.builder().setNumTopFeatureImportanceValues(-1).build()); + + assertThat(e.getMessage(), equalTo("[num_top_feature_importance_values] must be a non-negative integer")); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 7a0af05071b..55afb76ef5c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -37,7 +37,7 @@ import static org.hamcrest.Matchers.nullValue; public class ClassificationTests extends AbstractSerializingTestCase { - private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0); + private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build(); @Override protected Classification doParseInstance(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index c123a0553d1..ab9e12650e8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -31,7 +31,7 @@ import static org.hamcrest.Matchers.nullValue; public class RegressionTests extends AbstractSerializingTestCase { - private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0); + private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build(); @Override protected Regression doParseInstance(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 50af362e088..639c4da5df0 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.ml.integration; import com.google.common.collect.Ordering; - import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.admin.indices.get.GetIndexAction; import org.elasticsearch.action.admin.indices.get.GetIndexRequest; @@ -28,7 +27,6 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; @@ -86,7 +84,14 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { String predictedClassField = KEYWORD_FIELD + "_prediction"; indexData(sourceIndex, 300, 50, KEYWORD_FIELD); - DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, + new Classification( + KEYWORD_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), + null, + null, + null, + null)); registerAnalytics(config); putAnalytics(config); @@ -104,6 +109,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES))); assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD))); assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); + assertThat(resultsObject.keySet().stream().filter(k -> k.startsWith("feature_importance.")).findAny().isPresent(), is(true)); } assertProgress(jobId, 100, 100, 100, 100); @@ -178,7 +184,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { sourceIndex, destIndex, null, - new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0, null)); + new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, numTopClasses, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -413,7 +419,13 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { String firstJobId = "classification_two_jobs_with_same_randomize_seed_1"; String firstJobDestIndex = firstJobId + "_dest"; - BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0); + BoostedTreeParams boostedTreeParams = BoostedTreeParams.builder() + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setFeatureBagFraction(1.0) + .setMaximumNumberTrees(1) + .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null)); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 5ecab6f69d4..8b7350d9e13 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -18,7 +18,6 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.junit.After; @@ -53,7 +52,14 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { initialize("regression_single_numeric_feature_and_mixed_data_set"); indexData(sourceIndex, 300, 50); - DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD)); + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, + new Regression( + DEPENDENT_VARIABLE_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), + null, + null, + null) + ); registerAnalytics(config); putAnalytics(config); @@ -78,6 +84,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(resultsObject.containsKey("variable_prediction"), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD))); + assertThat(resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD), is(true)); } assertProgress(jobId, 100, 100, 100, 100); @@ -141,7 +148,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { sourceIndex, destIndex, null, - new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0, null)); + new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -244,7 +251,13 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { String firstJobId = "regression_two_jobs_with_same_randomize_seed_1"; String firstJobDestIndex = firstJobId + "_dest"; - BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0); + BoostedTreeParams boostedTreeParams = BoostedTreeParams.builder() + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setFeatureBagFraction(1.0) + .setMaximumNumberTrees(1) + .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null));