[7.x] Do not copy mapping from dependent variable to prediction field in regression analysis (#51227) (#51288)

This commit is contained in:
Przemysław Witek 2020-01-22 12:36:24 +01:00 committed by GitHub
parent 1009f92b03
commit bfcfcdee33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 179 additions and 101 deletions

View File

@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.FieldAliasMapper;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -28,6 +29,7 @@ import java.util.stream.Stream;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
public class Classification implements DataFrameAnalysis {
@ -248,12 +250,32 @@ public class Classification implements DataFrameAnalysis {
return Collections.singletonMap(dependentVariable, 2L);
}
@SuppressWarnings("unchecked")
@Override
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
return new HashMap<String, String>() {{
put(resultsFieldName + "." + predictionFieldName, dependentVariable);
put(resultsFieldName + ".top_classes.class_name", dependentVariable);
}};
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
if ((dependentVariableMapping instanceof Map) == false) {
return Collections.emptyMap();
}
Map<String, Object> dependentVariableMappingAsMap = (Map) dependentVariableMapping;
// If the source field is an alias, fetch the concrete field that the alias points to.
if (FieldAliasMapper.CONTENT_TYPE.equals(dependentVariableMappingAsMap.get("type"))) {
String path = (String) dependentVariableMappingAsMap.get(FieldAliasMapper.Names.PATH);
dependentVariableMapping = extractMapping(path, mappingsProperties);
}
// We may have updated the value of {@code dependentVariableMapping} in the "if" block above.
// Hence, we need to check the "instanceof" condition again.
if ((dependentVariableMapping instanceof Map) == false) {
return Collections.emptyMap();
}
Map<String, Object> additionalProperties = new HashMap<>();
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
return additionalProperties;
}
private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
return extractValue(String.join(".properties.", path.split("\\.")), mappingsProperties);
}
@Override

View File

