[7.x] Do not copy mapping from dependent variable to prediction field in regression analysis (#51227) (#51288)
This commit is contained in:
parent
1009f92b03
commit
bfcfcdee33
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
|||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.mapper.FieldAliasMapper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -28,6 +29,7 @@ import java.util.stream.Stream;
|
|||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
|
||||
|
||||
public class Classification implements DataFrameAnalysis {
|
||||
|
||||
|
@ -248,12 +250,32 @@ public class Classification implements DataFrameAnalysis {
|
|||
return Collections.singletonMap(dependentVariable, 2L);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
|
||||
return new HashMap<String, String>() {{
|
||||
put(resultsFieldName + "." + predictionFieldName, dependentVariable);
|
||||
put(resultsFieldName + ".top_classes.class_name", dependentVariable);
|
||||
}};
|
||||
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
||||
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
|
||||
if ((dependentVariableMapping instanceof Map) == false) {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
Map<String, Object> dependentVariableMappingAsMap = (Map) dependentVariableMapping;
|
||||
// If the source field is an alias, fetch the concrete field that the alias points to.
|
||||
if (FieldAliasMapper.CONTENT_TYPE.equals(dependentVariableMappingAsMap.get("type"))) {
|
||||
String path = (String) dependentVariableMappingAsMap.get(FieldAliasMapper.Names.PATH);
|
||||
dependentVariableMapping = extractMapping(path, mappingsProperties);
|
||||
}
|
||||
// We may have updated the value of {@code dependentVariableMapping} in the "if" block above.
|
||||
// Hence, we need to check the "instanceof" condition again.
|
||||
if ((dependentVariableMapping instanceof Map) == false) {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
Map<String, Object> additionalProperties = new HashMap<>();
|
||||
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
|
||||
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
|
||||
return additionalProperties;
|
||||
}
|
||||
|
||||
private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
|
||||
return extractValue(String.join(".properties.", path.split("\\.")), mappingsProperties);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -42,15 +42,13 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
|
|||
Map<String, Long> getFieldCardinalityLimits();
|
||||
|
||||
/**
|
||||
* Returns fields for which the mappings should be copied from source index to destination index.
|
||||
* Each entry of the returned {@link Map} is of the form:
|
||||
* key - field path in the destination index
|
||||
* value - field path in the source index from which the mapping should be taken
|
||||
* Returns fields for which the mappings should be either predefined or copied from source index to destination index.
|
||||
*
|
||||
* @param mappingsProperties mappings.properties portion of the index mappings
|
||||
* @param resultsFieldName name of the results field under which all the results are stored
|
||||
* @return {@link Map} containing fields for which the mappings should be copied from source index to destination index
|
||||
* @return {@link Map} containing fields for which the mappings should be handled explicitly
|
||||
*/
|
||||
Map<String, String> getExplicitlyMappedFields(String resultsFieldName);
|
||||
Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName);
|
||||
|
||||
/**
|
||||
* @return {@code true} if this analysis supports data frame rows with missing values
|
||||
|
|
|
@ -230,7 +230,7 @@ public class OutlierDetection implements DataFrameAnalysis {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
|
||||
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
|
|
|
@ -187,8 +187,10 @@ public class Regression implements DataFrameAnalysis {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
|
||||
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, dependentVariable);
|
||||
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
||||
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
|
||||
// high (over 10M) values of dependent variable.
|
||||
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, Collections.singletonMap("type", "double"));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.anEmptyMap;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
|
@ -171,8 +172,40 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(anEmptyMap())));
|
||||
}
|
||||
|
||||
public void testFieldMappingsToCopyIsNonEmpty() {
|
||||
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
|
||||
public void testGetExplicitlyMappedFields() {
|
||||
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"), is(anEmptyMap()));
|
||||
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"), is(anEmptyMap()));
|
||||
assertThat(
|
||||
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
|
||||
is(anEmptyMap()));
|
||||
assertThat(
|
||||
new Classification("foo").getExplicitlyMappedFields(
|
||||
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
|
||||
"results"),
|
||||
allOf(
|
||||
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
|
||||
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
|
||||
assertThat(
|
||||
new Classification("foo").getExplicitlyMappedFields(
|
||||
new HashMap<String, Object>() {{
|
||||
put("foo", new HashMap<String, String>() {{
|
||||
put("type", "alias");
|
||||
put("path", "bar");
|
||||
}});
|
||||
put("bar", Collections.singletonMap("type", "long"));
|
||||
}},
|
||||
"results"),
|
||||
allOf(
|
||||
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
|
||||
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
|
||||
assertThat(
|
||||
new Classification("foo").getExplicitlyMappedFields(
|
||||
Collections.singletonMap("foo", new HashMap<String, String>() {{
|
||||
put("type", "alias");
|
||||
put("path", "missing");
|
||||
}}),
|
||||
"results"),
|
||||
is(anEmptyMap()));
|
||||
}
|
||||
|
||||
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {
|
||||
|
|
|
@ -92,8 +92,8 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
|
|||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
|
||||
}
|
||||
|
||||
public void testFieldMappingsToCopyIsEmpty() {
|
||||
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(anEmptyMap()));
|
||||
public void testGetExplicitlyMappedFields() {
|
||||
assertThat(createTestInstance().getExplicitlyMappedFields(null, null), is(anEmptyMap()));
|
||||
}
|
||||
|
||||
public void testGetStateDocId() {
|
||||
|
|
|
@ -43,7 +43,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
return createRandom();
|
||||
}
|
||||
|
||||
public static Regression createRandom() {
|
||||
private static Regression createRandom() {
|
||||
String dependentVariableName = randomAlphaOfLength(10);
|
||||
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
||||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||
|
@ -110,8 +110,10 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
|
||||
}
|
||||
|
||||
public void testFieldMappingsToCopyIsNonEmpty() {
|
||||
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
|
||||
public void testGetExplicitlyMappedFields() {
|
||||
assertThat(
|
||||
new Regression("foo").getExplicitlyMappedFields(null, "results"),
|
||||
hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
|
||||
}
|
||||
|
||||
public void testGetStateDocId() {
|
||||
|
|
|
@ -7,8 +7,6 @@ package org.elasticsearch.xpack.ml.integration;
|
|||
|
||||
import com.google.common.collect.Ordering;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
|
||||
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
|
||||
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
|
@ -42,7 +40,6 @@ import java.util.Map;
|
|||
import java.util.Set;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.anyOf;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
@ -116,7 +113,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(predictedClassField, "keyword");
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [classification]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
|
@ -157,7 +154,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(predictedClassField, "keyword");
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [classification]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
|
@ -220,7 +217,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(predictedClassField, expectedMappingTypeForPredictedField);
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, expectedMappingTypeForPredictedField);
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [classification]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
|
@ -308,7 +305,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(predictedClassField, "keyword");
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||
}
|
||||
|
||||
|
@ -365,7 +362,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(predictedClassField, "keyword");
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||
assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||
}
|
||||
|
||||
|
@ -384,7 +381,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(predictedClassField, "keyword");
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||
assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||
}
|
||||
|
||||
|
@ -403,7 +400,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(predictedClassField, "keyword");
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||
assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||
}
|
||||
|
||||
|
@ -564,15 +561,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
return destDoc;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrapper around extractValue that:
|
||||
* - allows dots (".") in the path elements provided as arguments
|
||||
* - supports implicit casting to the appropriate type
|
||||
*/
|
||||
private static <T> T getFieldValue(Map<String, Object> doc, String... path) {
|
||||
return (T)extractValue(String.join(".", path), doc);
|
||||
}
|
||||
|
||||
private static <T> void assertTopClasses(Map<String, Object> resultsObject,
|
||||
int numTopClasses,
|
||||
String dependentVariable,
|
||||
|
@ -656,27 +644,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
private void assertMlResultsFieldMappings(String predictedClassField, String expectedType) {
|
||||
Map<String, Object> mappings =
|
||||
client()
|
||||
.execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(destIndex))
|
||||
.actionGet()
|
||||
.mappings()
|
||||
.get(destIndex)
|
||||
.get("_doc")
|
||||
.sourceAsMap();
|
||||
assertThat(
|
||||
mappings.toString(),
|
||||
getFieldValue(
|
||||
mappings,
|
||||
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
|
||||
equalTo(expectedType));
|
||||
assertThat(
|
||||
mappings.toString(),
|
||||
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
|
||||
equalTo(expectedType));
|
||||
}
|
||||
|
||||
private String stateDocId() {
|
||||
return jobId + "_classification_state#1";
|
||||
}
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
|
||||
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
|
||||
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
|
||||
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
||||
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
|
||||
|
@ -53,6 +55,7 @@ import java.util.Set;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
|
||||
import static org.hamcrest.Matchers.anyOf;
|
||||
import static org.hamcrest.Matchers.arrayWithSize;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
@ -281,4 +284,36 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
|
|||
.get();
|
||||
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
|
||||
}
|
||||
|
||||
protected static void assertMlResultsFieldMappings(String index, String predictedClassField, String expectedType) {
|
||||
Map<String, Object> mappings =
|
||||
client()
|
||||
.execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(index))
|
||||
.actionGet()
|
||||
.mappings()
|
||||
.get(index)
|
||||
.get("_doc")
|
||||
.sourceAsMap();
|
||||
assertThat(
|
||||
mappings.toString(),
|
||||
getFieldValue(
|
||||
mappings,
|
||||
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
|
||||
equalTo(expectedType));
|
||||
if (getFieldValue(mappings, "properties", "ml", "properties", "top_classes") != null) {
|
||||
assertThat(
|
||||
mappings.toString(),
|
||||
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
|
||||
equalTo(expectedType));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrapper around extractValue that:
|
||||
* - allows dots (".") in the path elements provided as arguments
|
||||
* - supports implicit casting to the appropriate type
|
||||
*/
|
||||
protected static <T> T getFieldValue(Map<String, Object> doc, String... path) {
|
||||
return (T)extractValue(String.join(".", path), doc);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,8 +35,10 @@ import static org.hamcrest.Matchers.is;
|
|||
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
private static final String NUMERICAL_FEATURE_FIELD = "feature";
|
||||
private static final String DISCRETE_NUMERICAL_FEATURE_FIELD = "discrete-feature";
|
||||
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
|
||||
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
|
||||
private static final List<Long> DISCRETE_NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(10L, 20L, 30L));
|
||||
private static final List<Double> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList(10.0, 20.0, 30.0));
|
||||
|
||||
private String jobId;
|
||||
|
@ -50,6 +52,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
||||
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
||||
indexData(sourceIndex, 300, 50);
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null,
|
||||
|
@ -78,19 +81,24 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
// it seems for this case values can be as far off as 2.0
|
||||
|
||||
// double featureValue = (double) destDoc.get(NUMERICAL_FEATURE_FIELD);
|
||||
// double predictionValue = (double) resultsObject.get("variable_prediction");
|
||||
// double predictionValue = (double) resultsObject.get(predictedClassField);
|
||||
// assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
||||
assertThat(resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD), is(true));
|
||||
assertThat(
|
||||
resultsObject.toString(),
|
||||
resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD)
|
||||
|| resultsObject.containsKey("feature_importance." + DISCRETE_NUMERICAL_FEATURE_FIELD),
|
||||
is(true));
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [regression]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
|
@ -103,6 +111,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
||||
initialize("regression_only_training_data_and_training_percent_is_100");
|
||||
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
||||
indexData(sourceIndex, 350, 0);
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
||||
|
@ -119,7 +128,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(true));
|
||||
}
|
||||
|
@ -128,6 +137,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [regression]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
|
@ -140,6 +150,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
||||
initialize("regression_only_training_data_and_training_percent_is_50");
|
||||
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
||||
indexData(sourceIndex, 350, 0);
|
||||
|
||||
DataFrameAnalyticsConfig config =
|
||||
|
@ -164,7 +175,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
// Let's just assert there's both training and non-training results
|
||||
if ((boolean) resultsObject.get("is_training")) {
|
||||
|
@ -180,6 +191,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [regression]",
|
||||
"Estimated memory usage for this analytics to be",
|
||||
|
@ -192,6 +204,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
public void testStopAndRestart() throws Exception {
|
||||
initialize("regression_stop_and_restart");
|
||||
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
||||
indexData(sourceIndex, 350, 0);
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
||||
|
@ -233,7 +246,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(true));
|
||||
}
|
||||
|
@ -242,6 +255,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||
}
|
||||
|
||||
public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
|
||||
|
@ -289,6 +303,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
public void testDeleteExpiredData_RemovesUnusedState() throws Exception {
|
||||
initialize("regression_delete_expired_data");
|
||||
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
||||
indexData(sourceIndex, 100, 0);
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
||||
|
@ -301,6 +316,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||
|
||||
// Call _delete_expired_data API and check nothing was deleted
|
||||
assertThat(deleteExpiredData().isDeleted(), is(true));
|
||||
|
@ -319,6 +335,31 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(stateIndexSearchResponse.getHits().getTotalHits().value, equalTo(0L));
|
||||
}
|
||||
|
||||
public void testDependentVariableIsLong() throws Exception {
|
||||
initialize("regression_dependent_variable_is_long");
|
||||
String predictedClassField = DISCRETE_NUMERICAL_FEATURE_FIELD + "_prediction";
|
||||
indexData(sourceIndex, 100, 0);
|
||||
|
||||
DataFrameAnalyticsConfig config =
|
||||
buildAnalytics(
|
||||
jobId,
|
||||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
assertProgress(jobId, 0, 0, 0, 0);
|
||||
|
||||
startAnalytics(jobId);
|
||||
waitUntilAnalyticsIsStopped(jobId);
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
|
||||
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||
}
|
||||
|
||||
private void initialize(String jobId) {
|
||||
this.jobId = jobId;
|
||||
this.sourceIndex = jobId + "_source_index";
|
||||
|
@ -327,7 +368,10 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) {
|
||||
client().admin().indices().prepareCreate(sourceIndex)
|
||||
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double")
|
||||
.addMapping("_doc",
|
||||
NUMERICAL_FEATURE_FIELD, "type=double",
|
||||
DISCRETE_NUMERICAL_FEATURE_FIELD, "type=long",
|
||||
DEPENDENT_VARIABLE_FIELD, "type=double")
|
||||
.get();
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
|
@ -335,12 +379,15 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
for (int i = 0; i < numTrainingRows; i++) {
|
||||
List<Object> source = Arrays.asList(
|
||||
NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()),
|
||||
DISCRETE_NUMERICAL_FEATURE_FIELD, DISCRETE_NUMERICAL_FEATURE_VALUES.get(i % DISCRETE_NUMERICAL_FEATURE_VALUES.size()),
|
||||
DEPENDENT_VARIABLE_FIELD, DEPENDENT_VARIABLE_VALUES.get(i % DEPENDENT_VARIABLE_VALUES.size()));
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
|
||||
List<Object> source = Arrays.asList(NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()));
|
||||
List<Object> source = Arrays.asList(
|
||||
NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()),
|
||||
DISCRETE_NUMERICAL_FEATURE_FIELD, DISCRETE_NUMERICAL_FEATURE_VALUES.get(i % DISCRETE_NUMERICAL_FEATURE_VALUES.size()));
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
|
@ -363,10 +410,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
}
|
||||
|
||||
private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Object> destDoc) {
|
||||
assertThat(destDoc.containsKey("ml"), is(true));
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
||||
return resultsObject;
|
||||
return getFieldValue(destDoc, "ml");
|
||||
}
|
||||
|
||||
protected String stateDocId() {
|
||||
|
|
|
@ -25,7 +25,6 @@ import org.elasticsearch.common.Nullable;
|
|||
import org.elasticsearch.common.collect.ImmutableOpenMap;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.index.IndexSortConfig;
|
||||
import org.elasticsearch.index.mapper.FieldAliasMapper;
|
||||
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
|
@ -41,7 +40,6 @@ import java.util.Map;
|
|||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
||||
|
||||
/**
|
||||
|
@ -163,38 +161,15 @@ public final class DataFrameAnalyticsIndex {
|
|||
return maxValue;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static Map<String, Object> createAdditionalMappings(DataFrameAnalyticsConfig config, Map<String, Object> mappingsProperties) {
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
Map<String, String> idCopyMapping = new HashMap<>();
|
||||
idCopyMapping.put("type", KeywordFieldMapper.CONTENT_TYPE);
|
||||
properties.put(ID_COPY, idCopyMapping);
|
||||
for (Map.Entry<String, String> entry
|
||||
: config.getAnalysis().getExplicitlyMappedFields(config.getDest().getResultsField()).entrySet()) {
|
||||
String destFieldPath = entry.getKey();
|
||||
String sourceFieldPath = entry.getValue();
|
||||
Object sourceFieldMapping = extractMapping(sourceFieldPath, mappingsProperties);
|
||||
if (sourceFieldMapping instanceof Map) {
|
||||
Map<String, Object> sourceFieldMappingAsMap = (Map) sourceFieldMapping;
|
||||
// If the source field is an alias, fetch the concrete field that the alias points to.
|
||||
if (FieldAliasMapper.CONTENT_TYPE.equals(sourceFieldMappingAsMap.get("type"))) {
|
||||
String path = (String) sourceFieldMappingAsMap.get(FieldAliasMapper.Names.PATH);
|
||||
sourceFieldMapping = extractMapping(path, mappingsProperties);
|
||||
}
|
||||
}
|
||||
// We may have updated the value of {@code sourceFieldMapping} in the "if" block above.
|
||||
// Hence, we need to check the "instanceof" condition again.
|
||||
if (sourceFieldMapping instanceof Map) {
|
||||
properties.put(destFieldPath, sourceFieldMapping);
|
||||
}
|
||||
}
|
||||
properties.putAll(config.getAnalysis().getExplicitlyMappedFields(mappingsProperties, config.getDest().getResultsField()));
|
||||
return properties;
|
||||
}
|
||||
|
||||
private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
|
||||
return extractValue(String.join("." + PROPERTIES + ".", path.split("\\.")), mappingsProperties);
|
||||
}
|
||||
|
||||
private static Map<String, Object> createMetaData(String analyticsId, Clock clock) {
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
metadata.put(CREATION_DATE_MILLIS, clock.millis());
|
||||
|
|
|
@ -203,7 +203,7 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase {
|
|||
|
||||
public void testCreateDestinationIndex_Regression() throws IOException {
|
||||
Map<String, Object> map = testCreateDestinationIndex(new Regression(NUMERICAL_FIELD));
|
||||
assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("double"));
|
||||
}
|
||||
|
||||
public void testCreateDestinationIndex_Classification() throws IOException {
|
||||
|
@ -319,7 +319,7 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase {
|
|||
|
||||
public void testUpdateMappingsToDestIndex_Regression() throws IOException {
|
||||
Map<String, Object> map = testUpdateMappingsToDestIndex(new Regression(NUMERICAL_FIELD));
|
||||
assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("double"));
|
||||
}
|
||||
|
||||
public void testUpdateMappingsToDestIndex_Classification() throws IOException {
|
||||
|
|
Loading…
Reference in New Issue