This commit is contained in:
parent
17358b5af7
commit
42bb8ae525
|
@ -48,34 +48,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||||
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
||||||
|
indexData(sourceIndex, 300, 50);
|
||||||
{ // Index 350 rows, 300 of them being training rows.
|
|
||||||
client().admin().indices().prepareCreate(sourceIndex)
|
|
||||||
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double")
|
|
||||||
.get();
|
|
||||||
|
|
||||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
|
||||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
|
||||||
for (int i = 0; i < 300; i++) {
|
|
||||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
|
||||||
Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
|
||||||
|
|
||||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
|
||||||
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
|
||||||
bulkRequestBuilder.add(indexRequest);
|
|
||||||
}
|
|
||||||
for (int i = 300; i < 350; i++) {
|
|
||||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
|
||||||
|
|
||||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
|
||||||
.source(NUMERICAL_FEATURE_FIELD, field);
|
|
||||||
bulkRequestBuilder.add(indexRequest);
|
|
||||||
}
|
|
||||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
|
||||||
if (bulkResponse.hasFailures()) {
|
|
||||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
||||||
registerAnalytics(config);
|
registerAnalytics(config);
|
||||||
|
@ -120,23 +93,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
||||||
initialize("regression_only_training_data_and_training_percent_is_100");
|
initialize("regression_only_training_data_and_training_percent_is_100");
|
||||||
|
indexData(sourceIndex, 350, 0);
|
||||||
{ // Index 350 rows, all of them being training rows.
|
|
||||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
|
||||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
|
||||||
for (int i = 0; i < 350; i++) {
|
|
||||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
|
||||||
Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
|
||||||
|
|
||||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
|
||||||
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
|
||||||
bulkRequestBuilder.add(indexRequest);
|
|
||||||
}
|
|
||||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
|
||||||
if (bulkResponse.hasFailures()) {
|
|
||||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
||||||
registerAnalytics(config);
|
registerAnalytics(config);
|
||||||
|
@ -173,23 +130,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
||||||
initialize("regression_only_training_data_and_training_percent_is_50");
|
initialize("regression_only_training_data_and_training_percent_is_50");
|
||||||
|
indexData(sourceIndex, 350, 0);
|
||||||
{ // Index 350 rows, all of them being training rows.
|
|
||||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
|
||||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
|
||||||
for (int i = 0; i < 350; i++) {
|
|
||||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
|
||||||
Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
|
||||||
|
|
||||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
|
||||||
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
|
||||||
bulkRequestBuilder.add(indexRequest);
|
|
||||||
}
|
|
||||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
|
||||||
if (bulkResponse.hasFailures()) {
|
|
||||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config =
|
DataFrameAnalyticsConfig config =
|
||||||
buildAnalytics(
|
buildAnalytics(
|
||||||
|
@ -242,21 +183,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
@AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/47612")
|
@AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/47612")
|
||||||
public void testStopAndRestart() throws Exception {
|
public void testStopAndRestart() throws Exception {
|
||||||
initialize("regression_stop_and_restart");
|
initialize("regression_stop_and_restart");
|
||||||
|
indexData(sourceIndex, 350, 0);
|
||||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
|
||||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
|
||||||
for (int i = 0; i < 350; i++) {
|
|
||||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
|
||||||
Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
|
||||||
|
|
||||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
|
||||||
.source("feature", field, "variable", value);
|
|
||||||
bulkRequestBuilder.add(indexRequest);
|
|
||||||
}
|
|
||||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
|
||||||
if (bulkResponse.hasFailures()) {
|
|
||||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
|
||||||
}
|
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
||||||
registerAnalytics(config);
|
registerAnalytics(config);
|
||||||
|
@ -310,6 +237,31 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
this.destIndex = sourceIndex + "_results";
|
this.destIndex = sourceIndex + "_results";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) {
|
||||||
|
client().admin().indices().prepareCreate(sourceIndex)
|
||||||
|
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double")
|
||||||
|
.get();
|
||||||
|
|
||||||
|
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||||
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||||
|
for (int i = 0; i < numTrainingRows; i++) {
|
||||||
|
List<Object> source = Arrays.asList(
|
||||||
|
NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()),
|
||||||
|
DEPENDENT_VARIABLE_FIELD, DEPENDENT_VARIABLE_VALUES.get(i % DEPENDENT_VARIABLE_VALUES.size()));
|
||||||
|
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
|
||||||
|
bulkRequestBuilder.add(indexRequest);
|
||||||
|
}
|
||||||
|
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
|
||||||
|
List<Object> source = Arrays.asList(NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()));
|
||||||
|
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
|
||||||
|
bulkRequestBuilder.add(indexRequest);
|
||||||
|
}
|
||||||
|
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||||
|
if (bulkResponse.hasFailures()) {
|
||||||
|
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, SearchHit hit) {
|
private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, SearchHit hit) {
|
||||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||||
assertThat(destDocGetResponse.isExists(), is(true));
|
assertThat(destDocGetResponse.isExists(), is(true));
|
||||||
|
|
Loading…
Reference in New Issue