From bfcfcdee33999ef4f9ad74aeec43b26a15bd5ad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Wed, 22 Jan 2020 12:36:24 +0100 Subject: [PATCH] [7.x] Do not copy mapping from dependent variable to prediction field in regression analysis (#51227) (#51288) --- .../ml/dataframe/analyses/Classification.java | 32 +++++++-- .../dataframe/analyses/DataFrameAnalysis.java | 10 ++- .../dataframe/analyses/OutlierDetection.java | 2 +- .../ml/dataframe/analyses/Regression.java | 6 +- .../analyses/ClassificationTests.java | 37 +++++++++- .../analyses/OutlierDetectionTests.java | 4 +- .../dataframe/analyses/RegressionTests.java | 8 ++- .../ml/integration/ClassificationIT.java | 47 ++----------- ...NativeDataFrameAnalyticsIntegTestCase.java | 35 ++++++++++ .../xpack/ml/integration/RegressionIT.java | 68 +++++++++++++++---- .../ml/dataframe/DataFrameAnalyticsIndex.java | 27 +------- .../DataFrameAnalyticsIndexTests.java | 4 +- 12 files changed, 179 insertions(+), 101 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 24b814d19ed..89de6cadd83 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.mapper.FieldAliasMapper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -28,6 +29,7 @@ import java.util.stream.Stream; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue; public class Classification implements DataFrameAnalysis { @@ -248,12 +250,32 @@ public class Classification implements DataFrameAnalysis { return Collections.singletonMap(dependentVariable, 2L); } + @SuppressWarnings("unchecked") @Override - public Map getExplicitlyMappedFields(String resultsFieldName) { - return new HashMap() {{ - put(resultsFieldName + "." + predictionFieldName, dependentVariable); - put(resultsFieldName + ".top_classes.class_name", dependentVariable); - }}; + public Map getExplicitlyMappedFields(Map mappingsProperties, String resultsFieldName) { + Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties); + if ((dependentVariableMapping instanceof Map) == false) { + return Collections.emptyMap(); + } + Map dependentVariableMappingAsMap = (Map) dependentVariableMapping; + // If the source field is an alias, fetch the concrete field that the alias points to. + if (FieldAliasMapper.CONTENT_TYPE.equals(dependentVariableMappingAsMap.get("type"))) { + String path = (String) dependentVariableMappingAsMap.get(FieldAliasMapper.Names.PATH); + dependentVariableMapping = extractMapping(path, mappingsProperties); + } + // We may have updated the value of {@code dependentVariableMapping} in the "if" block above. + // Hence, we need to check the "instanceof" condition again. + if ((dependentVariableMapping instanceof Map) == false) { + return Collections.emptyMap(); + } + Map additionalProperties = new HashMap<>(); + additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping); + additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping); + return additionalProperties; + } + + private static Object extractMapping(String path, Map mappingsProperties) { + return extractValue(String.join(".properties.", path.split("\\.")), mappingsProperties); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java index 74cdc5824cb..e79458abe38 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -42,15 +42,13 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { Map getFieldCardinalityLimits(); /** - * Returns fields for which the mappings should be copied from source index to destination index. - * Each entry of the returned {@link Map} is of the form: - * key - field path in the destination index - * value - field path in the source index from which the mapping should be taken + * Returns fields for which the mappings should be either predefined or copied from source index to destination index. * + * @param mappingsProperties mappings.properties portion of the index mappings * @param resultsFieldName name of the results field under which all the results are stored - * @return {@link Map} containing fields for which the mappings should be copied from source index to destination index + * @return {@link Map} containing fields for which the mappings should be handled explicitly */ - Map getExplicitlyMappedFields(String resultsFieldName); + Map getExplicitlyMappedFields(Map mappingsProperties, String resultsFieldName); /** * @return {@code true} if this analysis supports data frame rows with missing values diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index 81c46738093..af7d4d79ae3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -230,7 +230,7 @@ public class OutlierDetection implements DataFrameAnalysis { } @Override - public Map getExplicitlyMappedFields(String resultsFieldName) { + public Map getExplicitlyMappedFields(Map mappingsProperties, String resultsFieldName) { return Collections.emptyMap(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 83174a9aebf..996e28c60e9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -187,8 +187,10 @@ public class Regression implements DataFrameAnalysis { } @Override - public Map getExplicitlyMappedFields(String resultsFieldName) { - return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, dependentVariable); + public Map getExplicitlyMappedFields(Map mappingsProperties, String resultsFieldName) { + // Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of + // high (over 10M) values of dependent variable. + return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, Collections.singletonMap("type", "double")); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 55afb76ef5c..5e1b87ff483 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -25,6 +25,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Set; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; @@ -171,8 +172,40 @@ public class ClassificationTests extends AbstractSerializingTestCase() {{ + put("foo", new HashMap() {{ + put("type", "alias"); + put("path", "bar"); + }}); + put("bar", Collections.singletonMap("type", "long")); + }}, + "results"), + allOf( + hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")), + hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long")))); + assertThat( + new Classification("foo").getExplicitlyMappedFields( + Collections.singletonMap("foo", new HashMap() {{ + put("type", "alias"); + put("path", "missing"); + }}), + "results"), + is(anEmptyMap())); } public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java index 5b7a23b46ff..4ac525e6e48 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java @@ -92,8 +92,8 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase { return createRandom(); } - public static Regression createRandom() { + private static Regression createRandom() { String dependentVariableName = randomAlphaOfLength(10); BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom(); String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10); @@ -110,8 +110,10 @@ public class RegressionTests extends AbstractSerializingTestCase { assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap())); } - public void testFieldMappingsToCopyIsNonEmpty() { - assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap()))); + public void testGetExplicitlyMappedFields() { + assertThat( + new Regression("foo").getExplicitlyMappedFields(null, "results"), + hasEntry("results.foo_prediction", Collections.singletonMap("type", "double"))); } public void testGetStateDocId() { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index a6165ceb357..43bdc91e660 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.ml.integration; import com.google.common.collect.Ordering; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.admin.indices.get.GetIndexAction; -import org.elasticsearch.action.admin.indices.get.GetIndexRequest; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; @@ -42,7 +40,6 @@ import java.util.Map; import java.util.Set; import static java.util.stream.Collectors.toList; -import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; @@ -116,7 +113,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); - assertMlResultsFieldMappings(predictedClassField, "keyword"); + assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword"); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [classification]", "Estimated memory usage for this analytics to be", @@ -157,7 +154,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); - assertMlResultsFieldMappings(predictedClassField, "keyword"); + assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword"); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [classification]", "Estimated memory usage for this analytics to be", @@ -220,7 +217,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); - assertMlResultsFieldMappings(predictedClassField, expectedMappingTypeForPredictedField); + assertMlResultsFieldMappings(destIndex, predictedClassField, expectedMappingTypeForPredictedField); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [classification]", "Estimated memory usage for this analytics to be", @@ -308,7 +305,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); - assertMlResultsFieldMappings(predictedClassField, "keyword"); + assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword"); assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); } @@ -365,7 +362,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); - assertMlResultsFieldMappings(predictedClassField, "keyword"); + assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword"); assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); } @@ -384,7 +381,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); - assertMlResultsFieldMappings(predictedClassField, "keyword"); + assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword"); assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); } @@ -403,7 +400,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); - assertMlResultsFieldMappings(predictedClassField, "keyword"); + assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword"); assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); } @@ -564,15 +561,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { return destDoc; } - /** - * Wrapper around extractValue that: - * - allows dots (".") in the path elements provided as arguments - * - supports implicit casting to the appropriate type - */ - private static T getFieldValue(Map doc, String... path) { - return (T)extractValue(String.join(".", path), doc); - } - private static void assertTopClasses(Map resultsObject, int numTopClasses, String dependentVariable, @@ -656,27 +644,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { } } - private void assertMlResultsFieldMappings(String predictedClassField, String expectedType) { - Map mappings = - client() - .execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(destIndex)) - .actionGet() - .mappings() - .get(destIndex) - .get("_doc") - .sourceAsMap(); - assertThat( - mappings.toString(), - getFieldValue( - mappings, - "properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"), - equalTo(expectedType)); - assertThat( - mappings.toString(), - getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"), - equalTo(expectedType)); - } - private String stateDocId() { return jobId + "_classification_state#1"; } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index d99f58e608e..2c586b34e28 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -5,6 +5,8 @@ */ package org.elasticsearch.xpack.ml.integration; +import org.elasticsearch.action.admin.indices.get.GetIndexAction; +import org.elasticsearch.action.admin.indices.get.GetIndexRequest; import org.elasticsearch.action.admin.indices.refresh.RefreshAction; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.admin.indices.refresh.RefreshResponse; @@ -53,6 +55,7 @@ import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; +import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.arrayWithSize; import static org.hamcrest.Matchers.equalTo; @@ -281,4 +284,36 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest .get(); assertThat(searchResponse.getHits().getHits().length, equalTo(1)); } + + protected static void assertMlResultsFieldMappings(String index, String predictedClassField, String expectedType) { + Map mappings = + client() + .execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(index)) + .actionGet() + .mappings() + .get(index) + .get("_doc") + .sourceAsMap(); + assertThat( + mappings.toString(), + getFieldValue( + mappings, + "properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"), + equalTo(expectedType)); + if (getFieldValue(mappings, "properties", "ml", "properties", "top_classes") != null) { + assertThat( + mappings.toString(), + getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"), + equalTo(expectedType)); + } + } + + /** + * Wrapper around extractValue that: + * - allows dots (".") in the path elements provided as arguments + * - supports implicit casting to the appropriate type + */ + protected static T getFieldValue(Map doc, String... path) { + return (T)extractValue(String.join(".", path), doc); + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 8b7350d9e13..cb0147a6de5 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -35,8 +35,10 @@ import static org.hamcrest.Matchers.is; public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static final String NUMERICAL_FEATURE_FIELD = "feature"; + private static final String DISCRETE_NUMERICAL_FEATURE_FIELD = "discrete-feature"; private static final String DEPENDENT_VARIABLE_FIELD = "variable"; private static final List NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0)); + private static final List DISCRETE_NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(10L, 20L, 30L)); private static final List DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList(10.0, 20.0, 30.0)); private String jobId; @@ -50,6 +52,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { initialize("regression_single_numeric_feature_and_mixed_data_set"); + String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction"; indexData(sourceIndex, 300, 50); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, @@ -78,19 +81,24 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { // it seems for this case values can be as far off as 2.0 // double featureValue = (double) destDoc.get(NUMERICAL_FEATURE_FIELD); - // double predictionValue = (double) resultsObject.get("variable_prediction"); + // double predictionValue = (double) resultsObject.get(predictedClassField); // assertThat(predictionValue, closeTo(10 * featureValue, 2.0)); - assertThat(resultsObject.containsKey("variable_prediction"), is(true)); + assertThat(resultsObject.containsKey(predictedClassField), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD))); - assertThat(resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD), is(true)); + assertThat( + resultsObject.toString(), + resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD) + || resultsObject.containsKey("feature_importance." + DISCRETE_NUMERICAL_FEATURE_FIELD), + is(true)); } assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(destIndex, predictedClassField, "double"); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [regression]", "Estimated memory usage for this analytics to be", @@ -103,6 +111,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception { initialize("regression_only_training_data_and_training_percent_is_100"); + String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction"; indexData(sourceIndex, 350, 0); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD)); @@ -119,7 +128,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { for (SearchHit hit : sourceData.getHits()) { Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); - assertThat(resultsObject.containsKey("variable_prediction"), is(true)); + assertThat(resultsObject.containsKey(predictedClassField), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.get("is_training"), is(true)); } @@ -128,6 +137,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(destIndex, predictedClassField, "double"); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [regression]", "Estimated memory usage for this analytics to be", @@ -140,6 +150,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception { initialize("regression_only_training_data_and_training_percent_is_50"); + String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction"; indexData(sourceIndex, 350, 0); DataFrameAnalyticsConfig config = @@ -164,7 +175,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { for (SearchHit hit : sourceData.getHits()) { Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); - assertThat(resultsObject.containsKey("variable_prediction"), is(true)); + assertThat(resultsObject.containsKey(predictedClassField), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true)); // Let's just assert there's both training and non-training results if ((boolean) resultsObject.get("is_training")) { @@ -180,6 +191,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(destIndex, predictedClassField, "double"); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [regression]", "Estimated memory usage for this analytics to be", @@ -192,6 +204,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { public void testStopAndRestart() throws Exception { initialize("regression_stop_and_restart"); + String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction"; indexData(sourceIndex, 350, 0); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD)); @@ -233,7 +246,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { for (SearchHit hit : sourceData.getHits()) { Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); - assertThat(resultsObject.containsKey("variable_prediction"), is(true)); + assertThat(resultsObject.containsKey(predictedClassField), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.get("is_training"), is(true)); } @@ -242,6 +255,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(destIndex, predictedClassField, "double"); } public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception { @@ -289,6 +303,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { public void testDeleteExpiredData_RemovesUnusedState() throws Exception { initialize("regression_delete_expired_data"); + String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction"; indexData(sourceIndex, 100, 0); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD)); @@ -301,6 +316,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(destIndex, predictedClassField, "double"); // Call _delete_expired_data API and check nothing was deleted assertThat(deleteExpiredData().isDeleted(), is(true)); @@ -319,6 +335,31 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { assertThat(stateIndexSearchResponse.getHits().getTotalHits().value, equalTo(0L)); } + public void testDependentVariableIsLong() throws Exception { + initialize("regression_dependent_variable_is_long"); + String predictedClassField = DISCRETE_NUMERICAL_FEATURE_FIELD + "_prediction"; + indexData(sourceIndex, 100, 0); + + DataFrameAnalyticsConfig config = + buildAnalytics( + jobId, + sourceIndex, + destIndex, + null, + new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null)); + registerAnalytics(config); + putAnalytics(config); + + assertIsStopped(jobId); + assertProgress(jobId, 0, 0, 0, 0); + + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + assertProgress(jobId, 100, 100, 100, 100); + + assertMlResultsFieldMappings(destIndex, predictedClassField, "double"); + } + private void initialize(String jobId) { this.jobId = jobId; this.sourceIndex = jobId + "_source_index"; @@ -327,7 +368,10 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) { client().admin().indices().prepareCreate(sourceIndex) - .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double") + .addMapping("_doc", + NUMERICAL_FEATURE_FIELD, "type=double", + DISCRETE_NUMERICAL_FEATURE_FIELD, "type=long", + DEPENDENT_VARIABLE_FIELD, "type=double") .get(); BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() @@ -335,12 +379,15 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { for (int i = 0; i < numTrainingRows; i++) { List source = Arrays.asList( NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()), + DISCRETE_NUMERICAL_FEATURE_FIELD, DISCRETE_NUMERICAL_FEATURE_VALUES.get(i % DISCRETE_NUMERICAL_FEATURE_VALUES.size()), DEPENDENT_VARIABLE_FIELD, DEPENDENT_VARIABLE_VALUES.get(i % DEPENDENT_VARIABLE_VALUES.size())); IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()); bulkRequestBuilder.add(indexRequest); } for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) { - List source = Arrays.asList(NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size())); + List source = Arrays.asList( + NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()), + DISCRETE_NUMERICAL_FEATURE_FIELD, DISCRETE_NUMERICAL_FEATURE_VALUES.get(i % DISCRETE_NUMERICAL_FEATURE_VALUES.size())); IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()); bulkRequestBuilder.add(indexRequest); } @@ -363,10 +410,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { } private static Map getMlResultsObjectFromDestDoc(Map destDoc) { - assertThat(destDoc.containsKey("ml"), is(true)); - @SuppressWarnings("unchecked") - Map resultsObject = (Map) destDoc.get("ml"); - return resultsObject; + return getFieldValue(destDoc, "ml"); } protected String stateDocId() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java index b1b97c0b103..f7387aa4da9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java @@ -25,7 +25,6 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexSortConfig; -import org.elasticsearch.index.mapper.FieldAliasMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ClientHelper; @@ -41,7 +40,6 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; -import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; /** @@ -163,38 +161,15 @@ public final class DataFrameAnalyticsIndex { return maxValue; } - @SuppressWarnings("unchecked") private static Map createAdditionalMappings(DataFrameAnalyticsConfig config, Map mappingsProperties) { Map properties = new HashMap<>(); Map idCopyMapping = new HashMap<>(); idCopyMapping.put("type", KeywordFieldMapper.CONTENT_TYPE); properties.put(ID_COPY, idCopyMapping); - for (Map.Entry entry - : config.getAnalysis().getExplicitlyMappedFields(config.getDest().getResultsField()).entrySet()) { - String destFieldPath = entry.getKey(); - String sourceFieldPath = entry.getValue(); - Object sourceFieldMapping = extractMapping(sourceFieldPath, mappingsProperties); - if (sourceFieldMapping instanceof Map) { - Map sourceFieldMappingAsMap = (Map) sourceFieldMapping; - // If the source field is an alias, fetch the concrete field that the alias points to. - if (FieldAliasMapper.CONTENT_TYPE.equals(sourceFieldMappingAsMap.get("type"))) { - String path = (String) sourceFieldMappingAsMap.get(FieldAliasMapper.Names.PATH); - sourceFieldMapping = extractMapping(path, mappingsProperties); - } - } - // We may have updated the value of {@code sourceFieldMapping} in the "if" block above. - // Hence, we need to check the "instanceof" condition again. - if (sourceFieldMapping instanceof Map) { - properties.put(destFieldPath, sourceFieldMapping); - } - } + properties.putAll(config.getAnalysis().getExplicitlyMappedFields(mappingsProperties, config.getDest().getResultsField())); return properties; } - private static Object extractMapping(String path, Map mappingsProperties) { - return extractValue(String.join("." + PROPERTIES + ".", path.split("\\.")), mappingsProperties); - } - private static Map createMetaData(String analyticsId, Clock clock) { Map metadata = new HashMap<>(); metadata.put(CREATION_DATE_MILLIS, clock.millis()); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndexTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndexTests.java index c2564842c8b..785dbe869d8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndexTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndexTests.java @@ -203,7 +203,7 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase { public void testCreateDestinationIndex_Regression() throws IOException { Map map = testCreateDestinationIndex(new Regression(NUMERICAL_FIELD)); - assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("integer")); + assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("double")); } public void testCreateDestinationIndex_Classification() throws IOException { @@ -319,7 +319,7 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase { public void testUpdateMappingsToDestIndex_Regression() throws IOException { Map map = testUpdateMappingsToDestIndex(new Regression(NUMERICAL_FIELD)); - assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("integer")); + assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("double")); } public void testUpdateMappingsToDestIndex_Classification() throws IOException {