[7.x] Do not copy mapping from dependent variable to prediction field in regression analysis (#51227) (#51288)
This commit is contained in:
parent
1009f92b03
commit
bfcfcdee33
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.index.mapper.FieldAliasMapper;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -28,6 +29,7 @@ import java.util.stream.Stream;
|
||||||
|
|
||||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||||
|
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
|
||||||
|
|
||||||
public class Classification implements DataFrameAnalysis {
|
public class Classification implements DataFrameAnalysis {
|
||||||
|
|
||||||
|
@ -248,12 +250,32 @@ public class Classification implements DataFrameAnalysis {
|
||||||
return Collections.singletonMap(dependentVariable, 2L);
|
return Collections.singletonMap(dependentVariable, 2L);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
@Override
|
@Override
|
||||||
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
|
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
||||||
return new HashMap<String, String>() {{
|
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
|
||||||
put(resultsFieldName + "." + predictionFieldName, dependentVariable);
|
if ((dependentVariableMapping instanceof Map) == false) {
|
||||||
put(resultsFieldName + ".top_classes.class_name", dependentVariable);
|
return Collections.emptyMap();
|
||||||
}};
|
}
|
||||||
|
Map<String, Object> dependentVariableMappingAsMap = (Map) dependentVariableMapping;
|
||||||
|
// If the source field is an alias, fetch the concrete field that the alias points to.
|
||||||
|
if (FieldAliasMapper.CONTENT_TYPE.equals(dependentVariableMappingAsMap.get("type"))) {
|
||||||
|
String path = (String) dependentVariableMappingAsMap.get(FieldAliasMapper.Names.PATH);
|
||||||
|
dependentVariableMapping = extractMapping(path, mappingsProperties);
|
||||||
|
}
|
||||||
|
// We may have updated the value of {@code dependentVariableMapping} in the "if" block above.
|
||||||
|
// Hence, we need to check the "instanceof" condition again.
|
||||||
|
if ((dependentVariableMapping instanceof Map) == false) {
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
Map<String, Object> additionalProperties = new HashMap<>();
|
||||||
|
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
|
||||||
|
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
|
||||||
|
return additionalProperties;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
|
||||||
|
return extractValue(String.join(".properties.", path.split("\\.")), mappingsProperties);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -42,15 +42,13 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
|
||||||
Map<String, Long> getFieldCardinalityLimits();
|
Map<String, Long> getFieldCardinalityLimits();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns fields for which the mappings should be copied from source index to destination index.
|
* Returns fields for which the mappings should be either predefined or copied from source index to destination index.
|
||||||
* Each entry of the returned {@link Map} is of the form:
|
|
||||||
* key - field path in the destination index
|
|
||||||
* value - field path in the source index from which the mapping should be taken
|
|
||||||
*
|
*
|
||||||
|
* @param mappingsProperties mappings.properties portion of the index mappings
|
||||||
* @param resultsFieldName name of the results field under which all the results are stored
|
* @param resultsFieldName name of the results field under which all the results are stored
|
||||||
* @return {@link Map} containing fields for which the mappings should be copied from source index to destination index
|
* @return {@link Map} containing fields for which the mappings should be handled explicitly
|
||||||
*/
|
*/
|
||||||
Map<String, String> getExplicitlyMappedFields(String resultsFieldName);
|
Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return {@code true} if this analysis supports data frame rows with missing values
|
* @return {@code true} if this analysis supports data frame rows with missing values
|
||||||
|
|
|
@ -230,7 +230,7 @@ public class OutlierDetection implements DataFrameAnalysis {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
|
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
||||||
return Collections.emptyMap();
|
return Collections.emptyMap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -187,8 +187,10 @@ public class Regression implements DataFrameAnalysis {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
|
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
||||||
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, dependentVariable);
|
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
|
||||||
|
// high (over 10M) values of dependent variable.
|
||||||
|
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, Collections.singletonMap("type", "double"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -25,6 +25,7 @@ import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.allOf;
|
||||||
import static org.hamcrest.Matchers.anEmptyMap;
|
import static org.hamcrest.Matchers.anEmptyMap;
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.empty;
|
import static org.hamcrest.Matchers.empty;
|
||||||
|
@ -171,8 +172,40 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
||||||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(anEmptyMap())));
|
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(anEmptyMap())));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testFieldMappingsToCopyIsNonEmpty() {
|
public void testGetExplicitlyMappedFields() {
|
||||||
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
|
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"), is(anEmptyMap()));
|
||||||
|
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"), is(anEmptyMap()));
|
||||||
|
assertThat(
|
||||||
|
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
|
||||||
|
is(anEmptyMap()));
|
||||||
|
assertThat(
|
||||||
|
new Classification("foo").getExplicitlyMappedFields(
|
||||||
|
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
|
||||||
|
"results"),
|
||||||
|
allOf(
|
||||||
|
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
|
||||||
|
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
|
||||||
|
assertThat(
|
||||||
|
new Classification("foo").getExplicitlyMappedFields(
|
||||||
|
new HashMap<String, Object>() {{
|
||||||
|
put("foo", new HashMap<String, String>() {{
|
||||||
|
put("type", "alias");
|
||||||
|
put("path", "bar");
|
||||||
|
}});
|
||||||
|
put("bar", Collections.singletonMap("type", "long"));
|
||||||
|
}},
|
||||||
|
"results"),
|
||||||
|
allOf(
|
||||||
|
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
|
||||||
|
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
|
||||||
|
assertThat(
|
||||||
|
new Classification("foo").getExplicitlyMappedFields(
|
||||||
|
Collections.singletonMap("foo", new HashMap<String, String>() {{
|
||||||
|
put("type", "alias");
|
||||||
|
put("path", "missing");
|
||||||
|
}}),
|
||||||
|
"results"),
|
||||||
|
is(anEmptyMap()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {
|
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {
|
||||||
|
|
|
@ -92,8 +92,8 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
|
||||||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
|
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testFieldMappingsToCopyIsEmpty() {
|
public void testGetExplicitlyMappedFields() {
|
||||||
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(anEmptyMap()));
|
assertThat(createTestInstance().getExplicitlyMappedFields(null, null), is(anEmptyMap()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetStateDocId() {
|
public void testGetStateDocId() {
|
||||||
|
|
|
@ -43,7 +43,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||||
return createRandom();
|
return createRandom();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Regression createRandom() {
|
private static Regression createRandom() {
|
||||||
String dependentVariableName = randomAlphaOfLength(10);
|
String dependentVariableName = randomAlphaOfLength(10);
|
||||||
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
||||||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||||
|
@ -110,8 +110,10 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
|
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testFieldMappingsToCopyIsNonEmpty() {
|
public void testGetExplicitlyMappedFields() {
|
||||||
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
|
assertThat(
|
||||||
|
new Regression("foo").getExplicitlyMappedFields(null, "results"),
|
||||||
|
hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetStateDocId() {
|
public void testGetStateDocId() {
|
||||||
|
|
|
@ -7,8 +7,6 @@ package org.elasticsearch.xpack.ml.integration;
|
||||||
|
|
||||||
import com.google.common.collect.Ordering;
|
import com.google.common.collect.Ordering;
|
||||||
import org.elasticsearch.ElasticsearchStatusException;
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
|
|
||||||
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
|
|
||||||
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
||||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||||
import org.elasticsearch.action.bulk.BulkResponse;
|
import org.elasticsearch.action.bulk.BulkResponse;
|
||||||
|
@ -42,7 +40,6 @@ import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
import static java.util.stream.Collectors.toList;
|
import static java.util.stream.Collectors.toList;
|
||||||
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
|
|
||||||
import static org.hamcrest.Matchers.allOf;
|
import static org.hamcrest.Matchers.allOf;
|
||||||
import static org.hamcrest.Matchers.anyOf;
|
import static org.hamcrest.Matchers.anyOf;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
@ -116,7 +113,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
assertInferenceModelPersisted(jobId);
|
assertInferenceModelPersisted(jobId);
|
||||||
assertMlResultsFieldMappings(predictedClassField, "keyword");
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||||
assertThatAuditMessagesMatch(jobId,
|
assertThatAuditMessagesMatch(jobId,
|
||||||
"Created analytics with analysis type [classification]",
|
"Created analytics with analysis type [classification]",
|
||||||
"Estimated memory usage for this analytics to be",
|
"Estimated memory usage for this analytics to be",
|
||||||
|
@ -157,7 +154,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
assertInferenceModelPersisted(jobId);
|
assertInferenceModelPersisted(jobId);
|
||||||
assertMlResultsFieldMappings(predictedClassField, "keyword");
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||||
assertThatAuditMessagesMatch(jobId,
|
assertThatAuditMessagesMatch(jobId,
|
||||||
"Created analytics with analysis type [classification]",
|
"Created analytics with analysis type [classification]",
|
||||||
"Estimated memory usage for this analytics to be",
|
"Estimated memory usage for this analytics to be",
|
||||||
|
@ -220,7 +217,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
assertInferenceModelPersisted(jobId);
|
assertInferenceModelPersisted(jobId);
|
||||||
assertMlResultsFieldMappings(predictedClassField, expectedMappingTypeForPredictedField);
|
assertMlResultsFieldMappings(destIndex, predictedClassField, expectedMappingTypeForPredictedField);
|
||||||
assertThatAuditMessagesMatch(jobId,
|
assertThatAuditMessagesMatch(jobId,
|
||||||
"Created analytics with analysis type [classification]",
|
"Created analytics with analysis type [classification]",
|
||||||
"Estimated memory usage for this analytics to be",
|
"Estimated memory usage for this analytics to be",
|
||||||
|
@ -308,7 +305,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
assertInferenceModelPersisted(jobId);
|
assertInferenceModelPersisted(jobId);
|
||||||
assertMlResultsFieldMappings(predictedClassField, "keyword");
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||||
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,7 +362,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
assertInferenceModelPersisted(jobId);
|
assertInferenceModelPersisted(jobId);
|
||||||
assertMlResultsFieldMappings(predictedClassField, "keyword");
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||||
assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -384,7 +381,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
assertInferenceModelPersisted(jobId);
|
assertInferenceModelPersisted(jobId);
|
||||||
assertMlResultsFieldMappings(predictedClassField, "keyword");
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||||
assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -403,7 +400,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
assertInferenceModelPersisted(jobId);
|
assertInferenceModelPersisted(jobId);
|
||||||
assertMlResultsFieldMappings(predictedClassField, "keyword");
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
|
||||||
assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -564,15 +561,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
return destDoc;
|
return destDoc;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Wrapper around extractValue that:
|
|
||||||
* - allows dots (".") in the path elements provided as arguments
|
|
||||||
* - supports implicit casting to the appropriate type
|
|
||||||
*/
|
|
||||||
private static <T> T getFieldValue(Map<String, Object> doc, String... path) {
|
|
||||||
return (T)extractValue(String.join(".", path), doc);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static <T> void assertTopClasses(Map<String, Object> resultsObject,
|
private static <T> void assertTopClasses(Map<String, Object> resultsObject,
|
||||||
int numTopClasses,
|
int numTopClasses,
|
||||||
String dependentVariable,
|
String dependentVariable,
|
||||||
|
@ -656,27 +644,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void assertMlResultsFieldMappings(String predictedClassField, String expectedType) {
|
|
||||||
Map<String, Object> mappings =
|
|
||||||
client()
|
|
||||||
.execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(destIndex))
|
|
||||||
.actionGet()
|
|
||||||
.mappings()
|
|
||||||
.get(destIndex)
|
|
||||||
.get("_doc")
|
|
||||||
.sourceAsMap();
|
|
||||||
assertThat(
|
|
||||||
mappings.toString(),
|
|
||||||
getFieldValue(
|
|
||||||
mappings,
|
|
||||||
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
|
|
||||||
equalTo(expectedType));
|
|
||||||
assertThat(
|
|
||||||
mappings.toString(),
|
|
||||||
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
|
|
||||||
equalTo(expectedType));
|
|
||||||
}
|
|
||||||
|
|
||||||
private String stateDocId() {
|
private String stateDocId() {
|
||||||
return jobId + "_classification_state#1";
|
return jobId + "_classification_state#1";
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,8 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.ml.integration;
|
package org.elasticsearch.xpack.ml.integration;
|
||||||
|
|
||||||
|
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
|
||||||
|
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
|
||||||
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
|
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
|
||||||
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
||||||
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
|
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
|
||||||
|
@ -53,6 +55,7 @@ import java.util.Set;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
|
||||||
import static org.hamcrest.Matchers.anyOf;
|
import static org.hamcrest.Matchers.anyOf;
|
||||||
import static org.hamcrest.Matchers.arrayWithSize;
|
import static org.hamcrest.Matchers.arrayWithSize;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
@ -281,4 +284,36 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
|
||||||
.get();
|
.get();
|
||||||
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
|
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected static void assertMlResultsFieldMappings(String index, String predictedClassField, String expectedType) {
|
||||||
|
Map<String, Object> mappings =
|
||||||
|
client()
|
||||||
|
.execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(index))
|
||||||
|
.actionGet()
|
||||||
|
.mappings()
|
||||||
|
.get(index)
|
||||||
|
.get("_doc")
|
||||||
|
.sourceAsMap();
|
||||||
|
assertThat(
|
||||||
|
mappings.toString(),
|
||||||
|
getFieldValue(
|
||||||
|
mappings,
|
||||||
|
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
|
||||||
|
equalTo(expectedType));
|
||||||
|
if (getFieldValue(mappings, "properties", "ml", "properties", "top_classes") != null) {
|
||||||
|
assertThat(
|
||||||
|
mappings.toString(),
|
||||||
|
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
|
||||||
|
equalTo(expectedType));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wrapper around extractValue that:
|
||||||
|
* - allows dots (".") in the path elements provided as arguments
|
||||||
|
* - supports implicit casting to the appropriate type
|
||||||
|
*/
|
||||||
|
protected static <T> T getFieldValue(Map<String, Object> doc, String... path) {
|
||||||
|
return (T)extractValue(String.join(".", path), doc);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,8 +35,10 @@ import static org.hamcrest.Matchers.is;
|
||||||
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
private static final String NUMERICAL_FEATURE_FIELD = "feature";
|
private static final String NUMERICAL_FEATURE_FIELD = "feature";
|
||||||
|
private static final String DISCRETE_NUMERICAL_FEATURE_FIELD = "discrete-feature";
|
||||||
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
|
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
|
||||||
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
|
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
|
||||||
|
private static final List<Long> DISCRETE_NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(10L, 20L, 30L));
|
||||||
private static final List<Double> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList(10.0, 20.0, 30.0));
|
private static final List<Double> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList(10.0, 20.0, 30.0));
|
||||||
|
|
||||||
private String jobId;
|
private String jobId;
|
||||||
|
@ -50,6 +52,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||||
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
||||||
|
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
||||||
indexData(sourceIndex, 300, 50);
|
indexData(sourceIndex, 300, 50);
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null,
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null,
|
||||||
|
@ -78,19 +81,24 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
// it seems for this case values can be as far off as 2.0
|
// it seems for this case values can be as far off as 2.0
|
||||||
|
|
||||||
// double featureValue = (double) destDoc.get(NUMERICAL_FEATURE_FIELD);
|
// double featureValue = (double) destDoc.get(NUMERICAL_FEATURE_FIELD);
|
||||||
// double predictionValue = (double) resultsObject.get("variable_prediction");
|
// double predictionValue = (double) resultsObject.get(predictedClassField);
|
||||||
// assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
|
// assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
|
||||||
|
|
||||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
||||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||||
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
||||||
assertThat(resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD), is(true));
|
assertThat(
|
||||||
|
resultsObject.toString(),
|
||||||
|
resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD)
|
||||||
|
|| resultsObject.containsKey("feature_importance." + DISCRETE_NUMERICAL_FEATURE_FIELD),
|
||||||
|
is(true));
|
||||||
}
|
}
|
||||||
|
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
assertInferenceModelPersisted(jobId);
|
assertInferenceModelPersisted(jobId);
|
||||||
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||||
assertThatAuditMessagesMatch(jobId,
|
assertThatAuditMessagesMatch(jobId,
|
||||||
"Created analytics with analysis type [regression]",
|
"Created analytics with analysis type [regression]",
|
||||||
"Estimated memory usage for this analytics to be",
|
"Estimated memory usage for this analytics to be",
|
||||||
|
@ -103,6 +111,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
||||||
initialize("regression_only_training_data_and_training_percent_is_100");
|
initialize("regression_only_training_data_and_training_percent_is_100");
|
||||||
|
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
||||||
indexData(sourceIndex, 350, 0);
|
indexData(sourceIndex, 350, 0);
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
||||||
|
@ -119,7 +128,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
for (SearchHit hit : sourceData.getHits()) {
|
for (SearchHit hit : sourceData.getHits()) {
|
||||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||||
|
|
||||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
||||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||||
assertThat(resultsObject.get("is_training"), is(true));
|
assertThat(resultsObject.get("is_training"), is(true));
|
||||||
}
|
}
|
||||||
|
@ -128,6 +137,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
assertInferenceModelPersisted(jobId);
|
assertInferenceModelPersisted(jobId);
|
||||||
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||||
assertThatAuditMessagesMatch(jobId,
|
assertThatAuditMessagesMatch(jobId,
|
||||||
"Created analytics with analysis type [regression]",
|
"Created analytics with analysis type [regression]",
|
||||||
"Estimated memory usage for this analytics to be",
|
"Estimated memory usage for this analytics to be",
|
||||||
|
@ -140,6 +150,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
||||||
initialize("regression_only_training_data_and_training_percent_is_50");
|
initialize("regression_only_training_data_and_training_percent_is_50");
|
||||||
|
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
||||||
indexData(sourceIndex, 350, 0);
|
indexData(sourceIndex, 350, 0);
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config =
|
DataFrameAnalyticsConfig config =
|
||||||
|
@ -164,7 +175,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
for (SearchHit hit : sourceData.getHits()) {
|
for (SearchHit hit : sourceData.getHits()) {
|
||||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||||
|
|
||||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
||||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||||
// Let's just assert there's both training and non-training results
|
// Let's just assert there's both training and non-training results
|
||||||
if ((boolean) resultsObject.get("is_training")) {
|
if ((boolean) resultsObject.get("is_training")) {
|
||||||
|
@ -180,6 +191,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
assertInferenceModelPersisted(jobId);
|
assertInferenceModelPersisted(jobId);
|
||||||
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||||
assertThatAuditMessagesMatch(jobId,
|
assertThatAuditMessagesMatch(jobId,
|
||||||
"Created analytics with analysis type [regression]",
|
"Created analytics with analysis type [regression]",
|
||||||
"Estimated memory usage for this analytics to be",
|
"Estimated memory usage for this analytics to be",
|
||||||
|
@ -192,6 +204,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
public void testStopAndRestart() throws Exception {
|
public void testStopAndRestart() throws Exception {
|
||||||
initialize("regression_stop_and_restart");
|
initialize("regression_stop_and_restart");
|
||||||
|
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
||||||
indexData(sourceIndex, 350, 0);
|
indexData(sourceIndex, 350, 0);
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
||||||
|
@ -233,7 +246,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
for (SearchHit hit : sourceData.getHits()) {
|
for (SearchHit hit : sourceData.getHits()) {
|
||||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||||
|
|
||||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
||||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||||
assertThat(resultsObject.get("is_training"), is(true));
|
assertThat(resultsObject.get("is_training"), is(true));
|
||||||
}
|
}
|
||||||
|
@ -242,6 +255,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
assertInferenceModelPersisted(jobId);
|
assertInferenceModelPersisted(jobId);
|
||||||
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
|
public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
|
||||||
|
@ -289,6 +303,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
public void testDeleteExpiredData_RemovesUnusedState() throws Exception {
|
public void testDeleteExpiredData_RemovesUnusedState() throws Exception {
|
||||||
initialize("regression_delete_expired_data");
|
initialize("regression_delete_expired_data");
|
||||||
|
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
||||||
indexData(sourceIndex, 100, 0);
|
indexData(sourceIndex, 100, 0);
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
||||||
|
@ -301,6 +316,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(stateDocId());
|
assertModelStatePersisted(stateDocId());
|
||||||
assertInferenceModelPersisted(jobId);
|
assertInferenceModelPersisted(jobId);
|
||||||
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||||
|
|
||||||
// Call _delete_expired_data API and check nothing was deleted
|
// Call _delete_expired_data API and check nothing was deleted
|
||||||
assertThat(deleteExpiredData().isDeleted(), is(true));
|
assertThat(deleteExpiredData().isDeleted(), is(true));
|
||||||
|
@ -319,6 +335,31 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertThat(stateIndexSearchResponse.getHits().getTotalHits().value, equalTo(0L));
|
assertThat(stateIndexSearchResponse.getHits().getTotalHits().value, equalTo(0L));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testDependentVariableIsLong() throws Exception {
|
||||||
|
initialize("regression_dependent_variable_is_long");
|
||||||
|
String predictedClassField = DISCRETE_NUMERICAL_FEATURE_FIELD + "_prediction";
|
||||||
|
indexData(sourceIndex, 100, 0);
|
||||||
|
|
||||||
|
DataFrameAnalyticsConfig config =
|
||||||
|
buildAnalytics(
|
||||||
|
jobId,
|
||||||
|
sourceIndex,
|
||||||
|
destIndex,
|
||||||
|
null,
|
||||||
|
new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null));
|
||||||
|
registerAnalytics(config);
|
||||||
|
putAnalytics(config);
|
||||||
|
|
||||||
|
assertIsStopped(jobId);
|
||||||
|
assertProgress(jobId, 0, 0, 0, 0);
|
||||||
|
|
||||||
|
startAnalytics(jobId);
|
||||||
|
waitUntilAnalyticsIsStopped(jobId);
|
||||||
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
|
|
||||||
|
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
||||||
|
}
|
||||||
|
|
||||||
private void initialize(String jobId) {
|
private void initialize(String jobId) {
|
||||||
this.jobId = jobId;
|
this.jobId = jobId;
|
||||||
this.sourceIndex = jobId + "_source_index";
|
this.sourceIndex = jobId + "_source_index";
|
||||||
|
@ -327,7 +368,10 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) {
|
private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) {
|
||||||
client().admin().indices().prepareCreate(sourceIndex)
|
client().admin().indices().prepareCreate(sourceIndex)
|
||||||
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double")
|
.addMapping("_doc",
|
||||||
|
NUMERICAL_FEATURE_FIELD, "type=double",
|
||||||
|
DISCRETE_NUMERICAL_FEATURE_FIELD, "type=long",
|
||||||
|
DEPENDENT_VARIABLE_FIELD, "type=double")
|
||||||
.get();
|
.get();
|
||||||
|
|
||||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||||
|
@ -335,12 +379,15 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
for (int i = 0; i < numTrainingRows; i++) {
|
for (int i = 0; i < numTrainingRows; i++) {
|
||||||
List<Object> source = Arrays.asList(
|
List<Object> source = Arrays.asList(
|
||||||
NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()),
|
NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()),
|
||||||
|
DISCRETE_NUMERICAL_FEATURE_FIELD, DISCRETE_NUMERICAL_FEATURE_VALUES.get(i % DISCRETE_NUMERICAL_FEATURE_VALUES.size()),
|
||||||
DEPENDENT_VARIABLE_FIELD, DEPENDENT_VARIABLE_VALUES.get(i % DEPENDENT_VARIABLE_VALUES.size()));
|
DEPENDENT_VARIABLE_FIELD, DEPENDENT_VARIABLE_VALUES.get(i % DEPENDENT_VARIABLE_VALUES.size()));
|
||||||
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
|
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
|
||||||
bulkRequestBuilder.add(indexRequest);
|
bulkRequestBuilder.add(indexRequest);
|
||||||
}
|
}
|
||||||
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
|
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
|
||||||
List<Object> source = Arrays.asList(NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()));
|
List<Object> source = Arrays.asList(
|
||||||
|
NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()),
|
||||||
|
DISCRETE_NUMERICAL_FEATURE_FIELD, DISCRETE_NUMERICAL_FEATURE_VALUES.get(i % DISCRETE_NUMERICAL_FEATURE_VALUES.size()));
|
||||||
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
|
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
|
||||||
bulkRequestBuilder.add(indexRequest);
|
bulkRequestBuilder.add(indexRequest);
|
||||||
}
|
}
|
||||||
|
@ -363,10 +410,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Object> destDoc) {
|
private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Object> destDoc) {
|
||||||
assertThat(destDoc.containsKey("ml"), is(true));
|
return getFieldValue(destDoc, "ml");
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
|
||||||
return resultsObject;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected String stateDocId() {
|
protected String stateDocId() {
|
||||||
|
|
|
@ -25,7 +25,6 @@ import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.collect.ImmutableOpenMap;
|
import org.elasticsearch.common.collect.ImmutableOpenMap;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.index.IndexSortConfig;
|
import org.elasticsearch.index.IndexSortConfig;
|
||||||
import org.elasticsearch.index.mapper.FieldAliasMapper;
|
|
||||||
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
||||||
import org.elasticsearch.search.sort.SortOrder;
|
import org.elasticsearch.search.sort.SortOrder;
|
||||||
import org.elasticsearch.xpack.core.ClientHelper;
|
import org.elasticsearch.xpack.core.ClientHelper;
|
||||||
|
@ -41,7 +40,6 @@ import java.util.Map;
|
||||||
import java.util.concurrent.atomic.AtomicReference;
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
import java.util.function.Supplier;
|
import java.util.function.Supplier;
|
||||||
|
|
||||||
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
|
|
||||||
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -163,38 +161,15 @@ public final class DataFrameAnalyticsIndex {
|
||||||
return maxValue;
|
return maxValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
private static Map<String, Object> createAdditionalMappings(DataFrameAnalyticsConfig config, Map<String, Object> mappingsProperties) {
|
private static Map<String, Object> createAdditionalMappings(DataFrameAnalyticsConfig config, Map<String, Object> mappingsProperties) {
|
||||||
Map<String, Object> properties = new HashMap<>();
|
Map<String, Object> properties = new HashMap<>();
|
||||||
Map<String, String> idCopyMapping = new HashMap<>();
|
Map<String, String> idCopyMapping = new HashMap<>();
|
||||||
idCopyMapping.put("type", KeywordFieldMapper.CONTENT_TYPE);
|
idCopyMapping.put("type", KeywordFieldMapper.CONTENT_TYPE);
|
||||||
properties.put(ID_COPY, idCopyMapping);
|
properties.put(ID_COPY, idCopyMapping);
|
||||||
for (Map.Entry<String, String> entry
|
properties.putAll(config.getAnalysis().getExplicitlyMappedFields(mappingsProperties, config.getDest().getResultsField()));
|
||||||
: config.getAnalysis().getExplicitlyMappedFields(config.getDest().getResultsField()).entrySet()) {
|
|
||||||
String destFieldPath = entry.getKey();
|
|
||||||
String sourceFieldPath = entry.getValue();
|
|
||||||
Object sourceFieldMapping = extractMapping(sourceFieldPath, mappingsProperties);
|
|
||||||
if (sourceFieldMapping instanceof Map) {
|
|
||||||
Map<String, Object> sourceFieldMappingAsMap = (Map) sourceFieldMapping;
|
|
||||||
// If the source field is an alias, fetch the concrete field that the alias points to.
|
|
||||||
if (FieldAliasMapper.CONTENT_TYPE.equals(sourceFieldMappingAsMap.get("type"))) {
|
|
||||||
String path = (String) sourceFieldMappingAsMap.get(FieldAliasMapper.Names.PATH);
|
|
||||||
sourceFieldMapping = extractMapping(path, mappingsProperties);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// We may have updated the value of {@code sourceFieldMapping} in the "if" block above.
|
|
||||||
// Hence, we need to check the "instanceof" condition again.
|
|
||||||
if (sourceFieldMapping instanceof Map) {
|
|
||||||
properties.put(destFieldPath, sourceFieldMapping);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return properties;
|
return properties;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
|
|
||||||
return extractValue(String.join("." + PROPERTIES + ".", path.split("\\.")), mappingsProperties);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String, Object> createMetaData(String analyticsId, Clock clock) {
|
private static Map<String, Object> createMetaData(String analyticsId, Clock clock) {
|
||||||
Map<String, Object> metadata = new HashMap<>();
|
Map<String, Object> metadata = new HashMap<>();
|
||||||
metadata.put(CREATION_DATE_MILLIS, clock.millis());
|
metadata.put(CREATION_DATE_MILLIS, clock.millis());
|
||||||
|
|
|
@ -203,7 +203,7 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase {
|
||||||
|
|
||||||
public void testCreateDestinationIndex_Regression() throws IOException {
|
public void testCreateDestinationIndex_Regression() throws IOException {
|
||||||
Map<String, Object> map = testCreateDestinationIndex(new Regression(NUMERICAL_FIELD));
|
Map<String, Object> map = testCreateDestinationIndex(new Regression(NUMERICAL_FIELD));
|
||||||
assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("integer"));
|
assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("double"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testCreateDestinationIndex_Classification() throws IOException {
|
public void testCreateDestinationIndex_Classification() throws IOException {
|
||||||
|
@ -319,7 +319,7 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase {
|
||||||
|
|
||||||
public void testUpdateMappingsToDestIndex_Regression() throws IOException {
|
public void testUpdateMappingsToDestIndex_Regression() throws IOException {
|
||||||
Map<String, Object> map = testUpdateMappingsToDestIndex(new Regression(NUMERICAL_FIELD));
|
Map<String, Object> map = testUpdateMappingsToDestIndex(new Regression(NUMERICAL_FIELD));
|
||||||
assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("integer"));
|
assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("double"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testUpdateMappingsToDestIndex_Classification() throws IOException {
|
public void testUpdateMappingsToDestIndex_Classification() throws IOException {
|
||||||
|
|
Loading…
Reference in New Issue