diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index fef3cb0fb33..a7eefe199d4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -241,6 +241,7 @@ public class Classification implements DataFrameAnalysis { params.put(PREDICTION_FIELD_TYPE, predictionFieldType); } params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable)); + params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent); return params; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 0d40ef6ba35..824d4f95a17 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -163,6 +163,7 @@ public class Regression implements DataFrameAnalysis { if (predictionFieldName != null) { params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } + params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent); return params; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index d7285fb8600..14cbf338687 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -17,7 +17,6 @@ import org.elasticsearch.index.mapper.BooleanFieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; -import org.hamcrest.Matchers; import java.io.IOException; import java.util.Collections; @@ -201,31 +200,43 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase>allOf( - hasEntry("dependent_variable", "foo"), - hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), - hasEntry("num_top_classes", 2), - hasEntry("prediction_field_name", "foo_prediction"), - hasEntry("prediction_field_type", "bool"), - hasEntry("num_classes", 10L))); + equalTo( + org.elasticsearch.common.collect.Map.of( + "dependent_variable", "foo", + "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, + "num_top_classes", 2, + "prediction_field_name", "foo_prediction", + "prediction_field_type", "bool", + "num_classes", 10L, + "training_percent", 100.0))); assertThat( new Classification("bar").getParams(fieldInfo), - Matchers.>allOf( - hasEntry("dependent_variable", "bar"), - hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), - hasEntry("num_top_classes", 2), - hasEntry("prediction_field_name", "bar_prediction"), - hasEntry("prediction_field_type", "int"), - hasEntry("num_classes", 20L))); + equalTo( + org.elasticsearch.common.collect.Map.of( + "dependent_variable", "bar", + "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, + "num_top_classes", 2, + "prediction_field_name", "bar_prediction", + "prediction_field_type", "int", + "num_classes", 20L, + "training_percent", 100.0))); assertThat( - new Classification("baz").getParams(fieldInfo), - Matchers.>allOf( - hasEntry("dependent_variable", "baz"), - hasEntry("class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL), - hasEntry("num_top_classes", 2), - hasEntry("prediction_field_name", "baz_prediction"), - hasEntry("prediction_field_type", "string"), - hasEntry("num_classes", 30L))); + new Classification("baz", + BoostedTreeParams.builder().build() , + null, + null, + null, + 50.0, + null).getParams(fieldInfo), + equalTo( + org.elasticsearch.common.collect.Map.of( + "dependent_variable", "baz", + "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, + "num_top_classes", 2, + "prediction_field_name", "baz_prediction", + "prediction_field_type", "string", + "num_classes", 30L, + "training_percent", 50.0))); } public void testRequiredFieldsIsNonEmpty() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index 68ff2f71fb9..c7fbcadad45 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.Map; import java.util.Collections; -import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -132,7 +131,20 @@ public class RegressionTests extends AbstractBWCSerializationTestCase NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0)); - private static final List DISCRETE_NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(10L, 20L, 30L)); - private static final List DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList(10.0, 20.0, 30.0)); + static final String DEPENDENT_VARIABLE_FIELD = "variable"; + private static final List NUMERICAL_FEATURE_VALUES = org.elasticsearch.common.collect.List.of(1.0, 2.0, 3.0); + private static final List DISCRETE_NUMERICAL_FEATURE_VALUES = org.elasticsearch.common.collect.List.of(10L, 20L, 30L); + private static final List DEPENDENT_VARIABLE_VALUES = org.elasticsearch.common.collect.List.of(10.0, 20.0, 30.0); private String jobId; private String sourceIndex; @@ -399,7 +398,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { this.destIndex = sourceIndex + "_results"; } - private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) { + static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) { client().admin().indices().prepareCreate(sourceIndex) .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double",