@ -42,15 +42,13 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
Map<String, Long> getFieldCardinalityLimits();
/**
* Returns fields for which the mappings should be copied from source index to destination index.
* Each entry of the returned {@link Map} is of the form:
* key - field path in the destination index
* value - field path in the source index from which the mapping should be taken
* Returns fields for which the mappings should be either predefined or copied from source index to destination index.
*
* @param mappingsProperties mappings.properties portion of the index mappings
* @param resultsFieldName name of the results field under which all the results are stored
* @return {@link Map} containing fields for which the mappings should be copied from source index to destination index
* @return {@link Map} containing fields for which the mappings should be handled explicitly
*/
Map<String, String> getExplicitlyMappedFields(String resultsFieldName);
Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName);
/**
* @return {@code true} if this analysis supports data frame rows with missing values

View File

@ -230,7 +230,7 @@ public class OutlierDetection implements DataFrameAnalysis {
}
@Override
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
return Collections.emptyMap();
}

View File

@ -187,8 +187,10 @@ public class Regression implements DataFrameAnalysis {
}
@Override
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, dependentVariable);
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
// high (over 10M) values of dependent variable.
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, Collections.singletonMap("type", "double"));
}
@Override

View File

@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
@ -171,8 +172,40 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(anEmptyMap())));
}
public void testFieldMappingsToCopyIsNonEmpty() {
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
public void testGetExplicitlyMappedFields() {
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"), is(anEmptyMap()));
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"), is(anEmptyMap()));
assertThat(
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
is(anEmptyMap()));
assertThat(
new Classification("foo").getExplicitlyMappedFields(
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
"results"),
allOf(
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
assertThat(
new Classification("foo").getExplicitlyMappedFields(
new HashMap<String, Object>() {{
put("foo", new HashMap<String, String>() {{
put("type", "alias");
put("path", "bar");
}});
put("bar", Collections.singletonMap("type", "long"));
}},
"results"),
allOf(
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
assertThat(
new Classification("foo").getExplicitlyMappedFields(
Collections.singletonMap("foo", new HashMap<String, String>() {{
put("type", "alias");
put("path", "missing");
}}),
"results"),
is(anEmptyMap()));
}
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {

View File

@ -92,8 +92,8 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
}
public void testFieldMappingsToCopyIsEmpty() {
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(anEmptyMap()));
public void testGetExplicitlyMappedFields() {
assertThat(createTestInstance().getExplicitlyMappedFields(null, null), is(anEmptyMap()));
}
public void testGetStateDocId() {

View File

@ -43,7 +43,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
return createRandom();
}
public static Regression createRandom() {
private static Regression createRandom() {
String dependentVariableName = randomAlphaOfLength(10);
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
@ -110,8 +110,10 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
}
public void testFieldMappingsToCopyIsNonEmpty() {
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
public void testGetExplicitlyMappedFields() {
assertThat(
new Regression("foo").getExplicitlyMappedFields(null, "results"),
hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
}
public void testGetStateDocId() {

View File

@ -7,8 +7,6 @@ package org.elasticsearch.xpack.ml.integration;
import com.google.common.collect.Ordering;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
@ -42,7 +40,6 @@ import java.util.Map;
import java.util.Set;
import static java.util.stream.Collectors.toList;
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.equalTo;
@ -116,7 +113,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be",
@ -157,7 +154,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be",
@ -220,7 +217,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, expectedMappingTypeForPredictedField);
assertMlResultsFieldMappings(destIndex, predictedClassField, expectedMappingTypeForPredictedField);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be",
@ -308,7 +305,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}
@ -365,7 +362,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}
@ -384,7 +381,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}
@ -403,7 +400,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}
@ -564,15 +561,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
return destDoc;
}
/**
* Wrapper around extractValue that:
* - allows dots (".") in the path elements provided as arguments
* - supports implicit casting to the appropriate type
*/
private static <T> T getFieldValue(Map<String, Object> doc, String... path) {
return (T)extractValue(String.join(".", path), doc);
}
private static <T> void assertTopClasses(Map<String, Object> resultsObject,
int numTopClasses,
String dependentVariable,
@ -656,27 +644,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
}
}
private void assertMlResultsFieldMappings(String predictedClassField, String expectedType) {
Map<String, Object> mappings =
client()
.execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(destIndex))
.actionGet()
.mappings()
.get(destIndex)
.get("_doc")
.sourceAsMap();
assertThat(
mappings.toString(),
getFieldValue(
mappings,
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
equalTo(expectedType));
assertThat(
mappings.toString(),
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
equalTo(expectedType));
}
private String stateDocId() {
return jobId + "_classification_state#1";
}

View File

@ -5,6 +5,8 @@
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
@ -53,6 +55,7 @@ import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.equalTo;
@ -281,4 +284,36 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
.get();
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
}
protected static void assertMlResultsFieldMappings(String index, String predictedClassField, String expectedType) {
Map<String, Object> mappings =
client()
.execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(index))
.actionGet()
.mappings()
.get(index)
.get("_doc")
.sourceAsMap();
assertThat(
mappings.toString(),
getFieldValue(
mappings,
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
equalTo(expectedType));
if (getFieldValue(mappings, "properties", "ml", "properties", "top_classes") != null) {
assertThat(
mappings.toString(),
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
equalTo(expectedType));
}
}
/**
* Wrapper around extractValue that:
* - allows dots (".") in the path elements provided as arguments
* - supports implicit casting to the appropriate type
*/
protected static <T> T getFieldValue(Map<String, Object> doc, String... path) {
return (T)extractValue(String.join(".", path), doc);
}
}

View File

