diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java index 9d384e6d867..02861adc738 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java @@ -46,6 +46,7 @@ public class Classification implements DataFrameAnalysis { static final ParseField ETA = new ParseField("eta"); static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); @@ -62,10 +63,11 @@ public class Classification implements DataFrameAnalysis { (Double) a[3], (Integer) a[4], (Double) a[5], - (String) a[6], - (Double) a[7], - (Integer) a[8], - (Long) a[9])); + (Integer) a[6], + (String) a[7], + (Double) a[8], + (Integer) a[9], + (Long) a[10])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -74,6 +76,7 @@ public class Classification implements DataFrameAnalysis { PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES); @@ -86,13 +89,15 @@ public class Classification implements DataFrameAnalysis { private final Double eta; private final Integer maximumNumberTrees; private final Double featureBagFraction; + private final Integer numTopFeatureImportanceValues; private final String predictionFieldName; private final Double trainingPercent; private final Integer numTopClasses; private final Long randomizeSeed; private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, - @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, + @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, + @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName, @Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; @@ -100,6 +105,7 @@ public class Classification implements DataFrameAnalysis { this.eta = eta; this.maximumNumberTrees = maximumNumberTrees; this.featureBagFraction = featureBagFraction; + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; this.numTopClasses = numTopClasses; @@ -135,6 +141,10 @@ public class Classification implements DataFrameAnalysis { return featureBagFraction; } + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + public String getPredictionFieldName() { return predictionFieldName; } @@ -170,6 +180,9 @@ public class Classification implements DataFrameAnalysis { if (featureBagFraction != null) { builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } if (predictionFieldName != null) { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } @@ -188,8 +201,8 @@ public class Classification implements DataFrameAnalysis { @Override public int hashCode() { - return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, randomizeSeed, numTopClasses); + return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues, + predictionFieldName, trainingPercent, randomizeSeed, numTopClasses); } @Override @@ -203,6 +216,7 @@ public class Classification implements DataFrameAnalysis { && Objects.equals(eta, that.eta) && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues) && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(trainingPercent, that.trainingPercent) && Objects.equals(randomizeSeed, that.randomizeSeed) @@ -221,6 +235,7 @@ public class Classification implements DataFrameAnalysis { private Double eta; private Integer maximumNumberTrees; private Double featureBagFraction; + private Integer numTopFeatureImportanceValues; private String predictionFieldName; private Double trainingPercent; private Integer numTopClasses; @@ -255,6 +270,11 @@ public class Classification implements DataFrameAnalysis { return this; } + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + public Builder setPredictionFieldName(String predictionFieldName) { this.predictionFieldName = predictionFieldName; return this; @@ -276,8 +296,8 @@ public class Classification implements DataFrameAnalysis { } public Classification build() { - return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, numTopClasses, randomizeSeed); + return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, + numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed); } } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java index fa55ee40b27..d7e374a2563 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java @@ -46,6 +46,7 @@ public class Regression implements DataFrameAnalysis { static final ParseField ETA = new ParseField("eta"); static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); @@ -61,9 +62,10 @@ public class Regression implements DataFrameAnalysis { (Double) a[3], (Integer) a[4], (Double) a[5], - (String) a[6], - (Double) a[7], - (Long) a[8])); + (Integer) a[6], + (String) a[7], + (Double) a[8], + (Long) a[9])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -72,6 +74,7 @@ public class Regression implements DataFrameAnalysis { PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED); @@ -83,12 +86,14 @@ public class Regression implements DataFrameAnalysis { private final Double eta; private final Integer maximumNumberTrees; private final Double featureBagFraction; + private final Integer numTopFeatureImportanceValues; private final String predictionFieldName; private final Double trainingPercent; private final Long randomizeSeed; - private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, - @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, + private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, + @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, + @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName, @Nullable Double trainingPercent, @Nullable Long randomizeSeed) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; @@ -96,6 +101,7 @@ public class Regression implements DataFrameAnalysis { this.eta = eta; this.maximumNumberTrees = maximumNumberTrees; this.featureBagFraction = featureBagFraction; + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; this.randomizeSeed = randomizeSeed; @@ -130,6 +136,10 @@ public class Regression implements DataFrameAnalysis { return featureBagFraction; } + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + public String getPredictionFieldName() { return predictionFieldName; } @@ -161,6 +171,9 @@ public class Regression implements DataFrameAnalysis { if (featureBagFraction != null) { builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } if (predictionFieldName != null) { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } @@ -176,8 +189,8 @@ public class Regression implements DataFrameAnalysis { @Override public int hashCode() { - return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, randomizeSeed); + return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues, + predictionFieldName, trainingPercent, randomizeSeed); } @Override @@ -191,6 +204,7 @@ public class Regression implements DataFrameAnalysis { && Objects.equals(eta, that.eta) && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues) && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(trainingPercent, that.trainingPercent) && Objects.equals(randomizeSeed, that.randomizeSeed); @@ -208,6 +222,7 @@ public class Regression implements DataFrameAnalysis { private Double eta; private Integer maximumNumberTrees; private Double featureBagFraction; + private Integer numTopFeatureImportanceValues; private String predictionFieldName; private Double trainingPercent; private Long randomizeSeed; @@ -241,6 +256,11 @@ public class Regression implements DataFrameAnalysis { return this; } + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + public Builder setPredictionFieldName(String predictionFieldName) { this.predictionFieldName = predictionFieldName; return this; @@ -257,8 +277,8 @@ public class Regression implements DataFrameAnalysis { } public Regression build() { - return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, randomizeSeed); + return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, + numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 247b726e008..f9ed6f4e259 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1324,6 +1324,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .setPredictionFieldName("my_dependent_variable_prediction") .setTrainingPercent(80.0) .setRandomizeSeed(42L) + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setMaximumNumberTrees(10) + .setFeatureBagFraction(0.5) + .setNumTopFeatureImportanceValues(3) .build()) .setDescription("this is a regression") .build(); @@ -1361,6 +1367,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .setTrainingPercent(80.0) .setRandomizeSeed(42L) .setNumTopClasses(1) + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setMaximumNumberTrees(10) + .setFeatureBagFraction(0.5) + .setNumTopFeatureImportanceValues(3) .build()) .setDescription("this is a classification") .build(); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 860fe533fd3..142f1f1f660 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -2975,10 +2975,11 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .setEta(5.5) // <4> .setMaximumNumberTrees(50) // <5> .setFeatureBagFraction(0.4) // <6> - .setPredictionFieldName("my_prediction_field_name") // <7> - .setTrainingPercent(50.0) // <8> - .setRandomizeSeed(1234L) // <9> - .setNumTopClasses(1) // <10> + .setNumTopFeatureImportanceValues(3) // <7> + .setPredictionFieldName("my_prediction_field_name") // <8> + .setTrainingPercent(50.0) // <9> + .setRandomizeSeed(1234L) // <10> + .setNumTopClasses(1) // <11> .build(); // end::put-data-frame-analytics-classification @@ -2989,9 +2990,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .setEta(5.5) // <4> .setMaximumNumberTrees(50) // <5> .setFeatureBagFraction(0.4) // <6> - .setPredictionFieldName("my_prediction_field_name") // <7> - .setTrainingPercent(50.0) // <8> - .setRandomizeSeed(1234L) // <9> + .setNumTopFeatureImportanceValues(3) // <7> + .setPredictionFieldName("my_prediction_field_name") // <8> + .setTrainingPercent(50.0) // <9> + .setRandomizeSeed(1234L) // <10> .build(); // end::put-data-frame-analytics-regression @@ -3670,7 +3672,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { } { PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig); - + // tag::put-trained-model-execute-listener ActionListener listener = new ActionListener() { @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java index 5ef8fdaef5a..79d78c88888 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java @@ -32,6 +32,7 @@ public class ClassificationTests extends AbstractXContentTestCase { .setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true)) .setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000)) .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false)) + .setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE)) .setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10)) .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true)) .build(); diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc index 2152eff5c08..4be20113402 100644 --- a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -117,10 +117,11 @@ include-tagged::{doc-tests-file}[{api}-classification] <4> The applied shrinkage. A double in [0.001, 1]. <5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000]. <6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. -<7> The name of the prediction field in the results object. -<8> The percentage of training-eligible rows to be used in training. Defaults to 100%. -<9> The seed to be used by the random generator that picks which rows are used in training. -<10> The number of top classes to be reported in the results. Defaults to 2. +<7> If set, feature importance for the top most important features will be computed. +<8> The name of the prediction field in the results object. +<9> The percentage of training-eligible rows to be used in training. Defaults to 100%. +<10> The seed to be used by the random generator that picks which rows are used in training. +<11> The number of top classes to be reported in the results. Defaults to 2. ===== Regression @@ -137,9 +138,10 @@ include-tagged::{doc-tests-file}[{api}-regression] <4> The applied shrinkage. A double in [0.001, 1]. <5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000]. <6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. -<7> The name of the prediction field in the results object. -<8> The percentage of training-eligible rows to be used in training. Defaults to 100%. -<9> The seed to be used by the random generator that picks which rows are used in training. +<7> If set, feature importance for the top most important features will be computed. +<8> The name of the prediction field in the results object. +<9> The percentage of training-eligible rows to be used in training. Defaults to 100%. +<10> The seed to be used by the random generator that picks which rows are used in training. ==== Analyzed fields diff --git a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc index 8ecc11e115f..b38b42f3af8 100644 --- a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc @@ -150,6 +150,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction-field-name] (Optional, long) include::{docdir}/ml/ml-shared.asciidoc[tag=randomize-seed] +`analysis`.`classification`.`num_top_feature_importance_values`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values] + `analysis`.`classification`.`training_percent`:::: (Optional, integer) include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent] @@ -229,6 +233,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=lambda] (Optional, string) include::{docdir}/ml/ml-shared.asciidoc[tag=prediction-field-name] +`analysis`.`regression`.`num_top_feature_importance_values`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values] + `analysis`.`regression`.`training_percent`:::: (Optional, integer) include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent] diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 07e7f38d42f..8d6022232e8 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -637,6 +637,14 @@ end::include-model-definition[] tag::indices[] An array of index names. Wildcards are supported. For example: `["it_ops_metrics", "server*"]`. + +tag::num-top-feature-importance-values[] +Advanced configuration option. If set, feature importance for the top +most important features will be computed. Importance is calculated +using the SHAP (SHapley Additive exPlanations) method as described in +https://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf[Lundberg, S. M., & Lee, S.-I. A Unified Approach to Interpreting Model Predictions. In NeurIPS 2017.]. +end::num-top-feature-importance-values[] + + -- NOTE: If any indices are in remote clusters then `cluster.remote.connect` must diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java index 0f06b08444f..e0890c21377 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.analyses; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -34,6 +35,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { public static final ParseField ETA = new ParseField("eta"); public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static void declareFields(AbstractObjectParser parser) { parser.declareDouble(optionalConstructorArg(), LAMBDA); @@ -41,6 +43,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { parser.declareDouble(optionalConstructorArg(), ETA); parser.declareInt(optionalConstructorArg(), MAXIMUM_NUMBER_TREES); parser.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION); + parser.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); } private final Double lambda; @@ -48,12 +51,14 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { private final Double eta; private final Integer maximumNumberTrees; private final Double featureBagFraction; + private final Integer numTopFeatureImportanceValues; public BoostedTreeParams(@Nullable Double lambda, - @Nullable Double gamma, - @Nullable Double eta, - @Nullable Integer maximumNumberTrees, - @Nullable Double featureBagFraction) { + @Nullable Double gamma, + @Nullable Double eta, + @Nullable Integer maximumNumberTrees, + @Nullable Double featureBagFraction, + @Nullable Integer numTopFeatureImportanceValues) { if (lambda != null && lambda < 0) { throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName()); } @@ -69,15 +74,16 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) { throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName()); } + if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a non-negative integer", + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); + } this.lambda = lambda; this.gamma = gamma; this.eta = eta; this.maximumNumberTrees = maximumNumberTrees; this.featureBagFraction = featureBagFraction; - } - - public BoostedTreeParams() { - this(null, null, null, null, null); + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; } BoostedTreeParams(StreamInput in) throws IOException { @@ -86,6 +92,11 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { eta = in.readOptionalDouble(); maximumNumberTrees = in.readOptionalVInt(); featureBagFraction = in.readOptionalDouble(); + if (in.getVersion().onOrAfter(Version.V_7_6_0)) { + numTopFeatureImportanceValues = in.readOptionalInt(); + } else { + numTopFeatureImportanceValues = null; + } } @Override @@ -95,6 +106,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { out.writeOptionalDouble(eta); out.writeOptionalVInt(maximumNumberTrees); out.writeOptionalDouble(featureBagFraction); + if (out.getVersion().onOrAfter(Version.V_7_6_0)) { + out.writeOptionalInt(numTopFeatureImportanceValues); + } } @Override @@ -114,6 +128,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { if (featureBagFraction != null) { builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } return builder; } @@ -134,6 +151,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { if (featureBagFraction != null) { params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + params.put(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } return params; } @@ -146,11 +166,62 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { && Objects.equals(gamma, that.gamma) && Objects.equals(eta, that.eta) && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) - && Objects.equals(featureBagFraction, that.featureBagFraction); + && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); } @Override public int hashCode() { - return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction); + return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private Double lambda; + private Double gamma; + private Double eta; + private Integer maximumNumberTrees; + private Double featureBagFraction; + private Integer numTopFeatureImportanceValues; + + private Builder() {} + + public Builder setLambda(Double lambda) { + this.lambda = lambda; + return this; + } + + public Builder setGamma(Double gamma) { + this.gamma = gamma; + return this; + } + + public Builder setEta(Double eta) { + this.eta = eta; + return this; + } + + public Builder setMaximumNumberTrees(Integer maximumNumberTrees) { + this.maximumNumberTrees = maximumNumberTrees; + return this; + } + + public Builder setFeatureBagFraction(Double featureBagFraction) { + this.featureBagFraction = featureBagFraction; + return this; + } + + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + + public BoostedTreeParams build() { + return new BoostedTreeParams(lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 5a6cc664edf..24b814d19ed 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -50,11 +50,11 @@ public class Classification implements DataFrameAnalysis { lenient, a -> new Classification( (String) a[0], - new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), - (String) a[6], - (Integer) a[7], - (Double) a[8], - (Long) a[9])); + new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]), + (String) a[7], + (Integer) a[8], + (Double) a[9], + (Long) a[10])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); @@ -114,7 +114,7 @@ public class Classification implements DataFrameAnalysis { } public Classification(String dependentVariable) { - this(dependentVariable, new BoostedTreeParams(), null, null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null); } public Classification(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index fe292759131..83174a9aebf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -47,10 +47,10 @@ public class Regression implements DataFrameAnalysis { lenient, a -> new Regression( (String) a[0], - new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), - (String) a[6], - (Double) a[7], - (Long) a[8])); + new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]), + (String) a[7], + (Double) a[8], + (Long) a[9])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); @@ -85,7 +85,7 @@ public class Regression implements DataFrameAnalysis { } public Regression(String dependentVariable) { - this(dependentVariable, new BoostedTreeParams(), null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null); } public Regression(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index 46294386109..a90f0d91970 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -471,6 +471,9 @@ public class ElasticsearchMappings { .startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName()) .field(TYPE, DOUBLE) .endObject() + .startObject(BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() .startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName()) .field(TYPE, KEYWORD) .endObject() @@ -499,6 +502,9 @@ public class ElasticsearchMappings { .startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName()) .field(TYPE, DOUBLE) .endObject() + .startObject(BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() .startObject(Classification.PREDICTION_FIELD_NAME.getPreferredName()) .field(TYPE, KEYWORD) .endObject() diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index d96f57d0681..23075b2b9df 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -323,6 +323,7 @@ public final class ReservedFieldNames { BoostedTreeParams.ETA.getPreferredName(), BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName(), BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName(), + BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), ElasticsearchMappings.CONFIG_TYPE, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java index 145533df407..6f3aff88846 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java @@ -23,7 +23,7 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase( BoostedTreeParams.NAME, true, - a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4])); + a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4], (Integer) a[5])); BoostedTreeParams.declareFields(objParser); return objParser.apply(parser, null); } @@ -34,12 +34,14 @@ public class BoostedTreeParamsTests extends AbstractSerializingTestCase new BoostedTreeParams(-0.00001, 0.0, 0.5, 500, 0.3)); + () -> BoostedTreeParams.builder().setLambda(-0.00001).build()); assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double")); } public void testConstructor_GivenNegativeGamma() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, -0.00001, 0.5, 500, 0.3)); + () -> BoostedTreeParams.builder().setGamma(-0.00001).build()); assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double")); } public void testConstructor_GivenEtaIsZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.0, 500, 0.3)); + () -> BoostedTreeParams.builder().setEta(0.0).build()); assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); } public void testConstructor_GivenEtaIsGreaterThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 1.00001, 500, 0.3)); + () -> BoostedTreeParams.builder().setEta(1.00001).build()); assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); } public void testConstructor_GivenMaximumNumberTreesIsZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 0, 0.3)); + () -> BoostedTreeParams.builder().setMaximumNumberTrees(0).build()); assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); } public void testConstructor_GivenMaximumNumberTreesIsGreaterThan2k() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 2001, 0.3)); + () -> BoostedTreeParams.builder().setMaximumNumberTrees(2001).build()); assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); } public void testConstructor_GivenFeatureBagFractionIsLessThanZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, -0.00001)); + () -> BoostedTreeParams.builder().setFeatureBagFraction(-0.00001).build()); assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); } public void testConstructor_GivenFeatureBagFractionIsGreaterThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.00001)); + () -> BoostedTreeParams.builder().setFeatureBagFraction(1.00001).build()); assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); } + + public void testConstructor_GivenTopFeatureImportanceValuesIsNegative() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> BoostedTreeParams.builder().setNumTopFeatureImportanceValues(-1).build()); + + assertThat(e.getMessage(), equalTo("[num_top_feature_importance_values] must be a non-negative integer")); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 7a0af05071b..55afb76ef5c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -37,7 +37,7 @@ import static org.hamcrest.Matchers.nullValue; public class ClassificationTests extends AbstractSerializingTestCase { - private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0); + private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build(); @Override protected Classification doParseInstance(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index c123a0553d1..ab9e12650e8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -31,7 +31,7 @@ import static org.hamcrest.Matchers.nullValue; public class RegressionTests extends AbstractSerializingTestCase { - private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0); + private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build(); @Override protected Regression doParseInstance(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 50af362e088..639c4da5df0 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.ml.integration; import com.google.common.collect.Ordering; - import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.admin.indices.get.GetIndexAction; import org.elasticsearch.action.admin.indices.get.GetIndexRequest; @@ -28,7 +27,6 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; @@ -86,7 +84,14 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { String predictedClassField = KEYWORD_FIELD + "_prediction"; indexData(sourceIndex, 300, 50, KEYWORD_FIELD); - DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, + new Classification( + KEYWORD_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), + null, + null, + null, + null)); registerAnalytics(config); putAnalytics(config); @@ -104,6 +109,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES))); assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD))); assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); + assertThat(resultsObject.keySet().stream().filter(k -> k.startsWith("feature_importance.")).findAny().isPresent(), is(true)); } assertProgress(jobId, 100, 100, 100, 100); @@ -178,7 +184,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { sourceIndex, destIndex, null, - new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0, null)); + new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, numTopClasses, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -413,7 +419,13 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { String firstJobId = "classification_two_jobs_with_same_randomize_seed_1"; String firstJobDestIndex = firstJobId + "_dest"; - BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0); + BoostedTreeParams boostedTreeParams = BoostedTreeParams.builder() + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setFeatureBagFraction(1.0) + .setMaximumNumberTrees(1) + .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null)); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 5ecab6f69d4..8b7350d9e13 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -18,7 +18,6 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.junit.After; @@ -53,7 +52,14 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { initialize("regression_single_numeric_feature_and_mixed_data_set"); indexData(sourceIndex, 300, 50); - DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD)); + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, + new Regression( + DEPENDENT_VARIABLE_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), + null, + null, + null) + ); registerAnalytics(config); putAnalytics(config); @@ -78,6 +84,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(resultsObject.containsKey("variable_prediction"), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD))); + assertThat(resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD), is(true)); } assertProgress(jobId, 100, 100, 100, 100); @@ -141,7 +148,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { sourceIndex, destIndex, null, - new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0, null)); + new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -244,7 +251,13 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { String firstJobId = "regression_two_jobs_with_same_randomize_seed_1"; String firstJobDestIndex = firstJobId + "_dest"; - BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0); + BoostedTreeParams boostedTreeParams = BoostedTreeParams.builder() + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setFeatureBagFraction(1.0) + .setMaximumNumberTrees(1) + .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null));