diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 0341a634c9a..1d9f84471a6 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -48,34 +48,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { initialize("regression_single_numeric_feature_and_mixed_data_set"); - - { // Index 350 rows, 300 of them being training rows. - client().admin().indices().prepareCreate(sourceIndex) - .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double") - .get(); - - BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - for (int i = 0; i < 300; i++) { - Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); - Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3); - - IndexRequest indexRequest = new IndexRequest(sourceIndex) - .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value); - bulkRequestBuilder.add(indexRequest); - } - for (int i = 300; i < 350; i++) { - Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); - - IndexRequest indexRequest = new IndexRequest(sourceIndex) - .source(NUMERICAL_FEATURE_FIELD, field); - bulkRequestBuilder.add(indexRequest); - } - BulkResponse bulkResponse = bulkRequestBuilder.get(); - if (bulkResponse.hasFailures()) { - fail("Failed to index data: " + bulkResponse.buildFailureMessage()); - } - } + indexData(sourceIndex, 300, 50); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD)); registerAnalytics(config); @@ -120,23 +93,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception { initialize("regression_only_training_data_and_training_percent_is_100"); - - { // Index 350 rows, all of them being training rows. - BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - for (int i = 0; i < 350; i++) { - Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); - Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3); - - IndexRequest indexRequest = new IndexRequest(sourceIndex) - .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value); - bulkRequestBuilder.add(indexRequest); - } - BulkResponse bulkResponse = bulkRequestBuilder.get(); - if (bulkResponse.hasFailures()) { - fail("Failed to index data: " + bulkResponse.buildFailureMessage()); - } - } + indexData(sourceIndex, 350, 0); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD)); registerAnalytics(config); @@ -173,23 +130,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception { initialize("regression_only_training_data_and_training_percent_is_50"); - - { // Index 350 rows, all of them being training rows. - BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - for (int i = 0; i < 350; i++) { - Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); - Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3); - - IndexRequest indexRequest = new IndexRequest(sourceIndex) - .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value); - bulkRequestBuilder.add(indexRequest); - } - BulkResponse bulkResponse = bulkRequestBuilder.get(); - if (bulkResponse.hasFailures()) { - fail("Failed to index data: " + bulkResponse.buildFailureMessage()); - } - } + indexData(sourceIndex, 350, 0); DataFrameAnalyticsConfig config = buildAnalytics( @@ -242,21 +183,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { @AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/47612") public void testStopAndRestart() throws Exception { initialize("regression_stop_and_restart"); - - BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - for (int i = 0; i < 350; i++) { - Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); - Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3); - - IndexRequest indexRequest = new IndexRequest(sourceIndex) - .source("feature", field, "variable", value); - bulkRequestBuilder.add(indexRequest); - } - BulkResponse bulkResponse = bulkRequestBuilder.get(); - if (bulkResponse.hasFailures()) { - fail("Failed to index data: " + bulkResponse.buildFailureMessage()); - } + indexData(sourceIndex, 350, 0); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD)); registerAnalytics(config); @@ -310,6 +237,31 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { this.destIndex = sourceIndex + "_results"; } + private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) { + client().admin().indices().prepareCreate(sourceIndex) + .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double") + .get(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < numTrainingRows; i++) { + List source = Arrays.asList( + NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()), + DEPENDENT_VARIABLE_FIELD, DEPENDENT_VARIABLE_VALUES.get(i % DEPENDENT_VARIABLE_VALUES.size())); + IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()); + bulkRequestBuilder.add(indexRequest); + } + for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) { + List source = Arrays.asList(NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size())); + IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + } + private static Map getDestDoc(DataFrameAnalyticsConfig config, SearchHit hit) { GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); assertThat(destDocGetResponse.isExists(), is(true));