@ -35,8 +35,10 @@ import static org.hamcrest.Matchers.is;
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String NUMERICAL_FEATURE_FIELD = "feature";
private static final String DISCRETE_NUMERICAL_FEATURE_FIELD = "discrete-feature";
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
private static final List<Long> DISCRETE_NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(10L, 20L, 30L));
private static final List<Double> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList(10.0, 20.0, 30.0));
private String jobId;
@ -50,6 +52,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
initialize("regression_single_numeric_feature_and_mixed_data_set");
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
indexData(sourceIndex, 300, 50);
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null,
@ -78,19 +81,24 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
// it seems for this case values can be as far off as 2.0
// double featureValue = (double) destDoc.get(NUMERICAL_FEATURE_FIELD);
// double predictionValue = (double) resultsObject.get("variable_prediction");
// double predictionValue = (double) resultsObject.get(predictedClassField);
// assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat(resultsObject.containsKey(predictedClassField), is(true));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
assertThat(resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD), is(true));
assertThat(
resultsObject.toString(),
resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD)
|| resultsObject.containsKey("feature_importance." + DISCRETE_NUMERICAL_FEATURE_FIELD),
is(true));
}
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
"Estimated memory usage for this analytics to be",
@ -103,6 +111,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
initialize("regression_only_training_data_and_training_percent_is_100");
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
indexData(sourceIndex, 350, 0);
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
@ -119,7 +128,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat(resultsObject.containsKey(predictedClassField), is(true));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(true));
}
@ -128,6 +137,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
"Estimated memory usage for this analytics to be",
@ -140,6 +150,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
initialize("regression_only_training_data_and_training_percent_is_50");
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
indexData(sourceIndex, 350, 0);
DataFrameAnalyticsConfig config =
@ -164,7 +175,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat(resultsObject.containsKey(predictedClassField), is(true));
assertThat(resultsObject.containsKey("is_training"), is(true));
// Let's just assert there's both training and non-training results
if ((boolean) resultsObject.get("is_training")) {
@ -180,6 +191,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
"Estimated memory usage for this analytics to be",
@ -192,6 +204,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
public void testStopAndRestart() throws Exception {
initialize("regression_stop_and_restart");
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
indexData(sourceIndex, 350, 0);
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
@ -233,7 +246,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat(resultsObject.containsKey(predictedClassField), is(true));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(true));
}
@ -242,6 +255,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
}
public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
@ -289,6 +303,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
public void testDeleteExpiredData_RemovesUnusedState() throws Exception {
initialize("regression_delete_expired_data");
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
indexData(sourceIndex, 100, 0);
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
@ -301,6 +316,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
// Call _delete_expired_data API and check nothing was deleted
assertThat(deleteExpiredData().isDeleted(), is(true));
@ -319,6 +335,31 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(stateIndexSearchResponse.getHits().getTotalHits().value, equalTo(0L));
}
public void testDependentVariableIsLong() throws Exception {
initialize("regression_dependent_variable_is_long");
String predictedClassField = DISCRETE_NUMERICAL_FEATURE_FIELD + "_prediction";
indexData(sourceIndex, 100, 0);
DataFrameAnalyticsConfig config =
buildAnalytics(
jobId,
sourceIndex,
destIndex,
null,
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null));
registerAnalytics(config);
putAnalytics(config);
assertIsStopped(jobId);
assertProgress(jobId, 0, 0, 0, 0);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
assertProgress(jobId, 100, 100, 100, 100);
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
}
private void initialize(String jobId) {
this.jobId = jobId;
this.sourceIndex = jobId + "_source_index";
@ -327,7 +368,10 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) {
client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double")
.addMapping("_doc",
NUMERICAL_FEATURE_FIELD, "type=double",
DISCRETE_NUMERICAL_FEATURE_FIELD, "type=long",
DEPENDENT_VARIABLE_FIELD, "type=double")
.get();
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
@ -335,12 +379,15 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
for (int i = 0; i < numTrainingRows; i++) {
List<Object> source = Arrays.asList(
NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()),
DISCRETE_NUMERICAL_FEATURE_FIELD, DISCRETE_NUMERICAL_FEATURE_VALUES.get(i % DISCRETE_NUMERICAL_FEATURE_VALUES.size()),
DEPENDENT_VARIABLE_FIELD, DEPENDENT_VARIABLE_VALUES.get(i % DEPENDENT_VARIABLE_VALUES.size()));
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
bulkRequestBuilder.add(indexRequest);
}
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
List<Object> source = Arrays.asList(NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()));
List<Object> source = Arrays.asList(
NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()),
DISCRETE_NUMERICAL_FEATURE_FIELD, DISCRETE_NUMERICAL_FEATURE_VALUES.get(i % DISCRETE_NUMERICAL_FEATURE_VALUES.size()));
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
bulkRequestBuilder.add(indexRequest);
}
@ -363,10 +410,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
}
private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Object> destDoc) {
assertThat(destDoc.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
return resultsObject;
return getFieldValue(destDoc, "ml");
}
protected String stateDocId() {

View File

@ -25,7 +25,6 @@ import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.ImmutableOpenMap;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSortConfig;
import org.elasticsearch.index.mapper.FieldAliasMapper;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ClientHelper;
@ -41,7 +40,6 @@ import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
/**
@ -163,38 +161,15 @@ public final class DataFrameAnalyticsIndex {
return maxValue;
}
@SuppressWarnings("unchecked")
private static Map<String, Object> createAdditionalMappings(DataFrameAnalyticsConfig config, Map<String, Object> mappingsProperties) {
Map<String, Object> properties = new HashMap<>();
Map<String, String> idCopyMapping = new HashMap<>();
idCopyMapping.put("type", KeywordFieldMapper.CONTENT_TYPE);
properties.put(ID_COPY, idCopyMapping);
for (Map.Entry<String, String> entry
: config.getAnalysis().getExplicitlyMappedFields(config.getDest().getResultsField()).entrySet()) {
String destFieldPath = entry.getKey();
String sourceFieldPath = entry.getValue();
Object sourceFieldMapping = extractMapping(sourceFieldPath, mappingsProperties);
if (sourceFieldMapping instanceof Map) {
Map<String, Object> sourceFieldMappingAsMap = (Map) sourceFieldMapping;
// If the source field is an alias, fetch the concrete field that the alias points to.
if (FieldAliasMapper.CONTENT_TYPE.equals(sourceFieldMappingAsMap.get("type"))) {
String path = (String) sourceFieldMappingAsMap.get(FieldAliasMapper.Names.PATH);
sourceFieldMapping = extractMapping(path, mappingsProperties);
}
}
// We may have updated the value of {@code sourceFieldMapping} in the "if" block above.
// Hence, we need to check the "instanceof" condition again.
if (sourceFieldMapping instanceof Map) {
properties.put(destFieldPath, sourceFieldMapping);
}
}
properties.putAll(config.getAnalysis().getExplicitlyMappedFields(mappingsProperties, config.getDest().getResultsField()));
return properties;
}
private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
return extractValue(String.join("." + PROPERTIES + ".", path.split("\\.")), mappingsProperties);
}
private static Map<String, Object> createMetaData(String analyticsId, Clock clock) {
Map<String, Object> metadata = new HashMap<>();
metadata.put(CREATION_DATE_MILLIS, clock.millis());

View File

@ -203,7 +203,7 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase {
public void testCreateDestinationIndex_Regression() throws IOException {
Map<String, Object> map = testCreateDestinationIndex(new Regression(NUMERICAL_FIELD));
assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("integer"));
assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("double"));
}
public void testCreateDestinationIndex_Classification() throws IOException {
@ -319,7 +319,7 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase {
public void testUpdateMappingsToDestIndex_Regression() throws IOException {
Map<String, Object> map = testUpdateMappingsToDestIndex(new Regression(NUMERICAL_FIELD));
assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("integer"));
assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("double"));
}
public void testUpdateMappingsToDestIndex_Classification() throws IOException {